Source code for texar.tf.modules.encoders.hierarchical_encoders

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various encoders that encode data with hierarchical structure.
"""

import collections

import tensorflow as tf
from tensorflow.contrib.rnn import LSTMStateTuple
from tensorflow.python.util import nest    # pylint: disable=E0611

from texar.tf.modules.encoders.encoder_base import EncoderBase
from texar.tf.utils import utils

# pylint: disable=invalid-name, too-many-arguments, too-many-locals

__all__ = [
    "HierarchicalRNNEncoder"
]


[docs]class HierarchicalRNNEncoder(EncoderBase): """A hierarchical encoder that stacks basic RNN encoders into two layers. Can be used to encode long, structured sequences, e.g. paragraphs, dialog history, etc. Args: encoder_major (optional): An instance of subclass of :class:`~texar.tf.modules.RNNEncoderBase` The high-level encoder taking final states from low-level encoder as its inputs. If not specified, an encoder is created as specified in :attr:`hparams["encoder_major"]`. encoder_minor (optional): An instance of subclass of :class:`~texar.tf.modules.RNNEncoderBase` The low-level encoder. If not specified, an encoder is created as specified in :attr:`hparams["encoder_minor"]`. hparams (dict or HParams, optional): Hyperparameters. Missing hyperparamerter will be set to default values. See :meth:`default_hparams` for the hyperparameter sturcture and default values. See :meth:`_build` for the inputs and outputs of the encoder. .. document private functions .. automethod:: _build """ def __init__(self, encoder_major=None, encoder_minor=None, hparams=None): EncoderBase.__init__(self, hparams) encoder_major_hparams = utils.get_instance_kwargs( None, self._hparams.encoder_major_hparams) encoder_minor_hparams = utils.get_instance_kwargs( None, self._hparams.encoder_minor_hparams) if encoder_major is not None: self._encoder_major = encoder_major else: with tf.variable_scope(self.variable_scope.name): with tf.variable_scope('encoder_major'): self._encoder_major = utils.check_or_get_instance( self._hparams.encoder_major_type, encoder_major_hparams, ['texar.tf.modules.encoders', 'texar.tf.custom']) if encoder_minor is not None: self._encoder_minor = encoder_minor elif self._hparams.config_share: with tf.variable_scope(self.variable_scope.name): with tf.variable_scope('encoder_minor'): self._encoder_minor = utils.check_or_get_instance( self._hparams.encoder_major_type, encoder_major_hparams, ['texar.tf.modules.encoders', 'texar.tf.custom']) else: with tf.variable_scope(self.variable_scope.name): with tf.variable_scope('encoder_minor'): self._encoder_minor = utils.check_or_get_instance( self._hparams.encoder_minor_type, encoder_minor_hparams, ['texar.tf.modules.encoders', 'texar.tf.custom'])
[docs] @staticmethod def default_hparams(): """Returns a dictionary of hyperparameters with default values. .. role:: python(code) :language: python .. code-block:: python { "encoder_major_type": "UnidirectionalRNNEncoder", "encoder_major_hparams": {}, "encoder_minor_type": "UnidirectionalRNNEncoder", "encoder_minor_hparams": {}, "config_share": False, "name": "hierarchical_encoder_wrapper" } Here: "encoder_major_type": str or class or instance The high-level encoder. Can be a RNN encoder class, its name or module path, or a class instance. Ignored if `encoder_major` is given to the encoder constructor. "encoder_major_hparams": dict The hyperparameters for the high-level encoder. The high-level encoder is created with :python:`encoder_class(hparams=encoder_major_hparams)`. Ignored if `encoder_major` is given to the encoder constructor, or if "encoder_major_type" is an encoder instance. "encoder_minor_type": str or class or instance The low-level encoder. Can be a RNN encoder class, its name or module path, or a class instance. Ignored if `encoder_minor` is given to the encoder constructor, or if "config_share" is True. "encoder_minor_hparams": dict The hyperparameters for the low-level encoder. The high-level encoder is created with :python:`encoder_class(hparams=encoder_minor_hparams)`. Ignored if `encoder_minor` is given to the encoder constructor, or if "config_share" is True, or if "encoder_minor_type" is an encoder instance. "config_share": Whether to use encoder_major's hyperparameters to construct encoder_minor. "name": Name of the encoder. """ hparams = { "name": "hierarchical_encoder", "encoder_major_type": "UnidirectionalRNNEncoder", "encoder_major_hparams": {}, "encoder_minor_type": "UnidirectionalRNNEncoder", "encoder_minor_hparams": {}, "config_share": False, "@no_typecheck": [ 'encoder_major_hparams', 'encoder_minor_hparams' ] } hparams.update(EncoderBase.default_hparams()) return hparams
[docs] def _build(self, inputs, order='btu', medium=None, sequence_length_major=None, sequence_length_minor=None, **kwargs): """Encodes the inputs. Args: inputs: A 4-D tensor of shape `[B, T, U, dim]`, where - B: batch_size - T: the max length of high-level sequences. E.g., the max \ number of utterances in dialog history. - U: the max length of low-level sequences. E.g., the max \ length of each utterance in dialog history. - dim: embedding dimension The order of first three dimensions can be changed according to :attr:`order`. order: A 3-char string containing 'b', 't', and 'u', that specifies the order of inputs dimensions above. Following four can be accepted: - **'btu'**: None of the encoders are time-major. - **'utb'**: Both encoders are time-major. - **'tbu'**: The major encoder is time-major. - **'ubt'**: The minor encoder is time-major. medium (optional): A list of callables that subsequently process the final states of minor encoder and obtain the inputs for the major encoder. If not specified, :meth:`flatten` is used for processing the minor's final states. sequence_length_major (optional): The `sequence_length` argument sent to major encoder. This is a 1-D Tensor of shape `[B]`. sequence_length_minor (optional): The `sequence_length` argument sent to minor encoder. It can be either a 1-D Tensor of shape `[B*T]`, or a 2-D Tensor of shape `[B, T]` or `[T, B]` according to :attr:`order`. **kwargs: Other keyword arguments for the major and minor encoders, such as `initial_state`, etc. Note that `sequence_length`, and `time_major` must not be included here. `time_major` is derived from :attr:`order` automatically. By default, arguments will be sent to both major and minor encoders. To specify which encoder an argument should be sent to, add '_minor'/'_major' as its suffix. Note that `initial_state_minor` must have a batch dimension of size `B*T`. If you have an initial state of batch dimension = `T`, use :meth:`tile_initial_state_minor` to tile it according to `order`. Returns: A tuple `(outputs, final_state)` by the major encoder. See the return values of `_build()` method of respective encoder class for details. """ def _kwargs_split(kwargs): kwargs_minor, kwargs_major = {}, {} for k, v in kwargs.items(): if len(k) >= 6 and k[-6:] == ['_minor']: kwargs_minor[k[:-6]] = v if len(k) >= 6 and k[-6:] == ['_major']: kwargs_major[k[:-6]] = v return kwargs_minor, kwargs_major kwargs_minor, kwargs_major = _kwargs_split(kwargs) if sequence_length_minor is not None: sequence_length_minor = tf.reshape(sequence_length_minor, [-1]) kwargs_minor['sequence_length'] = sequence_length_minor kwargs_major['sequence_length'] = sequence_length_major expand, shape = self._get_flatten_order( order, kwargs_minor, kwargs_major, tf.shape(inputs)) inputs = tf.reshape(inputs, shape + [inputs.shape[3]]) _, states_minor = self._encoder_minor(inputs, **kwargs_minor) self.states_minor_before_medium = states_minor if medium is None: states_minor = self.flatten(states_minor) else: if not isinstance(medium, collections.Sequence): medium = [medium] for fn in medium: if isinstance(fn, str) and fn == 'flatten': states_minor = self.flatten(states_minor) else: states_minor = fn(states_minor) self.states_minor_after_medium = states_minor states_minor = tf.reshape( states_minor, tf.concat([expand, tf.shape(states_minor)[1:]], 0)) outputs_major, states_major = self._encoder_major(states_minor, **kwargs_major) # Add trainable variables of `self._cell` which may be constructed # externally if not self._built: self._add_trainable_variable( self._encoder_minor.trainable_variables) self._add_trainable_variable( self._encoder_major.trainable_variables) self._built = True return outputs_major, states_major
[docs] @staticmethod def tile_initial_state_minor(initial_state, order, inputs_shape): """Tiles an initial state to be used for encoder minor. The batch dimension of :attr:`initial_state` must equal `T`. The state will be copied for `B` times and used to start encoding each low-level sequence. For example, the first utterance in each dialog history in the batch will have the same initial state. Args: initial_state: Initial state with the batch dimension of size `T`. order (str): The dimension order of inputs. Must be the same as used in :meth:`_build`. inputs_shape: Shape of `inputs` for :meth:`_build`. Can usually be Obtained with `tf.shape(inputs)`. Returns: A tiled initial state with batch dimension of size `B*T` """ def _nest_tile(t, multiplier): return nest.map_structure(lambda x: tf.tile(x, multiplier), t) if order == 'btu': return _nest_tile(initial_state, inputs_shape[0]) elif order == 'ubt': return _nest_tile(initial_state, inputs_shape[1]) elif order == 'utb': return tf.contrib.seq2seq.tile_batch(initial_state, inputs_shape[2]) elif order == 'tbu': return tf.contrib.seq2seq.tile_batch(initial_state, inputs_shape[1]) else: raise ValueError('Unknown order: {}'.format(order))
@staticmethod def _get_flatten_order(order, kwargs_minor, kwargs_major, shape): if order == 'btu': kwargs_minor.setdefault('time_major', False) kwargs_major.setdefault('time_major', False) expand = shape[0:2] shape = [shape[0] * shape[1], shape[2]] elif order == 'utb': kwargs_minor.setdefault('time_major', True) kwargs_major.setdefault('time_major', True) expand = shape[1:3] shape = [shape[0], shape[1] * shape[2]] elif order == 'tbu': kwargs_minor.setdefault('time_major', False) kwargs_major.setdefault('time_major', True) expand = shape[0:2] shape = [shape[0] * shape[1], shape[2]] elif order == 'ubt': kwargs_minor.setdefault('time_major', True) kwargs_major.setdefault('time_major', False) expand = shape[1:3] shape = [shape[0], shape[1] * shape[2]] else: raise ValueError('Unknown order: {}'.format(order)) return expand, shape
[docs] @staticmethod def flatten(x): """Flattens a cell state by concatenating a sequence of cell states along the last dimension. If the cell states are :tf_main:`LSTMStateTuple <contrib/rnn/LSTMStateTuple>`, only the hidden `LSTMStateTuple.h` is used. This process is used by default if :attr:`medium` is not provided to :meth:`_build`. """ if isinstance(x, LSTMStateTuple): return x.h if isinstance(x, collections.Sequence): return tf.concat( [HierarchicalRNNEncoder.flatten(v) for v in x], -1) else: return x
@property def encoder_major(self): """The high-level encoder. """ return self._encoder_major @property def encoder_minor(self): """The low-level encoder. """ return self._encoder_minor