Source code for texar.tf.data.data.data_iterators

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various data iterator classes.
"""

import tensorflow as tf

from texar.tf.data.data.data_base import DataBase
from texar.tf.utils.variables import get_unique_named_variable_scope

__all__ = [
    "DataIteratorBase",
    "DataIterator",
    "TrainTestDataIterator",
    "FeedableDataIterator",
    "TrainTestFeedableDataIterator"
]


[docs]class DataIteratorBase(object): """Base class for all data iterator classes to inherit. A data iterator is a wrapper of :tf_main:`tf.data.Iterator <data/Iterator>`, and can switch between and iterate through **multiple** datasets. Args: datasets: Datasets to iterates through. This can be: - A single instance of :tf_main:`tf.data.Dataset <data/Dataset>` \ or instance of subclass of :class:`~texar.tf.data.DataBase`. - A `dict` that maps dataset name to \ instance of :tf_main:`tf.data.Dataset <data/Dataset>` or \ subclass of :class:`~texar.tf.data.DataBase`. - A `list` of instances of subclasses of \ :class:`texar.tf.data.DataBase`. The name of instances \ (:attr:`texar.tf.data.DataBase.name`) must be unique. """ def __init__(self, datasets): self._default_dataset_name = 'data' if isinstance(datasets, (tf.data.Dataset, DataBase)): datasets = {self._default_dataset_name: datasets} elif isinstance(datasets, (list, tuple)): if any(not isinstance(d, DataBase) for d in datasets): raise ValueError("`datasets` must be an non-empty list of " "`tx.data.DataBase` instances.") num_datasets = len(datasets) datasets = {d.name: d for d in datasets} if len(datasets) < num_datasets: raise ValueError("Names of datasets must be unique.") _datasets = {} for k, v in datasets.items(): # pylint: disable=invalid-name _datasets[k] = v if isinstance(v, tf.data.Dataset) else v.dataset self._datasets = _datasets if len(self._datasets) <= 0: raise ValueError("`datasets` must not be empty.") @property def num_datasets(self): """Number of datasets. """ return len(self._datasets) @property def dataset_names(self): """A list of dataset names. """ return list(self._datasets.keys())
[docs]class DataIterator(DataIteratorBase): """Data iterator that switches and iterates through multiple datasets. This is a wrapper of TF reinitializble :tf_main:`iterator <data/Iterator>`. Args: datasets: Datasets to iterates through. This can be: - A single instance of :tf_main:`tf.data.Dataset <data/Dataset>` \ or instance of subclass of :class:`~texar.tf.data.DataBase`. - A `dict` that maps dataset name to \ instance of :tf_main:`tf.data.Dataset <data/Dataset>` or \ subclass of :class:`~texar.tf.data.DataBase`. - A `list` of instances of subclasses of \ :class:`texar.tf.data.DataBase`. The name of instances \ (:attr:`texar.tf.data.DataBase.name`) must be unique. Example: .. code-block:: python train_data = MonoTextData(hparams_train) test_data = MonoTextData(hparams_test) iterator = DataIterator({'train': train_data, 'test': test_data}) batch = iterator.get_next() sess = tf.Session() for _ in range(200): # Run 200 epochs of train/test # Starts iterating through training data from the beginning iterator.switch_to_dataset(sess, 'train') while True: try: train_batch_ = sess.run(batch) except tf.errors.OutOfRangeError: print("End of training epoch.") # Starts iterating through test data from the beginning iterator.switch_to_dataset(sess, 'test') while True: try: test_batch_ = sess.run(batch) except tf.errors.OutOfRangeError: print("End of test epoch.") """ def __init__(self, datasets): DataIteratorBase.__init__(self, datasets) self._variable_scope = get_unique_named_variable_scope('data_iterator') with tf.variable_scope(self._variable_scope): first_dataset = self._datasets[sorted(self.dataset_names)[0]] self._iterator = tf.data.Iterator.from_structure( first_dataset.output_types, first_dataset.output_shapes) self._iterator_init_ops = { name: self._iterator.make_initializer(d) for name, d in self._datasets.items() }
[docs] def switch_to_dataset(self, sess, dataset_name=None): """Re-initializes the iterator of a given dataset and starts iterating over the dataset (from the beginning). Args: sess: The current tf session. dataset_name (optional): Name of the dataset. If not provided, there must be only one Dataset. """ if dataset_name is None: if self.num_datasets > 1: raise ValueError("`dataset_name` is required if there are " "more than one datasets.") dataset_name = next(iter(self._datasets)) if dataset_name not in self._datasets: raise ValueError("Dataset not found: ", dataset_name) sess.run(self._iterator_init_ops[dataset_name])
[docs] def get_next(self): """Returns the next element of the activated dataset. """ return self._iterator.get_next()
[docs]class TrainTestDataIterator(DataIterator): """Data iterator that alternatives between train, val, and test datasets. :attr:`train`, :attr:`val`, and :attr:`test` can be instance of either :tf_main:`tf.data.Dataset <data/Dataset>` or subclass of :class:`~texar.tf.data.DataBase`. At least one of them must be provided. This is a wrapper of :class:`~texar.tf.data.DataIterator`. Args: train (optional): Training data. val (optional): Validation data. test (optional): Test data. Example: .. code-block:: python train_data = MonoTextData(hparams_train) val_data = MonoTextData(hparams_val) iterator = TrainTestDataIterator(train=train_data, val=val_data) batch = iterator.get_next() sess = tf.Session() for _ in range(200): # Run 200 epochs of train/val # Starts iterating through training data from the beginning iterator.switch_to_train_data(sess) while True: try: train_batch_ = sess.run(batch) except tf.errors.OutOfRangeError: print("End of training epoch.") # Starts iterating through val data from the beginning iterator.switch_to_val_dataset(sess) while True: try: val_batch_ = sess.run(batch) except tf.errors.OutOfRangeError: print("End of val epoch.") """ def __init__(self, train=None, val=None, test=None): dataset_dict = {} self._train_name = 'train' self._val_name = 'val' self._test_name = 'test' if train is not None: dataset_dict[self._train_name] = train if val is not None: dataset_dict[self._val_name] = val if test is not None: dataset_dict[self._test_name] = test if len(dataset_dict) == 0: raise ValueError("At least one of `train`, `val`, and `test` " "must be provided.") DataIterator.__init__(self, dataset_dict)
[docs] def switch_to_train_data(self, sess): """Starts to iterate through training data (from the beginning). Args: sess: The current tf session. """ if self._train_name not in self._datasets: raise ValueError("Training data not provided.") self.switch_to_dataset(sess, self._train_name)
[docs] def switch_to_val_data(self, sess): """Starts to iterate through val data (from the beginning). Args: sess: The current tf session. """ if self._val_name not in self._datasets: raise ValueError("Val data not provided.") self.switch_to_dataset(sess, self._val_name)
[docs] def switch_to_test_data(self, sess): """Starts to iterate through test data (from the beginning). Args: sess: The current tf session. """ if self._test_name not in self._datasets: raise ValueError("Test data not provided.") self.switch_to_dataset(sess, self._test_name)
[docs]class FeedableDataIterator(DataIteratorBase): """Data iterator that iterates through **multiple** datasets and switches between datasets. The iterator can switch to a dataset and resume from where we left off last time we visited the dataset. This is a wrapper of TF feedable :tf_main:`iterator <data/Iterator>`. Args: datasets: Datasets to iterates through. This can be: - A single instance of :tf_main:`tf.data.Dataset <data/Dataset>` \ or instance of subclass of :class:`~texar.tf.data.DataBase`. - A `dict` that maps dataset name to \ instance of :tf_main:`tf.data.Dataset <data/Dataset>` or \ subclass of :class:`~texar.tf.data.DataBase`. - A `list` of instances of subclasses of \ :class:`texar.tf.data.DataBase`. The name of instances \ (:attr:`texar.tf.data.DataBase.name`) must be unique. Example: .. code-block:: python train_data = MonoTextData(hparams={'num_epochs': 200, ...}) test_data = MonoTextData(hparams_test) iterator = FeedableDataIterator({'train': train_data, 'test': test_data}) batch = iterator.get_next() sess = tf.Session() def _eval_epoch(): # Iterate through test data for one epoch # Initialize and start from beginning of test data iterator.initialize_dataset(sess, 'test') while True: try: fetch_dict = { # Read from test data iterator.handle: Iterator.get_handle(sess, 'test') } test_batch_ = sess.run(batch, feed_dict=feed_dict) except tf.errors.OutOfRangeError: print("End of val epoch.") # Initialize and start from beginning of training data iterator.initialize_dataset(sess, 'train') step = 0 while True: try: fetch_dict = { # Read from training data iterator.handle: Iterator.get_handle(sess, 'train') } train_batch_ = sess.run(batch, fetch_dict=fetch_dict) step +=1 if step % 200 == 0: # Evaluate periodically _eval_epoch() except tf.errors.OutOfRangeError: print("End of training.") """ def __init__(self, datasets): DataIteratorBase.__init__(self, datasets) self._variable_scope = get_unique_named_variable_scope( 'feedable_data_iterator') with tf.variable_scope(self._variable_scope): self._handle = tf.placeholder(tf.string, shape=[], name='handle') first_dataset = self._datasets[sorted(self.dataset_names)[0]] self._iterator = tf.data.Iterator.from_string_handle( self._handle, first_dataset.output_types, first_dataset.output_shapes) self._dataset_iterators = { name: dataset.make_initializable_iterator() for name, dataset in self._datasets.items() }
[docs] def get_handle(self, sess, dataset_name=None): """Returns a dataset handle used to feed the :attr:`handle` placeholder to fetch data from the dataset. Args: sess: The current tf session. dataset_name (optional): Name of the dataset. If not provided, there must be only one Dataset. Returns: A string handle to be fed to the :attr:`handle` placeholder. Example: .. code-block:: python next_element = iterator.get_next() train_handle = iterator.get_handle(sess, 'train') # Gets the next training element ne_ = sess.run(next_element, feed_dict={iterator.handle: train_handle}) """ if dataset_name is None: if self.num_datasets > 1: raise ValueError("`dataset_name` is required if there are " "more than one datasets.") dataset_name = next(iter(self._datasets)) if dataset_name not in self._datasets: raise ValueError("Dataset not found: ", dataset_name) return sess.run(self._dataset_iterators[dataset_name].string_handle())
[docs] def restart_dataset(self, sess, dataset_name=None): """Restarts datasets so that next iteration will fetch data from the beginning of the datasets. Args: sess: The current tf session. dataset_name (optional): A dataset name or a list of dataset names that specifies which dataset(s) to restart. If `None`, all datasets are restart. """ self.initialize_dataset(sess, dataset_name)
[docs] def initialize_dataset(self, sess, dataset_name=None): """Initializes datasets. A dataset must be initialized before being used. Args: sess: The current tf session. dataset_name (optional): A dataset name or a list of dataset names that specifies which dataset(s) to initialize. If `None`, all datasets are initialized. """ if dataset_name is None: dataset_name = self.dataset_names if not isinstance(dataset_name, (tuple, list)): dataset_name = [dataset_name] for name in dataset_name: sess.run(self._dataset_iterators[name].initializer)
[docs] def get_next(self): """Returns the next element of the activated dataset. """ return self._iterator.get_next()
@property def handle(self): """The handle placeholder that can be fed with a dataset handle to fetch data from the dataset. """ return self._handle
[docs]class TrainTestFeedableDataIterator(FeedableDataIterator): """Feedable data iterator that alternatives between train, val, and test datasets. This is a wrapper of :class:`~texar.tf.data.FeedableDataIterator`. The iterator can switch to a dataset and resume from where it was left off when it was visited last time. :attr:`train`, :attr:`val`, and :attr:`test` can be instance of either :tf_main:`tf.data.Dataset <data/Dataset>` or subclass of :class:`~texar.tf.data.DataBase`. At least one of them must be provided. Args: train (optional): Training data. val (optional): Validation data. test (optional): Test data. Example: .. code-block:: python train_data = MonoTextData(hparams={'num_epochs': 200, ...}) test_data = MonoTextData(hparams_test) iterator = TrainTestFeedableDataIterator(train=train_data, test=test_data) batch = iterator.get_next() sess = tf.Session() def _eval_epoch(): # Iterate through test data for one epoch # Initialize and start from beginning of test data iterator.initialize_test_dataset(sess) while True: try: fetch_dict = { # Read from test data iterator.handle: Iterator.get_test_handle(sess) } test_batch_ = sess.run(batch, feed_dict=feed_dict) except tf.errors.OutOfRangeError: print("End of test epoch.") # Initialize and start from beginning of training data iterator.initialize_train_dataset(sess) step = 0 while True: try: fetch_dict = { # Read from training data iterator.handle: Iterator.get_train_handle(sess) } train_batch_ = sess.run(batch, fetch_dict=fetch_dict) step +=1 if step % 200 == 0: # Evaluate periodically _eval_epoch() except tf.errors.OutOfRangeError: print("End of training.") """ def __init__(self, train=None, val=None, test=None): dataset_dict = {} self._train_name = 'train' self._val_name = 'val' self._test_name = 'test' if train is not None: dataset_dict[self._train_name] = train if val is not None: dataset_dict[self._val_name] = val if test is not None: dataset_dict[self._test_name] = test if len(dataset_dict) == 0: raise ValueError("At least one of `train`, `val`, and `test` " "must be provided.") FeedableDataIterator.__init__(self, dataset_dict)
[docs] def get_train_handle(self, sess): """Returns the handle of the training dataset. The handle can be used to feed the :attr:`handle` placeholder to fetch training data. Args: sess: The current tf session. Returns: A string handle to be fed to the :attr:`handle` placeholder. Example: .. code-block:: python next_element = iterator.get_next() train_handle = iterator.get_train_handle(sess) # Gets the next training element ne_ = sess.run(next_element, feed_dict={iterator.handle: train_handle}) """ if self._train_name not in self._datasets: raise ValueError("Training data not provided.") return self.get_handle(sess, self._train_name)
[docs] def get_val_handle(self, sess): """Returns the handle of the validation dataset. The handle can be used to feed the :attr:`handle` placeholder to fetch validation data. Args: sess: The current tf session. Returns: A string handle to be fed to the :attr:`handle` placeholder. """ if self._val_name not in self._datasets: raise ValueError("Val data not provided.") return self.get_handle(sess, self._val_name)
[docs] def get_test_handle(self, sess): """Returns the handle of the test dataset. The handle can be used to feed the :attr:`handle` placeholder to fetch test data. Args: sess: The current tf session. Returns: A string handle to be fed to the :attr:`handle` placeholder. """ if self._test_name not in self._datasets: raise ValueError("Test data not provided.") return self.get_handle(sess, self._test_name)
[docs] def restart_train_dataset(self, sess): """Restarts the training dataset so that next iteration will fetch data from the beginning of the training dataset. Args: sess: The current tf session. """ if self._train_name not in self._datasets: raise ValueError("Training data not provided.") self.restart_dataset(sess, self._train_name)
[docs] def restart_val_dataset(self, sess): """Restarts the validation dataset so that next iteration will fetch data from the beginning of the validation dataset. Args: sess: The current tf session. """ if self._val_name not in self._datasets: raise ValueError("Val data not provided.") self.restart_dataset(sess, self._val_name)
[docs] def restart_test_dataset(self, sess): """Restarts the test dataset so that next iteration will fetch data from the beginning of the test dataset. Args: sess: The current tf session. """ if self._test_name not in self._datasets: raise ValueError("Test data not provided.") self.restart_dataset(sess, self._test_name)