Source code for texar.tf.data.data.multi_aligned_data

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data consisting of multiple aligned parts.
"""

import copy

import tensorflow as tf

from texar.tf.hyperparams import HParams
from texar.tf.utils import utils
from texar.tf.utils.dtypes import is_str, is_callable
from texar.tf.data.data.text_data_base import TextDataBase
from texar.tf.data.data.scalar_data import ScalarData
from texar.tf.data.data.tfrecord_data import TFRecordData
from texar.tf.data.data.mono_text_data import _default_mono_text_dataset_hparams
from texar.tf.data.data.scalar_data import _default_scalar_dataset_hparams
from texar.tf.data.data.tfrecord_data import _default_tfrecord_dataset_hparams
from texar.tf.data.data.mono_text_data import MonoTextData
from texar.tf.data.data_utils import count_file_lines
from texar.tf.data.data import dataset_utils as dsutils
from texar.tf.data.vocabulary import Vocab, SpecialTokens
from texar.tf.data.embedding import Embedding

# pylint: disable=invalid-name, arguments-differ
# pylint: disable=protected-access, too-many-instance-attributes

__all__ = [
    "_default_dataset_hparams",
    "MultiAlignedData"
]


class _DataTypes(object):  # pylint: disable=no-init, too-few-public-methods
    """Enumeration of data types.
    """
    TEXT = "text"
    INT = "int"
    FLOAT = "float"
    TF_RECORD = "tf_record"


def _is_text_data(data_type):
    return data_type == _DataTypes.TEXT


def _is_scalar_data(data_type):
    return data_type == _DataTypes.INT or data_type == _DataTypes.FLOAT


def _is_tfrecord_data(data_type):
    return data_type == _DataTypes.TF_RECORD


def _default_dataset_hparams(data_type=None):
    """Returns hyperparameters of a dataset with default values.

    See :meth:`texar.tf.data.MultiAlignedData.default_hparams` for details.
    """
    if not data_type or _is_text_data(data_type):
        hparams = _default_mono_text_dataset_hparams()
        hparams.update({
            "data_type": _DataTypes.TEXT,
            "vocab_share_with": None,
            "embedding_init_share_with": None,
            "processing_share_with": None,
        })
    elif _is_scalar_data(data_type):
        hparams = _default_scalar_dataset_hparams()
    elif _is_tfrecord_data(data_type):
        hparams = _default_tfrecord_dataset_hparams()
        hparams.update({
            "data_type": _DataTypes.TF_RECORD,
        })
    return hparams


[docs]class MultiAlignedData(TextDataBase): """Data consisting of multiple aligned parts. Args: hparams (dict): Hyperparameters. See :meth:`default_hparams` for the defaults. The processor can read any number of parallel fields as specified in the "datasets" list of :attr:`hparams`, and result in a TF Dataset whose element is a python `dict` containing data fields from each of the specified datasets. Fields from a text dataset or TFRecord dataset have names prefixed by its "data_name". Fields from a scalar dataset are specified by its "data_name". Example: .. code-block:: python hparams={ 'datasets': [ {'files': 'a.txt', 'vocab_file': 'v.a', 'data_name': 'x'}, {'files': 'b.txt', 'vocab_file': 'v.b', 'data_name': 'y'}, {'files': 'c.txt', 'data_type': 'int', 'data_name': 'z'} ] 'batch_size': 1 } data = MultiAlignedData(hparams) iterator = DataIterator(data) batch = iterator.get_next() iterator.switch_to_dataset(sess) # initializes the dataset batch_ = sess.run(batch) # batch_ == { # 'x_text': [['<BOS>', 'x', 'sequence', '<EOS>']], # 'x_text_ids': [['1', '5', '10', '2']], # 'x_length': [4] # 'y_text': [['<BOS>', 'y', 'sequence', '1', '<EOS>']], # 'y_text_ids': [['1', '6', '10', '20', '2']], # 'y_length': [5], # 'z': [1000], # } ... hparams={ 'datasets': [ {'files': 'd.txt', 'vocab_file': 'v.d', 'data_name': 'm'}, { 'files': 'd.tfrecord', 'data_type': 'tf_record', "feature_original_types": { 'image': ['tf.string', 'FixedLenFeature'] }, 'image_options': { 'image_feature_name': 'image', 'resize_height': 512, 'resize_width': 512, }, 'data_name': 't', } ] 'batch_size': 1 } data = MultiAlignedData(hparams) iterator = DataIterator(data) batch = iterator.get_next() iterator.switch_to_dataset(sess) # initializes the dataset batch_ = sess.run(batch) # batch_ == { # 'x_text': [['<BOS>', 'NewYork', 'City', 'Map', '<EOS>']], # 'x_text_ids': [['1', '100', '80', '65', '2']], # 'x_length': [5], # # # "t_image" is a list of a "numpy.ndarray" image # # in this example. Its width equals to 512 and # # its height equals to 512. # 't_image': [...] # } """ def __init__(self, hparams): TextDataBase.__init__(self, hparams) # Defaultizes hparams of each dataset datasets_hparams = self._hparams.datasets defaultized_datasets_hparams = [] for ds_hpms in datasets_hparams: data_type = ds_hpms.get("data_type", None) defaultized_ds_hpms = HParams(ds_hpms, _default_dataset_hparams(data_type)) defaultized_datasets_hparams.append(defaultized_ds_hpms) self._hparams.datasets = defaultized_datasets_hparams with tf.name_scope(self.name, self.default_hparams()["name"]): self._make_data()
[docs] @staticmethod def default_hparams(): """Returns a dicitionary of default hyperparameters. .. code-block:: python { # (1) Hyperparams specific to text dataset "datasets": [] # (2) General hyperparams "num_epochs": 1, "batch_size": 64, "allow_smaller_final_batch": True, "shuffle": True, "shuffle_buffer_size": None, "shard_and_shuffle": False, "num_parallel_calls": 1, "prefetch_buffer_size": 0, "max_dataset_size": -1, "seed": None, "name": "multi_aligned_data", } Here: 1. "datasets" is a list of `dict` each of which specifies a dataset which can be text, scalar or TFRecord. The :attr:`"data_name"` field of each dataset is used as the name prefix of the data fields from the respective dataset. The :attr:`"data_name"` field of each dataset should not be the same. - For scalar dataset, the allowed hyperparameters and default \ values are the same as the "dataset" field of \ :meth:`texar.tf.data.ScalarData.default_hparams`. Note that \ :attr:`"data_type"` must be explicily specified \ (either "int" or "float"). \ - For TFRecord dataset, the allowed hyperparameters and default \ values are the same as the "dataset" field of \ :meth:`texar.tf.data.TFRecordData.default_hparams`. Note that \ :attr:`"data_type"` must be explicily specified \ (tf_record"). \ - For text dataset, the allowed hyperparameters and default values\ are the same as the "dataset" filed of \ :meth:`texar.tf.data.MonoTextData.default_hparams`, with several \ extra hyperparameters: "data_type": str The type of the dataset, one of {"text", "int", "float", "tf_record"}. If set to "int" or "float", the dataset is considered to be a scalar dataset. If set to "tf_record", the dataset is considered to be a TFRecord dataset. If not specified or set to "text", the dataset is considered to be a text dataset. "vocab_share_with": int, optional Share the vocabulary of a preceding text dataset with the specified index in the list (starting from 0). The specified dataset must be a text dataset, and must have an index smaller than the current dataset. If specified, the vocab file of current dataset is ignored. Default is `None` which disables the vocab sharing. "embedding_init_share_with": int, optional Share the embedding initial value of a preceding text dataset with the specified index in the list (starting from 0). The specified dataset must be a text dataset, and must have an index smaller than the current dataset. If specified, the :attr:`"embedding_init"` field of the current dataset is ignored. Default is `None` which disables the initial value sharing. "processing_share_with": int, optional Share the processing configurations of a preceding text dataset with the specified index in the list (starting from 0). The specified dataset must be a text dataset, and must have an index smaller than the current dataset. If specified, relevant field of the current dataset are ignored, including "delimiter", "bos_token", "eos_token", and "other_transformations". Default is `None` which disables the processing sharing. 2. For the **general** hyperparameters, see :meth:`texar.tf.data.DataBase.default_hparams` for details. """ hparams = TextDataBase.default_hparams() hparams["name"] = "multi_aligned_data" hparams["datasets"] = [] return hparams
@staticmethod def _raise_sharing_error(err_data, shr_data, hparam_name): raise ValueError( "Must only share specifications with a preceding dataset. " "Dataset %d has '%s=%d'" % (err_data, hparam_name, shr_data)) @staticmethod def make_vocab(hparams): """Makes a list of vocabs based on the hparams. Args: hparams (list): A list of dataset hyperparameters. Returns: A list of :class:`texar.tf.data.Vocab` instances. Some instances may be the same objects if they are set to be shared and have the same other configs. """ if not isinstance(hparams, (list, tuple)): hparams = [hparams] vocabs = [] for i, hparams_i in enumerate(hparams): if not _is_text_data(hparams_i["data_type"]): vocabs.append(None) continue proc_shr = hparams_i["processing_share_with"] if proc_shr is not None: bos_token = hparams[proc_shr]["bos_token"] eos_token = hparams[proc_shr]["eos_token"] else: bos_token = hparams_i["bos_token"] eos_token = hparams_i["eos_token"] bos_token = utils.default_str( bos_token, SpecialTokens.BOS) eos_token = utils.default_str( eos_token, SpecialTokens.EOS) vocab_shr = hparams_i["vocab_share_with"] if vocab_shr is not None: if vocab_shr >= i: MultiAlignedData._raise_sharing_error( i, vocab_shr, "vocab_share_with") if not vocabs[vocab_shr]: raise ValueError("Cannot share vocab with dataset %d which " "does not have a vocab." % vocab_shr) if bos_token == vocabs[vocab_shr].bos_token and \ eos_token == vocabs[vocab_shr].eos_token: vocab = vocabs[vocab_shr] else: vocab = Vocab(hparams[vocab_shr]["vocab_file"], bos_token=bos_token, eos_token=eos_token) else: vocab = Vocab(hparams_i["vocab_file"], bos_token=bos_token, eos_token=eos_token) vocabs.append(vocab) return vocabs @staticmethod def make_embedding(hparams, vocabs): """Optionally loads embeddings from files (if provided), and returns respective :class:`texar.tf.data.Embedding` instances. """ if not isinstance(hparams, (list, tuple)): hparams = [hparams] embs = [] for i, hparams_i in enumerate(hparams): if not _is_text_data(hparams_i["data_type"]): embs.append(None) continue emb_shr = hparams_i["embedding_init_share_with"] if emb_shr is not None: if emb_shr >= i: MultiAlignedData._raise_sharing_error( i, emb_shr, "embedding_init_share_with") if not embs[emb_shr]: raise ValueError("Cannot share embedding with dataset %d " "which does not have an embedding." % emb_shr) if emb_shr != hparams_i["vocab_share_with"]: raise ValueError("'embedding_init_share_with' != " "vocab_share_with. embedding_init can " "be shared only when vocab is shared.") emb = embs[emb_shr] else: emb = None emb_file = hparams_i["embedding_init"]["file"] if emb_file and emb_file != "": emb = Embedding(vocabs[i].token_to_id_map_py, hparams_i["embedding_init"]) embs.append(emb) return embs def _make_dataset(self): datasets = [] for _, hparams_i in enumerate(self._hparams.datasets): dtype = hparams_i.data_type if _is_text_data(dtype) or _is_scalar_data(dtype): dataset = tf.data.TextLineDataset( hparams_i.files, compression_type=hparams_i.compression_type) datasets.append(dataset) elif _is_tfrecord_data(dtype): dataset = tf.data.TFRecordDataset(filenames=hparams_i.files) num_shards = hparams_i.num_shards shard_id = hparams_i.shard_id if num_shards is not None and shard_id is not None: dataset = dataset.shard(num_shards, shard_id) datasets.append(dataset) else: raise ValueError("Unknown data type: %s" % hparams_i.data_type) return tf.data.Dataset.zip(tuple(datasets)) # @staticmethod # def _get_name_prefix(dataset_hparams): # def _dtype_conflict(dtype_1, dtype_2): # conflict = ((dtype_1 == dtype_2) or # (dtype_1 in {_DataTypes.INT, _DataTypes.FLOAT} and # dtype_2 in {_DataTypes.INT, _DataTypes.FLOAT})) # return conflict # name_prefix = [hpms["data_name"] for hpms in dataset_hparams] # name_prefix_dict = {} # for i, np in enumerate(name_prefix): # ids = name_prefix_dict.get(np, []) # for j in ids: # if _dtype_conflict(dataset_hparams[j]["data_type"], # dataset_hparams[i]["data_type"]): # raise ValueError( # "'data_name' of the datasets with compatible " # "data_types cannot be the same: %d-th dataset and " # "%d-th dataset have the same name '%s'" % # (i, j, name_prefix[i])) # ids.append(i) # name_prefix_dict[np] = ids # return name_prefix @staticmethod def _get_name_prefix(dataset_hparams): name_prefix = [hpms["data_name"] for hpms in dataset_hparams] for i in range(1, len(name_prefix)): if name_prefix[i] in name_prefix[:i - 1]: raise ValueError("Data name duplicated: %s" % name_prefix[i]) return name_prefix @staticmethod def _make_processor(dataset_hparams, data_spec, name_prefix): processors = [] for i, hparams_i in enumerate(dataset_hparams): data_spec_i = data_spec.get_ith_data_spec(i) data_type = hparams_i["data_type"] if _is_text_data(data_type): tgt_proc_hparams = hparams_i proc_shr = hparams_i["processing_share_with"] if proc_shr is not None: tgt_proc_hparams = copy.copy(dataset_hparams[proc_shr]) try: tgt_proc_hparams["variable_utterance"] = \ hparams_i["variable_utterance"] except TypeError: tgt_proc_hparams.variable_utterance = \ hparams_i["variable_utterance"] processor, data_spec_i = MonoTextData._make_processor( tgt_proc_hparams, data_spec_i) elif _is_scalar_data(data_type): processor, data_spec_i = ScalarData._make_processor( hparams_i, data_spec_i, name_prefix='') elif _is_tfrecord_data(data_type): processor, data_spec_i = TFRecordData._make_processor( hparams_i, data_spec_i, name_prefix='') else: raise ValueError("Unsupported data type: %s" % data_type) processors.append(processor) data_spec.set_ith_data_spec(i, data_spec_i, len(dataset_hparams)) tran_fn = dsutils.make_combined_transformation( processors, name_prefix=name_prefix) data_spec.add_spec(name_prefix=name_prefix) return tran_fn, data_spec @staticmethod def _make_length_filter(dataset_hparams, length_name, decoder): filter_fns = [] for i, hpms in enumerate(dataset_hparams): if not _is_text_data(hpms["data_type"]): filter_fn = None else: filter_fn = MonoTextData._make_length_filter( hpms, length_name[i], decoder[i]) filter_fns.append(filter_fn) combined_filter_fn = dsutils._make_combined_filter_fn(filter_fns) return combined_filter_fn def _process_dataset(self, dataset, hparams, data_spec): name_prefix = self._get_name_prefix(hparams["datasets"]) # pylint: disable=attribute-defined-outside-init self._name_to_id = {v: k for k, v in enumerate(name_prefix)} tran_fn, data_spec = self._make_processor( hparams["datasets"], data_spec, name_prefix) num_parallel_calls = hparams["num_parallel_calls"] dataset = dataset.map( lambda *args: tran_fn(dsutils.maybe_tuple(args)), num_parallel_calls=num_parallel_calls) # Filters by length def _get_length_name(i): if not _is_text_data(hparams["datasets"][i]["data_type"]): return None name = dsutils._connect_name( data_spec.name_prefix[i], data_spec.decoder[i].length_tensor_name) return name filter_fn = self._make_length_filter( hparams["datasets"], [_get_length_name(i) for i in range(len(hparams["datasets"]))], data_spec.decoder) if filter_fn: dataset = dataset.filter(filter_fn) # Truncates data count dataset = dataset.take(hparams["max_dataset_size"]) return dataset, data_spec def _make_bucket_length_fn(self): length_fn = self._hparams.bucket_length_fn if not length_fn: # Uses the length of the first text data i = -1 for i, hparams_i in enumerate(self._hparams.datasets): if _is_text_data(hparams_i["data_type"]): break if i < 0: raise ValueError("Undefined `length_fn`.") length_fn = lambda x: x[self.length_name(i)] elif not is_callable(length_fn): # pylint: disable=redefined-variable-type length_fn = utils.get_function(length_fn, ["texar.tf.custom"]) return length_fn def _make_padded_shapes(self, dataset, decoders): padded_shapes = dataset.output_shapes for i, hparams_i in enumerate(self._hparams.datasets): if not _is_text_data(hparams_i["data_type"]): continue if not hparams_i["pad_to_max_seq_length"]: continue text_and_id_shapes = MonoTextData._make_padded_text_and_id_shapes( dataset, hparams_i, decoders[i], self.text_name(i), self.text_id_name(i)) padded_shapes.update(text_and_id_shapes) return padded_shapes def _make_data(self): self._vocab = self.make_vocab(self._hparams.datasets) self._embedding = self.make_embedding(self._hparams.datasets, self._vocab) # Create dataset dataset = self._make_dataset() dataset, dataset_size = self._shuffle_dataset( dataset, self._hparams, self._hparams.datasets[0].files) self._dataset_size = dataset_size # Processing data_spec = dsutils._DataSpec(dataset=dataset, dataset_size=self._dataset_size, vocab=self._vocab, embedding=self._embedding) dataset, data_spec = self._process_dataset( dataset, self._hparams, data_spec) self._data_spec = data_spec self._decoder = data_spec.decoder # Batching length_fn = self._make_bucket_length_fn() padded_shapes = self._make_padded_shapes(dataset, self._decoder) dataset = self._make_batch( dataset, self._hparams, length_fn, padded_shapes) # Prefetching if self._hparams.prefetch_buffer_size > 0: dataset = dataset.prefetch(self._hparams.prefetch_buffer_size) self._dataset = dataset
[docs] def list_items(self): """Returns the list of item names that the data can produce. Returns: A list of strings. """ return list(self._dataset.output_types.keys())
@property def dataset(self): """The dataset. """ return self._dataset
[docs] def dataset_size(self): """Returns the number of data instances in the dataset. Note that this is the total data count in the raw files, before any filtering and truncation. """ if not self._dataset_size: # pylint: disable=attribute-defined-outside-init self._dataset_size = count_file_lines( self._hparams.datasets[0].files) return self._dataset_size
def _maybe_name_to_id(self, name_or_id): if is_str(name_or_id): if name_or_id not in self._name_to_id: raise ValueError("Unknown data name: {}".format(name_or_id)) return self._name_to_id[name_or_id] return name_or_id
[docs] def vocab(self, name_or_id): """Returns the :class:`~texar.tf.data.Vocab` of text dataset by its name or id. `None` if the dataset is not of text type. Args: name_or_id (str or int): Data name or the index of text dataset. """ i = self._maybe_name_to_id(name_or_id) return self._vocab[i]
[docs] def embedding_init_value(self, name_or_id): """Returns the `Tensor` of embedding init value of the dataset by its name or id. `None` if the dataset is not of text type. """ i = self._maybe_name_to_id(name_or_id) return self._embedding[i]
[docs] def text_name(self, name_or_id): """The name of text tensor of text dataset by its name or id. If the dataaet is not of text type, returns `None`. """ i = self._maybe_name_to_id(name_or_id) if not _is_text_data(self._hparams.datasets[i]["data_type"]): return None name = dsutils._connect_name( self._data_spec.name_prefix[i], self._data_spec.decoder[i].text_tensor_name) return name
[docs] def length_name(self, name_or_id): """The name of length tensor of text dataset by its name or id. If the dataset is not of text type, returns `None`. """ i = self._maybe_name_to_id(name_or_id) if not _is_text_data(self._hparams.datasets[i]["data_type"]): return None name = dsutils._connect_name( self._data_spec.name_prefix[i], self._data_spec.decoder[i].length_tensor_name) return name
[docs] def text_id_name(self, name_or_id): """The name of length tensor of text dataset by its name or id. If the dataset is not of text type, returns `None`. """ i = self._maybe_name_to_id(name_or_id) if not _is_text_data(self._hparams.datasets[i]["data_type"]): return None name = dsutils._connect_name( self._data_spec.name_prefix[i], self._data_spec.decoder[i].text_id_tensor_name) return name
[docs] def utterance_cnt_name(self, name_or_id): """The name of utterance count tensor of text dataset by its name or id. If the dataset is not variable utterance text data, returns `None`. """ i = self._maybe_name_to_id(name_or_id) if not _is_text_data(self._hparams.datasets[i]["data_type"]) or \ not self._hparams.datasets[i]["variable_utterance"]: return None name = dsutils._connect_name( self._data_spec.name_prefix[i], self._data_spec.decoder[i].utterance_cnt_tensor_name) return name
@property def data_name(self, name_or_id): """The name of the data tensor of scalar dataset by its name or id.. If the dataset is not a scalar data, returns `None`. """ i = self._maybe_name_to_id(name_or_id) if not _is_scalar_data(self._hparams.datasets[i]["data_type"]): return None name = dsutils._connect_name( self._data_spec.name_prefix[i], self._data_spec.decoder[i].data_tensor_name) return name