# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class for seq2seq models.
"""
import tensorflow as tf
from texar.tf.models.model_base import ModelBase
from texar.tf.losses.mle_losses import sequence_sparse_softmax_cross_entropy
from texar.tf.data.data.paired_text_data import PairedTextData
from texar.tf.core.optimization import get_train_op
from texar.tf.hyperparams import HParams
from texar.tf.utils import utils
from texar.tf.utils.variables import collect_trainable_variables
# pylint: disable=too-many-instance-attributes, unused-argument,
# pylint: disable=too-many-arguments, no-self-use
__all__ = [
"Seq2seqBase"
]
[docs]class Seq2seqBase(ModelBase):
"""Base class inherited by all seq2seq model classes.
.. document private functions
.. automethod:: _build
"""
def __init__(self, data_hparams, hparams=None):
ModelBase.__init__(self, hparams)
self._data_hparams = HParams(data_hparams,
PairedTextData.default_hparams())
self._src_vocab = None
self._tgt_vocab = None
self._src_embedder = None
self._tgt_embedder = None
self._connector = None
self._encoder = None
self._decoder = None
[docs] @staticmethod
def default_hparams():
"""Returns a dictionary of hyperparameters with default values.
.. code-block:: python
{
"source_embedder": "WordEmbedder",
"source_embedder_hparams": {},
"target_embedder": "WordEmbedder",
"target_embedder_hparams": {},
"embedder_share": True,
"embedder_hparams_share": True,
"encoder": "UnidirectionalRNNEncoder",
"encoder_hparams": {},
"decoder": "BasicRNNDecoder",
"decoder_hparams": {},
"decoding_strategy_train": "train_greedy",
"decoding_strategy_infer": "infer_greedy",
"beam_search_width": 0,
"connector": "MLPTransformConnector",
"connector_hparams": {},
"optimization": {},
"name": "seq2seq",
}
Here:
"source_embedder": str or class or instance
Word embedder for source text. Can be a class, its name or module
path, or a class instance.
"source_embedder_hparams": dict
Hyperparameters for constructing the source embedder. E.g.,
See :meth:`~texar.tf.modules.WordEmbedder.default_hparams` for
hyperparameters of :class:`~texar.tf.modules.WordEmbedder`. Ignored
if "source_embedder" is an instance.
"target_embedder", "target_embedder_hparams":
Same as "source_embedder" and "source_embedder_hparams" but for
target text embedder.
"embedder_share": bool
Whether to share the source and target embedder. If `True`,
source embedder will be used to embed target text.
"embedder_hparams_share": bool
Whether to share the embedder configurations. If `True`,
target embedder will be created with "source_embedder_hparams".
But the two embedders have different set of trainable variables.
"encoder", "encoder_hparams":
Same as "source_embedder" and "source_embedder_hparams" but for
encoder.
"decoder", "decoder_hparams":
Same as "source_embedder" and "source_embedder_hparams" but for
decoder.
"decoding_strategy_train": str
The decoding strategy in training mode. See
:meth:`~texar.tf.modules.RNNDecoderBase._build` for details.
"decoding_strategy_infer": str
The decoding strategy in eval/inference mode.
"beam_search_width": int
Beam width. If > 1, beam search is used in eval/inference mode.
"connector", "connector_hparams":
The connector class and hyperparameters. A connector transforms
an encoder final state to a decoder initial state.
"optimization": dict
Hyperparameters of optimizating the model. See
:func:`~texar.tf.core.default_optimization_hparams` for details.
"name": str
Name of the model.
"""
hparams = ModelBase.default_hparams()
hparams.update({
"name": "seq2seq",
"source_embedder": "WordEmbedder",
"source_embedder_hparams": {},
"target_embedder": "WordEmbedder",
"target_embedder_hparams": {},
"embedder_share": True,
"embedder_hparams_share": True,
"encoder": "UnidirectionalRNNEncoder",
"encoder_hparams": {},
"decoder": "BasicRNNDecoder",
"decoder_hparams": {},
"decoding_strategy_train": "train_greedy",
"decoding_strategy_infer": "infer_greedy",
"beam_search_width": 0,
"connector": "MLPTransformConnector",
"connector_hparams": {},
"optimization": {}
})
return hparams
def _build_vocab(self):
self._src_vocab, self._tgt_vocab = PairedTextData.make_vocab(
self._data_hparams.source_dataset,
self._data_hparams.target_dataset)
def _build_embedders(self):
kwargs = {
"vocab_size": self._src_vocab.size,
"hparams": self._hparams.source_embedder_hparams.todict()
}
self._src_embedder = utils.check_or_get_instance(
self._hparams.source_embedder, kwargs,
["texar.tf.modules", "texar.tf.custom"])
if self._hparams.embedder_share:
self._tgt_embedder = self._src_embedder
else:
kwargs = {
"vocab_size": self._tgt_vocab.size,
}
if self._hparams.embedder_hparams_share:
kwargs["hparams"] = \
self._hparams.source_embedder_hparams.todict()
else:
kwargs["hparams"] = \
self._hparams.target_embedder_hparams.todict()
self._tgt_embedder = utils.check_or_get_instance(
self._hparams.target_embedder, kwargs,
["texar.tf.modules", "texar.tf.custom"])
def _build_encoder(self):
kwargs = {
"hparams": self._hparams.encoder_hparams.todict()
}
self._encoder = utils.check_or_get_instance(
self._hparams.encoder, kwargs,
["texar.tf.modules", "texar.tf.custom"])
def _build_decoder(self):
raise NotImplementedError
def _build_connector(self):
kwargs = {
"output_size": self._decoder.state_size,
"hparams": self._hparams.connector_hparams.todict()
}
self._connector = utils.check_or_get_instance(
self._hparams.connector, kwargs,
["texar.tf.modules", "texar.tf.custom"])
[docs] def get_loss(self, decoder_results, features, labels):
"""Computes the training loss.
"""
return sequence_sparse_softmax_cross_entropy(
labels=labels['target_text_ids'][:, 1:],
logits=decoder_results['outputs'].logits,
sequence_length=decoder_results['sequence_length'])
def _get_predictions(self, decoder_results, features, labels, loss=None):
raise NotImplementedError
def _get_train_op(self, loss):
varlist = collect_trainable_variables(
[self._src_embedder, self._tgt_embedder, self._encoder,
self._connector, self._decoder])
return get_train_op(
loss, variables=varlist, hparams=self._hparams.optimization)
def _get_eval_metric_ops(self, decoder_results, features, labels):
return None
[docs] def embed_source(self, features, labels, mode):
"""Embeds the inputs.
"""
raise NotImplementedError
[docs] def embed_target(self, features, labels, mode):
"""Embeds the target inputs. Used in training.
"""
raise NotImplementedError
[docs] def encode(self, features, labels, mode):
"""Encodes the inputs.
"""
raise NotImplementedError
def _connect(self, encoder_results, features, labels, mode):
"""Transforms encoder final state into decoder initial state.
"""
raise NotImplementedError
[docs] def decode(self, encoder_results, features, labels, mode):
"""Decodes.
"""
raise NotImplementedError
[docs] def _build(self, features, labels, params, mode, config=None):
self._build_vocab()
self._build_embedders()
self._build_encoder()
self._build_decoder()
self._build_connector()
encoder_results = self.encode(features, labels, mode)
decoder_results = self.decode(encoder_results, features, labels, mode)
loss, train_op, preds, eval_metric_ops = None, None, None, None
if mode == tf.estimator.ModeKeys.PREDICT:
preds = self._get_predictions(decoder_results, features, labels)
else:
loss = self.get_loss(decoder_results, features, labels)
if mode == tf.estimator.ModeKeys.TRAIN:
train_op = self._get_train_op(loss)
if mode == tf.estimator.ModeKeys.EVAL:
eval_metric_ops = self._get_eval_metric_ops(
decoder_results, features, labels)
preds = self._get_predictions(decoder_results, features, labels,
loss)
return tf.estimator.EstimatorSpec(
mode=mode,
predictions=preds,
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)