Source code for texar.tf.data.data.paired_text_data

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Paired text data that consists of source text and target text.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import copy

import tensorflow as tf

from texar.tf.utils import utils
from texar.tf.utils.dtypes import is_callable
from texar.tf.data.data.mono_text_data import _default_mono_text_dataset_hparams
from texar.tf.data.data.text_data_base import TextDataBase
from texar.tf.data.data.mono_text_data import MonoTextData
from texar.tf.data.data_utils import count_file_lines
from texar.tf.data.data import dataset_utils as dsutils
from texar.tf.data.vocabulary import Vocab, SpecialTokens
from texar.tf.data.embedding import Embedding

# pylint: disable=invalid-name, arguments-differ, not-context-manager
# pylint: disable=protected-access, too-many-arguments

__all__ = [
    "_default_paired_text_dataset_hparams",
    "PairedTextData"
]


def _default_paired_text_dataset_hparams():
    """Returns hyperparameters of a paired text dataset with default values.

    See :meth:`texar.tf.data.PairedTextData.default_hparams` for details.
    """
    source_hparams = _default_mono_text_dataset_hparams()
    source_hparams["bos_token"] = None
    source_hparams["data_name"] = "source"
    target_hparams = _default_mono_text_dataset_hparams()
    target_hparams.update(
        {
            "vocab_share": False,
            "embedding_init_share": False,
            "processing_share": False,
            "data_name": "target"
        }
    )
    return {
        "source_dataset": source_hparams,
        "target_dataset": target_hparams
    }


