# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class for RNN decoders.
"""
# pylint: disable=too-many-arguments, no-name-in-module
# pylint: disable=too-many-branches, protected-access, too-many-locals
# pylint: disable=arguments-differ, unused-argument
import copy
import tensorflow as tf
from tensorflow.contrib.seq2seq import Decoder as TFDecoder
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import nest
from texar.tf.core import layers
from texar.tf.utils import utils
from texar.tf.utils.mode import is_train_mode, is_train_mode_py
from texar.tf.modules.decoders.dynamic_decode import dynamic_decode
from texar.tf.module_base import ModuleBase
from texar.tf.modules.decoders import rnn_decoder_helpers
from texar.tf.utils.dtypes import is_callable
from texar.tf.utils.shapes import shape_list
from texar.tf.modules.decoders import tf_helpers as tx_helper
__all__ = [
"RNNDecoderBase",
"_make_output_layer"
]
def _make_output_layer_from_tensor(output_layer_tensor, vocab_size,
output_layer_bias, variable_scope):
"""Creates a dense layer from a Tensor. Used to tie word embedding
with the output layer weight.
"""
affine_bias = None
if output_layer_bias:
with tf.variable_scope(variable_scope):
affine_bias = tf.get_variable('affine_bias', [vocab_size])
def _outputs_to_logits(outputs):
shape = shape_list(outputs)
dim = shape[-1]
outputs = tf.reshape(outputs, [-1, dim])
logits = tf.matmul(outputs, output_layer_tensor)
if affine_bias is not None:
logits += affine_bias
logits = tf.reshape(logits, shape[:-1] + [vocab_size])
return logits
return _outputs_to_logits
def _make_output_layer(output_layer, vocab_size,
output_layer_bias, variable_scope):
"""Makes a decoder output layer.
"""
_vocab_size = vocab_size
if is_callable(output_layer):
_output_layer = output_layer
elif tf.contrib.framework.is_tensor(output_layer):
_vocab_size = shape_list(output_layer)[1]
_output_layer = _make_output_layer_from_tensor(
output_layer, _vocab_size, output_layer_bias, variable_scope)
elif output_layer is None:
if _vocab_size is None:
raise ValueError(
"Either `output_layer` or `vocab_size` must be provided. "
"Set `output_layer=tf.identity` if no output layer is "
"wanted.")
with tf.variable_scope(variable_scope):
# pylint: disable=redefined-variable-type
_output_layer = tf.layers.Dense(
units=_vocab_size, use_bias=output_layer_bias)
else:
raise ValueError(
"output_layer should be a callable layer, a tensor, or None. "
"Unsupported type: ", type(output_layer)
)
return _output_layer, _vocab_size
[docs]class RNNDecoderBase(ModuleBase, TFDecoder):
"""Base class inherited by all RNN decoder classes.
See :class:`~texar.tf.modules.BasicRNNDecoder` for the argumenrts.
See :meth:`_build` for the inputs and outputs of RNN decoders in general.
.. document private functions
.. automethod:: _build
"""
def __init__(self,
cell=None,
vocab_size=None,
output_layer=None,
cell_dropout_mode=None,
hparams=None):
ModuleBase.__init__(self, hparams)
self._helper = None
self._initial_state = None
# Make rnn cell
with tf.variable_scope(self.variable_scope):
if cell is not None:
self._cell = cell
else:
self._cell = layers.get_rnn_cell(
self._hparams.rnn_cell, cell_dropout_mode)
self._beam_search_cell = None
# Make the output layer
self._output_layer, self._vocab_size = _make_output_layer(
output_layer, vocab_size, self._hparams.output_layer_bias,
self.variable_scope)
self.max_decoding_length = None
[docs] @staticmethod
def default_hparams():
"""Returns a dictionary of hyperparameters with default values.
The hyperparameters are the same as in
:meth:`~texar.tf.modules.BasicRNNDecoder.default_hparams` of
:class:`~texar.tf.modules.BasicRNNDecoder`, except that the default
"name" here is "rnn_decoder".
"""
return {
"rnn_cell": layers.default_rnn_cell_hparams(),
"helper_train": rnn_decoder_helpers.default_helper_train_hparams(),
"helper_infer": rnn_decoder_helpers.default_helper_infer_hparams(),
"max_decoding_length_train": None,
"max_decoding_length_infer": None,
"name": "rnn_decoder",
"output_layer_bias": True,
}
[docs] def _build(self,
decoding_strategy="train_greedy",
initial_state=None,
inputs=None,
sequence_length=None,
embedding=None,
start_tokens=None,
end_token=None,
softmax_temperature=None,
max_decoding_length=None,
impute_finished=False,
output_time_major=False,
input_time_major=False,
helper=None,
mode=None,
**kwargs):
"""Performs decoding. This is a shared interface for both
:class:`~texar.tf.modules.BasicRNNDecoder` and
:class:`~texar.tf.modules.AttentionRNNDecoder`.
The function provides **3 ways** to specify the
decoding method, with varying flexibility:
1. The :attr:`decoding_strategy` argument: A string taking value of:
- **"train_greedy"**: decoding in teacher-forcing fashion
(i.e., feeding `ground truth` to decode the next step), and each
sample is obtained by taking the `argmax` of the RNN output
logits. Arguments
:attr:`(inputs, sequence_length, input_time_major)` are required
for this strategy, and argument :attr:`embedding` is optional.
- **"infer_greedy"**: decoding in inference fashion (i.e., feeding
the `generated` sample to decode the next step), and each sample
is obtained by taking the `argmax` of the RNN output logits.
Arguments :attr:`(embedding, start_tokens, end_token)` are
required for this strategy, and argument
:attr:`max_decoding_length` is optional.
- **"infer_sample"**: decoding in inference fashion, and each
sample is obtained by `random sampling` from the RNN output
distribution. Arguments
:attr:`(embedding, start_tokens, end_token)` are
required for this strategy, and argument
:attr:`max_decoding_length` is optional.
This argument is used only when argument :attr:`helper` is `None`.
Example:
.. code-block:: python
embedder = WordEmbedder(vocab_size=data.vocab.size)
decoder = BasicRNNDecoder(vocab_size=data.vocab.size)
# Teacher-forcing decoding
outputs_1, _, _ = decoder(
decoding_strategy='train_greedy',
inputs=embedder(data_batch['text_ids']),
sequence_length=data_batch['length']-1)
# Random sample decoding. Gets 100 sequence samples
outputs_2, _, sequence_length = decoder(
decoding_strategy='infer_sample',
start_tokens=[data.vocab.bos_token_id]*100,
end_token=data.vocab.eos.token_id,
embedding=embedder,
max_decoding_length=60)
2. The :attr:`helper` argument: An instance of subclass of
:class:`texar.tf.modules.Helper`. This
provides a superset of decoding strategies than above, for example:
- :class:`~texar.tf.modules.TrainingHelper` corresponding to the
"train_greedy" strategy.
- :class:`~texar.tf.modules.GreedyEmbeddingHelper` and
:class:`~texar.tf.modules.SampleEmbeddingHelper` corresponding to
the "infer_greedy" and "infer_sample", respectively.
- :class:`~texar.tf.modules.TopKSampleEmbeddingHelper` for Top-K
sample decoding.
- :class:`ScheduledEmbeddingTrainingHelper` and
:class:`ScheduledOutputTrainingHelper` for scheduled
sampling.
- :class:`~texar.tf.modules.SoftmaxEmbeddingHelper` and
:class:`~texar.tf.modules.GumbelSoftmaxEmbeddingHelper` for
soft decoding and gradient backpropagation.
Helpers give the maximal flexibility of configuring the decoding
strategy.
Example:
.. code-block:: python
embedder = WordEmbedder(vocab_size=data.vocab.size)
decoder = BasicRNNDecoder(vocab_size=data.vocab.size)
# Teacher-forcing decoding, same as above with
# `decoding_strategy='train_greedy'`
helper_1 = tx.modules.TrainingHelper(
inputs=embedders(data_batch['text_ids']),
sequence_length=data_batch['length']-1)
outputs_1, _, _ = decoder(helper=helper_1)
# Gumbel-softmax decoding
helper_2 = GumbelSoftmaxEmbeddingHelper(
embedding=embedder,
start_tokens=[data.vocab.bos_token_id]*100,
end_token=data.vocab.eos_token_id,
tau=0.1)
outputs_2, _, sequence_length = decoder(
max_decoding_length=60, helper=helper_2)
3. :attr:`hparams["helper_train"]` and :attr:`hparams["helper_infer"]`:
Specifying the helper through hyperparameters. Train and infer
strategy is toggled based on :attr:`mode`. Appriopriate arguments
(e.g., :attr:`inputs`, :attr:`start_tokens`, etc) are selected to
construct the helper. Additional arguments for helper constructor
can be provided either through :attr:`**kwargs`, or through
:attr:`hparams["helper_train/infer"]["kwargs"]`.
This means is used only when both :attr:`decoding_strategy` and
:attr:`helper` are `None`.
Example:
.. code-block:: python
h = {
"helper_infer": {
"type": "GumbelSoftmaxEmbeddingHelper",
"kwargs": { "tau": 0.1 }
}
}
embedder = WordEmbedder(vocab_size=data.vocab.size)
decoder = BasicRNNDecoder(vocab_size=data.vocab.size,
hparams=h)
# Gumbel-softmax decoding
output, _, _ = decoder(
decoding_strategy=None, # Sets to None explicit
embedding=embedder,
start_tokens=[data.vocab.bos_token_id]*100,
end_token=data.vocab.eos_token_id,
max_decoding_length=60,
mode=tf.estimator.ModeKeys.PREDICT)
# PREDICT mode also shuts down dropout
Args:
decoding_strategy (str): A string specifying the decoding
strategy. Different arguments are required based on the
strategy. Ignored if :attr:`helper` is given.
initial_state (optional): Initial state of decoding.
If `None` (default), zero state is used.
inputs (optional): Input tensors for teacher forcing decoding.
Used when :attr:`decoding_strategy` is set to
``"train_greedy"``, or when `hparams`-configured helper is used.
- If `embedding` is `None`, `inputs` is directly fed to
the decoder. E.g., in `"train_greedy"` strategy, `inputs` must
be a 3D Tensor of shape `[batch_size, max_time, emb_dim]` (or
`[max_time, batch_size, emb_dim]` if
`input_time_major` == `True`).
- If `embedding` is given, `inputs` is used as index to look up
embeddings and feed in the decoder. E.g., if `embedding` is
an instance of :class:`~texar.tf.modules.WordEmbedder`,
then :attr:`inputs` is usually a 2D int Tensor
`[batch_size, max_time]` (or `[max_time, batch_size]` if
`input_time_major` == `True`) containing the token indexes.
sequence_length (optional): A 1D int Tensor containing the
sequence length of :attr:`inputs`.
Used when `decoding_strategy="train_greedy"` or
`hparams`-configured helper is used.
embedding (optional): Embedding used when:
- "infer_greedy" or "infer_sample" `decoding_strategy` is
used. This can be a callable or the `params` argument for
:tf_main:`embedding_lookup <nn/embedding_lookup>`.
If a callable, it can take a vector tensor of token `ids`,
or take two arguments (`ids`, `times`), where `ids`
is a vector tensor of token ids, and `times` is a vector
tensor
of time steps (i.e., position ids). The latter case can be
used when attr:`embedding` is a combination of word embedding
and position embedding. `embedding` is required in this case.
- "train_greedy" `decoding_strategy` is used.
This can be a callable or the `params` argument for
:tf_main:`embedding_lookup <nn/embedding_lookup>`.
If a callable, it can take :attr:`inputs` and returns
the input embedding. `embedding` is optional in this case.
start_tokens (optional): A int Tensor of shape `[batch_size]`,
the start tokens. Used when `decoding_strategy="infer_greedy"`
or `"infer_sample"`, or when the helper specified in `hparams`
is used.
Example:
.. code-block:: python
data = tx.data.MonoTextData(hparams)
iterator = DataIterator(data)
batch = iterator.get_next()
bos_token_id = data.vocab.bos_token_id
start_tokens=tf.ones_like(batch['length'])*bos_token_id
end_token (optional): A int 0D Tensor, the token that marks end
of decoding.
Used when `decoding_strategy="infer_greedy"` or
`"infer_sample"`, or when the helper specified in `hparams`
is used.
softmax_temperature (optional): A float 0D Tensor, value to divide
the logits by before computing the softmax. Larger values
(above 1.0) result in more random samples. Must > 0. If `None`,
1.0 is used.
Used when `decoding_strategy="infer_sample"`.
max_decoding_length: A int scalar Tensor indicating the maximum
allowed number of decoding steps. If `None` (default), either
`hparams["max_decoding_length_train"]` or
`hparams["max_decoding_length_infer"]` is used
according to :attr:`mode`.
impute_finished (bool): If `True`, then states for batch
entries which are marked as finished get copied through and
the corresponding outputs get zeroed out. This causes some
slowdown at each time step, but ensures that the final state
and outputs have the correct values and that backprop ignores
time steps that were marked as finished.
output_time_major (bool): If `True`, outputs are returned as
time major tensors. If `False` (default), outputs are returned
as batch major tensors.
input_time_major (optional): Whether the :attr:`inputs` tensor is
time major.
Used when `decoding_strategy="train_greedy"` or
`hparams`-configured helper is used.
helper (optional): An instance of
:class:`texar.tf.modules.Helper`
that defines the decoding strategy. If given,
`decoding_strategy`
and helper configs in :attr:`hparams` are ignored.
mode (str, optional): A string taking value in
:tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`. If
`TRAIN`, training related hyperparameters are used (e.g.,
`hparams['max_decoding_length_train']`), otherwise,
inference related hyperparameters are used (e.g.,
`hparams['max_decoding_length_infer']`).
If `None` (default), `TRAIN` mode is used.
**kwargs: Other keyword arguments for constructing helpers
defined by `hparams["helper_trainn"]` or
`hparams["helper_infer"]`.
Returns:
`(outputs, final_state, sequence_lengths)`, where
- **`outputs`**: an object containing the decoder output on all
time steps.
- **`final_state`**: is the cell state of the final time step.
- **`sequence_lengths`**: is an int Tensor of shape `[batch_size]`
containing the length of each sample.
"""
# Helper
if helper is not None:
pass
elif decoding_strategy is not None:
if decoding_strategy == "train_greedy":
helper = rnn_decoder_helpers._get_training_helper(
inputs, sequence_length, embedding, input_time_major)
elif decoding_strategy == "infer_greedy":
helper = tx_helper.GreedyEmbeddingHelper(
embedding, start_tokens, end_token)
elif decoding_strategy == "infer_sample":
helper = tx_helper.SampleEmbeddingHelper(
embedding, start_tokens, end_token, softmax_temperature)
else:
raise ValueError(
"Unknown decoding strategy: {}".format(decoding_strategy))
else:
if is_train_mode_py(mode):
kwargs_ = copy.copy(self._hparams.helper_train.kwargs.todict())
helper_type = self._hparams.helper_train.type
else:
kwargs_ = copy.copy(self._hparams.helper_infer.kwargs.todict())
helper_type = self._hparams.helper_infer.type
kwargs_.update({
"inputs": inputs,
"sequence_length": sequence_length,
"time_major": input_time_major,
"embedding": embedding,
"start_tokens": start_tokens,
"end_token": end_token,
"softmax_temperature": softmax_temperature})
kwargs_.update(kwargs)
helper = rnn_decoder_helpers.get_helper(helper_type, **kwargs_)
self._helper = helper
# Initial state
if initial_state is not None:
self._initial_state = initial_state
else:
self._initial_state = self.zero_state(
batch_size=self.batch_size, dtype=tf.float32)
# Maximum decoding length
max_l = max_decoding_length
if max_l is None:
max_l_train = self._hparams.max_decoding_length_train
if max_l_train is None:
max_l_train = utils.MAX_SEQ_LENGTH
max_l_infer = self._hparams.max_decoding_length_infer
if max_l_infer is None:
max_l_infer = utils.MAX_SEQ_LENGTH
max_l = tf.cond(is_train_mode(mode),
lambda: max_l_train, lambda: max_l_infer)
self.max_decoding_length = max_l
# Decode
outputs, final_state, sequence_lengths = dynamic_decode(
decoder=self, impute_finished=impute_finished,
maximum_iterations=max_l, output_time_major=output_time_major)
if not self._built:
self._add_internal_trainable_variables()
# Add trainable variables of `self._cell` which may be
# constructed externally.
self._add_trainable_variable(
layers.get_rnn_cell_trainable_variables(self._cell))
if isinstance(self._output_layer, tf.layers.Layer):
self._add_trainable_variable(
self._output_layer.trainable_variables)
# Add trainable variables of `self._beam_search_rnn_cell` which
# may already be constructed and used.
if self._beam_search_cell is not None:
self._add_trainable_variable(
self._beam_search_cell.trainable_variables)
self._built = True
return outputs, final_state, sequence_lengths
def _get_beam_search_cell(self, **kwargs):
self._beam_search_cell = self._cell
return self._cell
def _rnn_output_size(self):
size = self._cell.output_size
if self._output_layer is tf.identity:
return size
else:
# To use layer's compute_output_shape, we need to convert the
# RNNCell's output_size entries into shapes with an unknown
# batch size. We then pass this through the layer's
# compute_output_shape and read off all but the first (batch)
# dimensions to get the output size of the rnn with the layer
# applied to the top.
output_shape_with_unknown_batch = nest.map_structure(
lambda s: tensor_shape.TensorShape([None]).concatenate(s),
size)
layer_output_shape = self._output_layer.compute_output_shape(
output_shape_with_unknown_batch)
return nest.map_structure(lambda s: s[1:], layer_output_shape)
@property
def batch_size(self):
return self._helper.batch_size
@property
def output_size(self):
"""Output size of one step.
"""
raise NotImplementedError
@property
def output_dtype(self):
"""Types of output of one step.
"""
raise NotImplementedError
def initialize(self, name=None):
# Inherits from TFDecoder
# All RNN decoder classes must implement this
raise NotImplementedError
def step(self, time, inputs, state, name=None):
# Inherits from TFDecoder
# All RNN decoder classes must implement this
raise NotImplementedError
def finalize(self, outputs, final_state, sequence_lengths):
# Inherits from TFDecoder
# All RNN decoder classes must implement this
raise NotImplementedError
@property
def cell(self):
"""The RNN cell.
"""
return self._cell
[docs] def zero_state(self, batch_size, dtype):
"""Zero state of the RNN cell.
Equivalent to :attr:`decoder.cell.zero_state`.
"""
return self._cell.zero_state(
batch_size=batch_size, dtype=dtype)
@property
def state_size(self):
"""The state size of decoder cell.
Equivalent to :attr:`decoder.cell.state_size`.
"""
return self.cell.state_size
@property
def vocab_size(self):
"""The vocab size.
"""
return self._vocab_size
@property
def output_layer(self):
"""The output layer.
"""
return self._output_layer