Source code for texar.tf.data.data.paired_text_data

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Paired text data that consists of source text and target text.
"""

import copy

import tensorflow as tf

from texar.tf.utils import utils
from texar.tf.utils.dtypes import is_callable
from texar.tf.data.data.mono_text_data import _default_mono_text_dataset_hparams
from texar.tf.data.data.text_data_base import TextDataBase
from texar.tf.data.data.mono_text_data import MonoTextData
from texar.tf.data.data_utils import count_file_lines
from texar.tf.data.data import dataset_utils as dsutils
from texar.tf.data.vocabulary import Vocab, SpecialTokens
from texar.tf.data.embedding import Embedding

# pylint: disable=invalid-name, arguments-differ, not-context-manager
# pylint: disable=protected-access, too-many-arguments

__all__ = [
    "_default_paired_text_dataset_hparams",
    "PairedTextData"
]


def _default_paired_text_dataset_hparams():
    """Returns hyperparameters of a paired text dataset with default values.

    See :meth:`texar.tf.data.PairedTextData.default_hparams` for details.
    """
    source_hparams = _default_mono_text_dataset_hparams()
    source_hparams["bos_token"] = None
    source_hparams["data_name"] = "source"
    target_hparams = _default_mono_text_dataset_hparams()
    target_hparams.update(
        {
            "vocab_share": False,
            "embedding_init_share": False,
            "processing_share": False,
            "data_name": "target"
        }
    )
    return {
        "source_dataset": source_hparams,
        "target_dataset": target_hparams
    }


# pylint: disable=too-many-instance-attributes, too-many-public-methods
[docs]class PairedTextData(TextDataBase): """Text data processor that reads parallel source and target text. This can be used in, e.g., seq2seq models. Args: hparams (dict): Hyperparameters. See :meth:`default_hparams` for the defaults. By default, the processor reads raw data files, performs tokenization, batching and other pre-processing steps, and results in a TF Dataset whose element is a python `dict` including six fields: - "source_text": A string Tensor of shape `[batch_size, max_time]` containing the **raw** text toknes of source sequences. `max_time` is the length of the longest sequence in the batch. Short sequences in the batch are padded with **empty string**. By default only EOS token is appended to each sequence. Out-of-vocabulary tokens are **NOT** replaced with UNK. - "source_text_ids": An `int64` Tensor of shape `[batch_size, max_time]` containing the token indexes of source sequences. - "source_length": An `int` Tensor of shape `[batch_size]` containing the length of each source sequence in the batch (including BOS and/or EOS if added). - "target_text": A string Tensor as "source_text" but for target sequences. By default both BOS and EOS are added. - "target_text_ids": An `int64` Tensor as "source_text_ids" but for target sequences. - "target_length": An `int` Tensor of shape `[batch_size]` as "source_length" but for target sequences. If :attr:`'variable_utterance'` is set to `True` in :attr:`'source_dataset'` and/or :attr:`'target_dataset'` of :attr:`hparams`, the corresponding fields "source_*" and/or "target_*" are respectively changed to contain variable utterance text data, as in :class:`~texar.tf.data.MonoTextData`. The above field names can be accessed through :attr:`source_text_name`, :attr:`source_text_id_name`, :attr:`source_length_name`, :attr:`source_utterance_cnt_name`, and those prefixed with `target_`, respectively. Example: .. code-block:: python hparams={ 'source_dataset': {'files': 's', 'vocab_file': 'vs'}, 'target_dataset': {'files': ['t1', 't2'], 'vocab_file': 'vt'}, 'batch_size': 1 } data = PairedTextData(hparams) iterator = DataIterator(data) batch = iterator.get_next() iterator.switch_to_dataset(sess) # initializes the dataset batch_ = sess.run(batch) # batch_ == { # 'source_text': [['source', 'sequence', '<EOS>']], # 'source_text_ids': [[5, 10, 2]], # 'source_length': [3] # 'target_text': [['<BOS>', 'target', 'sequence', '1', '<EOS>']], # 'target_text_ids': [[1, 6, 10, 20, 2]], # 'target_length': [5] # } """ def __init__(self, hparams): TextDataBase.__init__(self, hparams) with tf.name_scope(self.name, self.default_hparams()["name"]): self._make_data()
[docs] @staticmethod def default_hparams(): """Returns a dicitionary of default hyperparameters. .. code-block:: python { # (1) Hyperparams specific to text dataset "source_dataset": { "files": [], "compression_type": None, "vocab_file": "", "embedding_init": {}, "delimiter": " ", "max_seq_length": None, "length_filter_mode": "truncate", "pad_to_max_seq_length": False, "bos_token": None, "eos_token": "<EOS>", "other_transformations": [], "variable_utterance": False, "utterance_delimiter": "|||", "max_utterance_cnt": 5, "data_name": "source", }, "target_dataset": { # ... # Same fields are allowed as in "source_dataset" with the # same default values, except the # following new fields/values: "bos_token": "<BOS>" "vocab_share": False, "embedding_init_share": False, "processing_share": False, "data_name": "target" } # (2) General hyperparams "num_epochs": 1, "batch_size": 64, "allow_smaller_final_batch": True, "shuffle": True, "shuffle_buffer_size": None, "shard_and_shuffle": False, "num_parallel_calls": 1, "prefetch_buffer_size": 0, "max_dataset_size": -1, "seed": None, "name": "paired_text_data", # (3) Bucketing "bucket_boundaries": [], "bucket_batch_sizes": None, "bucket_length_fn": None, } Here: 1. Hyperparameters in the :attr:`"source_dataset"` and attr:`"target_dataset"` fields have the same definition as those in :meth:`texar.tf.data.MonoTextData.default_hparams`, for source and target text, respectively. For the new hyperparameters in "target_dataset": "vocab_share": bool Whether to share the vocabulary of source. If `True`, the vocab file of target is ignored. "embedding_init_share": bool Whether to share the embedding initial value of source. If `True`, :attr:`"embedding_init"` of target is ignored. :attr:`"vocab_share"` must be true to share the embedding initial value. "processing_share": bool Whether to share the processing configurations of source, including "delimiter", "bos_token", "eos_token", and "other_transformations". 2. For the **general** hyperparameters, see :meth:`texar.tf.data.DataBase.default_hparams` for details. 3. For **bucketing** hyperparameters, see :meth:`texar.tf.data.MonoTextData.default_hparams` for details, except that the default bucket_length_fn is the maximum sequence length of source and target sequences. """ hparams = TextDataBase.default_hparams() hparams["name"] = "paired_text_data" hparams.update(_default_paired_text_dataset_hparams()) return hparams
@staticmethod def make_vocab(src_hparams, tgt_hparams): """Reads vocab files and returns source vocab and target vocab. Args: src_hparams (dict or HParams): Hyperparameters of source dataset. tgt_hparams (dict or HParams): Hyperparameters of target dataset. Returns: A pair of :class:`texar.tf.data.Vocab` instances. The two instances may be the same objects if source and target vocabs are shared and have the same other configs. """ src_vocab = MonoTextData.make_vocab(src_hparams) if tgt_hparams["processing_share"]: tgt_bos_token = src_hparams["bos_token"] tgt_eos_token = src_hparams["eos_token"] else: tgt_bos_token = tgt_hparams["bos_token"] tgt_eos_token = tgt_hparams["eos_token"] tgt_bos_token = utils.default_str(tgt_bos_token, SpecialTokens.BOS) tgt_eos_token = utils.default_str(tgt_eos_token, SpecialTokens.EOS) if tgt_hparams["vocab_share"]: if tgt_bos_token == src_vocab.bos_token and \ tgt_eos_token == src_vocab.eos_token: tgt_vocab = src_vocab else: tgt_vocab = Vocab(src_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) else: tgt_vocab = Vocab(tgt_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) return src_vocab, tgt_vocab @staticmethod def make_embedding(src_emb_hparams, src_token_to_id_map, tgt_emb_hparams=None, tgt_token_to_id_map=None, emb_init_share=False): """Optionally loads source and target embeddings from files (if provided), and returns respective :class:`texar.tf.data.Embedding` instances. """ src_embedding = MonoTextData.make_embedding(src_emb_hparams, src_token_to_id_map) if emb_init_share: tgt_embedding = src_embedding else: tgt_emb_file = tgt_emb_hparams["file"] tgt_embedding = None if tgt_emb_file is not None and tgt_emb_file != "": tgt_embedding = Embedding(tgt_token_to_id_map, tgt_emb_hparams) return src_embedding, tgt_embedding def _make_dataset(self): src_dataset = tf.data.TextLineDataset( self._hparams.source_dataset.files, compression_type=self._hparams.source_dataset.compression_type) tgt_dataset = tf.data.TextLineDataset( self._hparams.target_dataset.files, compression_type=self._hparams.target_dataset.compression_type) return tf.data.Dataset.zip((src_dataset, tgt_dataset)) @staticmethod def _get_name_prefix(src_hparams, tgt_hparams): name_prefix = [ src_hparams["data_name"], tgt_hparams["data_name"]] if name_prefix[0] == name_prefix[1]: raise ValueError("'data_name' of source and target " "datasets cannot be the same.") return name_prefix @staticmethod def _make_processor(src_hparams, tgt_hparams, data_spec, name_prefix): # Create source data decoder data_spec_i = data_spec.get_ith_data_spec(0) src_decoder, src_trans, data_spec_i = MonoTextData._make_processor( src_hparams, data_spec_i, chained=False) data_spec.set_ith_data_spec(0, data_spec_i, 2) # Create target data decoder tgt_proc_hparams = tgt_hparams if tgt_hparams["processing_share"]: tgt_proc_hparams = copy.copy(src_hparams) try: tgt_proc_hparams["variable_utterance"] = \ tgt_hparams["variable_utterance"] except TypeError: tgt_proc_hparams.variable_utterance = \ tgt_hparams["variable_utterance"] data_spec_i = data_spec.get_ith_data_spec(1) tgt_decoder, tgt_trans, data_spec_i = MonoTextData._make_processor( tgt_proc_hparams, data_spec_i, chained=False) data_spec.set_ith_data_spec(1, data_spec_i, 2) tran_fn = dsutils.make_combined_transformation( [[src_decoder] + src_trans, [tgt_decoder] + tgt_trans], name_prefix=name_prefix) data_spec.add_spec(name_prefix=name_prefix) return tran_fn, data_spec @staticmethod def _make_length_filter(src_hparams, tgt_hparams, src_length_name, tgt_length_name, src_decoder, tgt_decoder): src_filter_fn = MonoTextData._make_length_filter( src_hparams, src_length_name, src_decoder) tgt_filter_fn = MonoTextData._make_length_filter( tgt_hparams, tgt_length_name, tgt_decoder) combined_filter_fn = dsutils._make_combined_filter_fn( [src_filter_fn, tgt_filter_fn]) return combined_filter_fn def _process_dataset(self, dataset, hparams, data_spec): name_prefix = PairedTextData._get_name_prefix( hparams["source_dataset"], hparams["target_dataset"]) tran_fn, data_spec = self._make_processor( hparams["source_dataset"], hparams["target_dataset"], data_spec, name_prefix=name_prefix) num_parallel_calls = hparams["num_parallel_calls"] dataset = dataset.map( lambda *args: tran_fn(dsutils.maybe_tuple(args)), num_parallel_calls=num_parallel_calls) # Filters by length src_length_name = dsutils._connect_name( data_spec.name_prefix[0], data_spec.decoder[0].length_tensor_name) tgt_length_name = dsutils._connect_name( data_spec.name_prefix[1], data_spec.decoder[1].length_tensor_name) filter_fn = self._make_length_filter( hparams["source_dataset"], hparams["target_dataset"], src_length_name, tgt_length_name, data_spec.decoder[0], data_spec.decoder[1]) if filter_fn: dataset = dataset.filter(filter_fn) # Truncates data count dataset = dataset.take(hparams["max_dataset_size"]) return dataset, data_spec def _make_bucket_length_fn(self): length_fn = self._hparams.bucket_length_fn if not length_fn: length_fn = lambda x: tf.maximum( x[self.source_length_name], x[self.target_length_name]) elif not is_callable(length_fn): # pylint: disable=redefined-variable-type length_fn = utils.get_function(length_fn, ["texar.tf.custom"]) return length_fn def _make_padded_shapes(self, dataset, src_decoder, tgt_decoder): src_text_and_id_shapes = {} if self._hparams.source_dataset.pad_to_max_seq_length: src_text_and_id_shapes = \ MonoTextData._make_padded_text_and_id_shapes( dataset, self._hparams.source_dataset, src_decoder, self.source_text_name, self.source_text_id_name) tgt_text_and_id_shapes = {} if self._hparams.target_dataset.pad_to_max_seq_length: tgt_text_and_id_shapes = \ MonoTextData._make_padded_text_and_id_shapes( dataset, self._hparams.target_dataset, tgt_decoder, self.target_text_name, self.target_text_id_name) padded_shapes = dataset.output_shapes padded_shapes.update(src_text_and_id_shapes) padded_shapes.update(tgt_text_and_id_shapes) return padded_shapes def _make_data(self): self._src_vocab, self._tgt_vocab = self.make_vocab( self._hparams.source_dataset, self._hparams.target_dataset) tgt_hparams = self._hparams.target_dataset if not tgt_hparams.vocab_share and tgt_hparams.embedding_init_share: raise ValueError("embedding_init can be shared only when vocab " "is shared. Got `vocab_share=False, " "emb_init_share=True`.") self._src_embedding, self._tgt_embedding = self.make_embedding( self._hparams.source_dataset.embedding_init, self._src_vocab.token_to_id_map_py, self._hparams.target_dataset.embedding_init, self._tgt_vocab.token_to_id_map_py, self._hparams.target_dataset.embedding_init_share) # Create dataset dataset = self._make_dataset() dataset, dataset_size = self._shuffle_dataset( dataset, self._hparams, self._hparams.source_dataset.files) self._dataset_size = dataset_size # Processing. data_spec = dsutils._DataSpec( dataset=dataset, dataset_size=self._dataset_size, vocab=[self._src_vocab, self._tgt_vocab], embedding=[self._src_embedding, self._tgt_embedding]) dataset, data_spec = self._process_dataset( dataset, self._hparams, data_spec) self._data_spec = data_spec self._decoder = data_spec.decoder self._src_decoder = data_spec.decoder[0] self._tgt_decoder = data_spec.decoder[1] # Batching length_fn = self._make_bucket_length_fn() padded_shapes = self._make_padded_shapes( dataset, self._src_decoder, self._tgt_decoder) dataset = self._make_batch( dataset, self._hparams, length_fn, padded_shapes) # Prefetching if self._hparams.prefetch_buffer_size > 0: dataset = dataset.prefetch(self._hparams.prefetch_buffer_size) self._dataset = dataset
[docs] def list_items(self): """Returns the list of item names that the data can produce. Returns: A list of strings. """ return list(self._dataset.output_types.keys())
@property def dataset(self): """The dataset. """ return self._dataset
[docs] def dataset_size(self): """Returns the number of data instances in the dataset. Note that this is the total data count in the raw files, before any filtering and truncation. """ if not self._dataset_size: # pylint: disable=attribute-defined-outside-init self._dataset_size = count_file_lines( self._hparams.source_dataset.files) return self._dataset_size
@property def vocab(self): """A pair instances of :class:`~texar.tf.data.Vocab` that are source and target vocabs, respectively. """ return self._src_vocab, self._tgt_vocab @property def source_vocab(self): """The source vocab, an instance of :class:`~texar.tf.data.Vocab`. """ return self._src_vocab @property def target_vocab(self): """The target vocab, an instance of :class:`~texar.tf.data.Vocab`. """ return self._tgt_vocab @property def source_embedding_init_value(self): """The `Tensor` containing the embedding value of source data loaded from file. `None` if embedding is not specified. """ if self._src_embedding is None: return None return self._src_embedding.word_vecs @property def target_embedding_init_value(self): """The `Tensor` containing the embedding value of target data loaded from file. `None` if embedding is not specified. """ if self._tgt_embedding is None: return None return self._tgt_embedding.word_vecs
[docs] def embedding_init_value(self): """A pair of `Tensor` containing the embedding values of source and target data loaded from file. """ src_emb = self.source_embedding_init_value tgt_emb = self.target_embedding_init_value return src_emb, tgt_emb
@property def source_text_name(self): """The name of the source text tensor, "source_text" by default. """ name = dsutils._connect_name( self._data_spec.name_prefix[0], self._src_decoder.text_tensor_name) return name @property def source_length_name(self): """The name of the source length tensor, "source_length" by default. """ name = dsutils._connect_name( self._data_spec.name_prefix[0], self._src_decoder.length_tensor_name) return name @property def source_text_id_name(self): """The name of the source text index tensor, "source_text_ids" by default. """ name = dsutils._connect_name( self._data_spec.name_prefix[0], self._src_decoder.text_id_tensor_name) return name @property def source_utterance_cnt_name(self): """The name of the source text utterance count tensor, "source_utterance_cnt" by default. """ if not self._hparams.source_dataset.variable_utterance: raise ValueError( "`utterance_cnt_name` of source data is undefined.") name = dsutils._connect_name( self._data_spec.name_prefix[0], self._src_decoder.utterance_cnt_tensor_name) return name @property def target_text_name(self): """The name of the target text tensor, "target_text" bt default. """ name = dsutils._connect_name( self._data_spec.name_prefix[1], self._tgt_decoder.text_tensor_name) return name @property def target_length_name(self): """The name of the target length tensor, "target_length" by default. """ name = dsutils._connect_name( self._data_spec.name_prefix[1], self._tgt_decoder.length_tensor_name) return name @property def target_text_id_name(self): """The name of the target text index tensor, "target_text_ids" by default. """ name = dsutils._connect_name( self._data_spec.name_prefix[1], self._tgt_decoder.text_id_tensor_name) return name @property def target_utterance_cnt_name(self): """The name of the target text utterance count tensor, "target_utterance_cnt" by default. """ if not self._hparams.target_dataset.variable_utterance: raise ValueError( "`utterance_cnt_name` of target data is undefined.") name = dsutils._connect_name( self._data_spec.name_prefix[1], self._tgt_decoder.utterance_cnt_tensor_name) return name @property def text_name(self): """The name of text tensor, "text" by default. """ return self._src_decoder.text_tensor_name @property def length_name(self): """The name of length tensor, "length" by default. """ return self._src_decoder.length_tensor_name @property def text_id_name(self): """The name of text index tensor, "text_ids" by default. """ return self._src_decoder.text_id_tensor_name @property def utterance_cnt_name(self): """The name of the text utterance count tensor, "utterance_cnt" by default. """ if self._hparams.source_dataset.variable_utterance: return self._src_decoder.utterance_cnt_tensor_name if self._hparams.target_dataset.variable_utterance: return self._tgt_decoder.utterance_cnt_tensor_name raise ValueError("`utterance_cnt_name` is not defined.")