# pylint: disable=too-many-instance-attributes, too-many-public-methods
[docs]class PairedTextData(TextDataBase): """Text data processor that reads parallel source and target text. This can be used in, e.g., seq2seq models. Args: hparams (dict): Hyperparameters. See :meth:`default_hparams` for the defaults. By default, the processor reads raw data files, performs tokenization, batching and other pre-processing steps, and results in a TF Dataset whose element is a python `dict` including six fields: - "source_text": A string Tensor of shape `[batch_size, max_time]` containing the **raw** text toknes of source sequences. `max_time` is the length of the longest sequence in the batch. Short sequences in the batch are padded with **empty string**. By default only EOS token is appended to each sequence. Out-of-vocabulary tokens are **NOT** replaced with UNK. - "source_text_ids": An `int64` Tensor of shape `[batch_size, max_time]` containing the token indexes of source sequences. - "source_length": An `int` Tensor of shape `[batch_size]` containing the length of each source sequence in the batch (including BOS and/or EOS if added). - "target_text": A string Tensor as "source_text" but for target sequences. By default both BOS and EOS are added. - "target_text_ids": An `int64` Tensor as "source_text_ids" but for target sequences. - "target_length": An `int` Tensor of shape `[batch_size]` as "source_length" but for target sequences. If :attr:`'variable_utterance'` is set to `True` in :attr:`'source_dataset'` and/or :attr:`'target_dataset'` of :attr:`hparams`, the corresponding fields "source_*" and/or "target_*" are respectively changed to contain variable utterance text data, as in :class:`~texar.tf.data.MonoTextData`. The above field names can be accessed through :attr:`source_text_name`, :attr:`source_text_id_name`, :attr:`source_length_name`, :attr:`source_utterance_cnt_name`, and those prefixed with `target_`, respectively. Example: .. code-block:: python hparams={ 'source_dataset': {'files': 's', 'vocab_file': 'vs'}, 'target_dataset': {'files': ['t1', 't2'], 'vocab_file': 'vt'}, 'batch_size': 1 } data = PairedTextData(hparams) iterator = DataIterator(data) batch = iterator.get_next() iterator.switch_to_dataset(sess) # initializes the dataset batch_ = sess.run(batch) # batch_ == { # 'source_text': [['source', 'sequence', '<EOS>']], # 'source_text_ids': [[5, 10, 2]], # 'source_length': [3] # 'target_text': [['<BOS>', 'target', 'sequence', '1', '<EOS>']], # 'target_text_ids': [[1, 6, 10, 20, 2]], # 'target_length': [5] # } """ def __init__(self, hparams): TextDataBase.__init__(self, hparams) with tf.name_scope(self.name, self.default_hparams()["name"]): self._make_data()
[docs] @staticmethod def default_hparams(): """Returns a dicitionary of default hyperparameters. .. code-block:: python { # (1) Hyperparams specific to text dataset "source_dataset": { "files": [], "compression_type": None, "vocab_file": "", "embedding_init": {}, "delimiter": " ", "max_seq_length": None, "length_filter_mode": "truncate", "pad_to_max_seq_length": False, "bos_token": None, "eos_token": "<EOS>", "other_transformations": [], "variable_utterance": False, "utterance_delimiter": "|||", "max_utterance_cnt": 5, "data_name": "source", }, "target_dataset": { # ... # Same fields are allowed as in "source_dataset" with the # same default values, except the # following new fields/values: "bos_token": "<BOS>" "vocab_share": False, "embedding_init_share": False, "processing_share": False, "data_name": "target" } # (2) General hyperparams "num_epochs": 1, "batch_size": 64, "allow_smaller_final_batch": True, "shuffle": True, "shuffle_buffer_size": None, "shard_and_shuffle": False, "num_parallel_calls": 1, "prefetch_buffer_size": 0, "max_dataset_size": -1, "seed": None, "name": "paired_text_data", # (3) Bucketing "bucket_boundaries": [], "bucket_batch_sizes": None, "bucket_length_fn": None, } Here: 1. Hyperparameters in the :attr:`"source_dataset"` and attr:`"target_dataset"` fields have the same definition as those in :meth:`texar.tf.data.MonoTextData.default_hparams`, for source and target text, respectively. For the new hyperparameters in "target_dataset": "vocab_share": bool Whether to share the vocabulary of source. If `True`, the vocab file of target is ignored. "embedding_init_share": bool Whether to share the embedding initial value of source. If `True`, :attr:`"embedding_init"` of target is ignored. :attr:`"vocab_share"` must be true to share the embedding initial value. "processing_share": bool Whether to share the processing configurations of source, including "delimiter", "bos_token", "eos_token", and "other_transformations". 2. For the **general** hyperparameters, see :meth:`texar.tf.data.DataBase.default_hparams` for details. 3. For **bucketing** hyperparameters, see :meth:`texar.tf.data.MonoTextData.default_hparams` for details, except that the default bucket_length_fn is the maximum sequence length of source and target sequences. """ hparams = TextDataBase.default_hparams() hparams["name"] = "paired_text_data" hparams.update(_default_paired_text_dataset_hparams()) return hparams
@staticmethod def make_vocab(src_hparams, tgt_hparams): """Reads vocab files and returns source vocab and target vocab. Args: src_hparams (dict or HParams): Hyperparameters of source dataset. tgt_hparams (dict or HParams): Hyperparameters of target dataset. Returns: A pair of :class:`texar.tf.data.Vocab` instances. The two instances may be the same objects if source and target vocabs are shared and have the same other configs. """ src_vocab = MonoTextData.make_vocab(src_hparams) if tgt_hparams["processing_share"]: tgt_bos_token = src_hparams["bos_token"] tgt_eos_token = src_hparams["eos_token"] else: tgt_bos_token = tgt_hparams["bos_token"] tgt_eos_token = tgt_hparams["eos_token"] tgt_bos_token = utils.default_str(tgt_bos_token, SpecialTokens.BOS) tgt_eos_token = utils.default_str(tgt_eos_token, SpecialTokens.EOS) if tgt_hparams["vocab_share"]: if tgt_bos_token == src_vocab.bos_token and \ tgt_eos_token == src_vocab.eos_token: tgt_vocab = src_vocab else: tgt_vocab = Vocab(src_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) else: tgt_vocab = Vocab(tgt_hparams["vocab_file"], bos_token=tgt_bos_token, eos_token=tgt_eos_token) return src_vocab, tgt_vocab @staticmethod def make_embedding(src_emb_hparams, src_token_to_id_map, tgt_emb_hparams=None, tgt_token_to_id_map=None, emb_init_share=False): """Optionally loads source and target embeddings from files (if provided), and returns respective :class:`texar.tf.data.Embedding` instances. """ src_embedding = MonoTextData.make_embedding(src_emb_hparams, src_token_to_id_map) if emb_init_share: tgt_embedding = src_embedding else: tgt_emb_file = tgt_emb_hparams["file"] tgt_embedding = None if tgt_emb_file is not None and tgt_emb_file != "": tgt_embedding = Embedding(tgt_token_to_id_map, tgt_emb_hparams) return src_embedding, tgt_embedding def _make_dataset(self): src_dataset = tf.data.TextLineDataset( self._hparams.source_dataset.files, compression_type=self._hparams.source_dataset.compression_type) tgt_dataset = tf.data.TextLineDataset( self._hparams.target_dataset.files, compression_type=self._hparams.target_dataset.compression_type) return tf.data.Dataset.zip((src_dataset, tgt_dataset)) @staticmethod def _get_name_prefix(src_hparams, tgt_hparams): name_prefix = [ src_hparams["data_name"], tgt_hparams["data_name"]] if name_prefix[0] == name_prefix[1]: raise ValueError("'data_name' of source and target " "datasets cannot be the same.") return name_prefix @staticmethod def _make_processor(src_hparams, tgt_hparams, data_spec, name_prefix): # Create source data decoder data_spec_i = data_spec.get_ith_data_spec(0) src_decoder, src_trans, data_spec_i = MonoTextData._make_processor( src_hparams, data_spec_i, chained=False) data_spec.set_ith_data_spec(0, data_spec_i, 2) # Create target data decoder tgt_proc_hparams = tgt_hparams if tgt_hparams["processing_share"]: tgt_proc_hparams = copy.copy(src_hparams) try: tgt_proc_hparams["variable_utterance"] = \ tgt_hparams["variable_utterance"] except TypeError: tgt_proc_hparams.variable_utterance = \ tgt_hparams["variable_utterance"] data_spec_i = data_spec.get_ith_data_spec(1) tgt_decoder, tgt_trans, data_spec_i = MonoTextData._make_processor( tgt_proc_hparams, data_spec_i, chained=False) data_spec.set_ith_data_spec(1, data_spec_i, 2) tran_fn = dsutils.make_combined_transformation( [[src_decoder] + src_trans, [tgt_decoder] + tgt_trans], name_prefix=name_prefix) data_spec.add_spec(name_prefix=name_prefix) return tran_fn, data_spec @staticmethod def _make_length_filter(src_hparams, tgt_hparams, src_length_name, tgt_length_name, src_decoder, tgt_decoder): src_filter_fn = MonoTextData._make_length_filter( src_hparams, src_length_name, src_decoder) tgt_filter_fn = MonoTextData._make_length_filter( tgt_hparams, tgt_length_name, tgt_decoder) combined_filter_fn = dsutils._make_combined_filter_fn( [src_filter_fn, tgt_filter_fn]) return combined_filter_fn def _process_dataset(self, dataset, hparams, data_spec): name_prefix = PairedTextData._get_name_prefix( hparams["source_dataset"], hparams["target_dataset"]) tran_fn, data_spec = self._make_processor( hparams["source_dataset"], hparams["target_dataset"], data_spec, name_prefix=name_prefix) num_parallel_calls = hparams["num_parallel_calls"] dataset = dataset.map( lambda *args: tran_fn(dsutils.maybe_tuple(args)), num_parallel_calls=num_parallel_calls) # Filters by length src_length_name = dsutils._connect_name( data_spec.name_prefix[0], data_spec.decoder[0].length_tensor_name) tgt_length_name = dsutils._connect_name( data_spec.name_prefix[1], data_spec.decoder[1].length_tensor_name) filter_fn = self._make_length_filter( hparams["source_dataset"], hparams["target_dataset"], src_length_name, tgt_length_name, data_spec.decoder[0], data_spec.decoder[1]) if filter_fn: dataset = dataset.filter(filter_fn) # Truncates data count dataset = dataset.take(hparams["max_dataset_size"]) return dataset, data_spec def _make_bucket_length_fn(self): length_fn = self._hparams.bucket_length_fn if not length_fn: length_fn = lambda x: tf.maximum( x[self.source_length_name], x[self.target_length_name]) elif not is_callable(length_fn): # pylint: disable=redefined-variable-type length_fn = utils.get_function(length_fn, ["texar.tf.custom"]) return length_fn def _make_padded_shapes(self, dataset, src_decoder, tgt_decoder): src_text_and_id_shapes = {} if self._hparams.source_dataset.pad_to_max_seq_length: src_text_and_id_shapes = \ MonoTextData._make_padded_text_and_id_shapes( dataset, self._hparams.source_dataset, src_decoder, self.source_text_name, self.source_text_id_name) tgt_text_and_id_shapes = {} if self._hparams.target_dataset.pad_to_max_seq_length: tgt_text_and_id_shapes = \ MonoTextData._make_padded_text_and_id_shapes( dataset, self._hparams.target_dataset, tgt_decoder, self.target_text_name, self.target_text_id_name) padded_shapes = dataset.output_shapes padded_shapes.update(src_text_and_id_shapes) padded_shapes.update(tgt_text_and_id_shapes) return padded_shapes def _make_data(self): self._src_vocab, self._tgt_vocab = self.make_vocab( self._hparams.source_dataset, self._hparams.target_dataset) tgt_hparams = self._hparams.target_dataset if not tgt_hparams.vocab_share and tgt_hparams.embedding_init_share: raise ValueError("embedding_init can be shared only when vocab " "is shared. Got `vocab_share=False, " "emb_init_share=True`.") self._src_embedding, self._tgt_embedding = self.make_embedding( self._hparams.source_dataset.embedding_init, self._src_vocab.token_to_id_map_py, self._hparams.target_dataset.embedding_init, self._tgt_vocab.token_to_id_map_py, self._hparams.target_dataset.embedding_init_share) # Create dataset dataset = self._make_dataset() dataset, dataset_size = self._shuffle_dataset( dataset, self._hparams, self._hparams.source_dataset.files) self._dataset_size = dataset_size # Processing. data_spec = dsutils._DataSpec( dataset=dataset, dataset_size=self._dataset_size, vocab=[self._src_vocab, self._tgt_vocab], embedding=[self._src_embedding, self._tgt_embedding]) dataset, data_spec = self._process_dataset( dataset, self._hparams, data_spec) self._data_spec = data_spec self._decoder = data_spec.decoder self._src_decoder = data_spec.decoder[0] self._tgt_decoder = data_spec.decoder[1] # Batching length_fn = self._make_bucket_length_fn() padded_shapes = self._make_padded_shapes( dataset, self._src_decoder, self._tgt_decoder) dataset = self._make_batch( dataset, self._hparams, length_fn, padded_shapes) # Prefetching if self._hparams.prefetch_buffer_size > 0: dataset = dataset.prefetch(self._hparams.prefetch_buffer_size) self._dataset = dataset
[docs] def list_items(self): """Returns the list of item names that the data can produce. Returns: A list of strings. """ return list(self._dataset.output_types.keys())
@property def dataset(self): """The dataset. """ return self._dataset
[docs] def dataset_size(self): """Returns the number of data instances in the dataset. Note that this is the total data count in the raw files, before any filtering and truncation. """ if not self._dataset_size: # pylint: disable=attribute-defined-outside-init self._dataset_size = count_file_lines( self._hparams.source_dataset.files) return self._dataset_size
@property def vocab(self): """A pair instances of :class:`~texar.tf.data.Vocab` that are source and target vocabs, respectively. """ return self._src_vocab, self._tgt_vocab @property def source_vocab(self): """The source vocab, an instance of :class:`~texar.tf.data.Vocab`. """ return self._src_vocab @property def target_vocab(self): """The target vocab, an instance of :class:`~texar.tf.data.Vocab`. """ return self._tgt_vocab @property def source_embedding_init_value(self): """The `Tensor` containing the embedding value of source data loaded from file. `None` if embedding is not specified. """ if self._src_embedding is None: return None return self._src_embedding.word_vecs @property def target_embedding_init_value(self): """The `Tensor` containing the embedding value of target data loaded from file. `None` if embedding is not specified. """ if self._tgt_embedding is None: return None return self._tgt_embedding.word_vecs
[docs] def embedding_init_value(self): """A pair of `Tensor` containing the embedding values of source and target data loaded from file. """ src_emb = self.source_embedding_init_value tgt_emb = self.target_embedding_init_value return src_emb, tgt_emb
@property def source_text_name(self): """The name of the source text tensor, "source_text" by default. """ name = dsutils._connect_name( self._data_spec.name_prefix[0], self._src_decoder.text_tensor_name) return name @property def source_length_name(self): """The name of the source length tensor, "source_length" by default. """ name = dsutils._connect_name( self._data_spec.name_prefix[0], self._src_decoder.length_tensor_name) return name @property def source_text_id_name(self): """The name of the source text index tensor, "source_text_ids" by default. """ name = dsutils._connect_name( self._data_spec.name_prefix[0], self._src_decoder.text_id_tensor_name) return name @property def source_utterance_cnt_name(self): """The name of the source text utterance count tensor, "source_utterance_cnt" by default. """ if not self._hparams.source_dataset.variable_utterance: raise ValueError( "`utterance_cnt_name` of source data is undefined.") name = dsutils._connect_name( self._data_spec.name_prefix[0], self._src_decoder.utterance_cnt_tensor_name) return name @property def target_text_name(self): """The name of the target text tensor, "target_text" bt default. """ name = dsutils._connect_name( self._data_spec.name_prefix[1], self._tgt_decoder.text_tensor_name) return name @property def target_length_name(self): """The name of the target length tensor, "target_length" by default. """ name = dsutils._connect_name( self._data_spec.name_prefix[1], self._tgt_decoder.length_tensor_name) return name @property def target_text_id_name(self): """The name of the target text index tensor, "target_text_ids" by default. """ name = dsutils._connect_name( self._data_spec.name_prefix[1], self._tgt_decoder.text_id_tensor_name) return name @property def target_utterance_cnt_name(self): """The name of the target text utterance count tensor, "target_utterance_cnt" by default. """ if not self._hparams.target_dataset.variable_utterance: raise ValueError( "`utterance_cnt_name` of target data is undefined.") name = dsutils._connect_name( self._data_spec.name_prefix[1], self._tgt_decoder.utterance_cnt_tensor_name) return name @property def text_name(self): """The name of text tensor, "text" by default. """ return self._src_decoder.text_tensor_name @property def length_name(self): """The name of length tensor, "length" by default. """ return self._src_decoder.length_tensor_name @property def text_id_name(self): """The name of text index tensor, "text_ids" by default. """ return self._src_decoder.text_id_tensor_name @property def utterance_cnt_name(self): """The name of the text utterance count tensor, "utterance_cnt" by default. """ if self._hparams.source_dataset.variable_utterance: return self._src_decoder.utterance_cnt_tensor_name if self._hparams.target_dataset.variable_utterance: return self._tgt_decoder.utterance_cnt_tensor_name raise ValueError("`utterance_cnt_name` is not defined.")