Source code for texar.tf.modules.decoders.beam_search_decode

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Beam search decoding for RNN decoders.
"""

import tensorflow as tf
from tensorflow.contrib.seq2seq import \
    dynamic_decode, AttentionWrapperState, AttentionWrapper, \
    BeamSearchDecoder, tile_batch

from texar.tf.modules.decoders.rnn_decoder_base import RNNDecoderBase
# pylint: disable=too-many-arguments, protected-access, too-many-locals
# pylint: disable=invalid-name

__all__ = [
    "beam_search_decode"
]


def _get_initial_state(initial_state,
                       tiled_initial_state,
                       cell,
                       batch_size,
                       beam_width,
                       dtype):
    if tiled_initial_state is None:
        if isinstance(initial_state, AttentionWrapperState):
            raise ValueError(
                '`initial_state` must not be an AttentionWrapperState. Use '
                'a plain cell state instead, which will be wrapped into an '
                'AttentionWrapperState automatically.')
        if initial_state is None:
            tiled_initial_state = cell.zero_state(batch_size * beam_width,
                                                  dtype)
        else:
            tiled_initial_state = tile_batch(initial_state,
                                             multiplier=beam_width)

    if isinstance(cell, AttentionWrapper) and \
            not isinstance(tiled_initial_state, AttentionWrapperState):
        zero_state = cell.zero_state(batch_size * beam_width, dtype)
        tiled_initial_state = zero_state.clone(cell_state=tiled_initial_state)

    return tiled_initial_state


[docs]def beam_search_decode(decoder_or_cell, embedding, start_tokens, end_token, beam_width, initial_state=None, tiled_initial_state=None, output_layer=None, length_penalty_weight=0.0, max_decoding_length=None, output_time_major=False, **kwargs): """Performs beam search sampling decoding. Args: decoder_or_cell: An instance of subclass of :class:`~texar.tf.modules.RNNDecoderBase`, or an instance of :tf_main:`RNNCell <contrib/rnn/RNNCell>`. The decoder or RNN cell to perform decoding. embedding: A callable that takes a vector tensor of indexes (e.g., an instance of subclass of :class:`~texar.tf.modules.EmbedderBase`), or the :attr:`params` argument for :tf_main:`tf.nn.embedding_lookup <nn/embedding_lookup>`. start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. end_token: `int32` scalar, the token that marks end of decoding. beam_width (int): Python integer, the number of beams. initial_state (optional): Initial state of decoding. If `None` (default), zero state is used. The state must **not** be tiled with :tf_main:`tile_batch <contrib/seq2seq/tile_batch>`. If you have an already-tiled initial state, use :attr:`tiled_initial_state` instead. In the case of attention RNN decoder, `initial_state` must **not** be an :tf_main:`AttentionWrapperState <contrib/seq2seq/AttentionWrapperState>`. Instead, it must be a state of the wrapped `RNNCell`, which state will be wrapped into `AttentionWrapperState` automatically. Ignored if :attr:`tiled_initial_state` is given. tiled_initial_state (optional): Initial state that has been tiled (typicaly with :tf_main:`tile_batch <contrib/seq2seq/tile_batch>`) so that the batch dimension has size `batch_size * beam_width`. In the case of attention RNN decoder, this can be either a state of the wrapped `RNNCell`, or an `AttentionWrapperState`. If not given, :attr:`initial_state` is used. output_layer (optional): A :tf_main:`Layer <layers/Layer>` instance to apply to the RNN output prior to storing the result or sampling. If `None` and :attr:`decoder_or_cell` is a decoder, the decoder's output layer will be used. length_penalty_weight: Float weight to penalize length. Disabled with `0.0` (default). max_decoding_length (optional): A int scalar Tensor indicating the maximum allowed number of decoding steps. If `None` (default), decoding will continue until the end token is encountered. output_time_major (bool): If `True`, outputs are returned as time major tensors. If `False` (default), outputs are returned as batch major tensors. **kwargs: Other keyword arguments for :tf_main:`dynamic_decode <contrib/seq2seq/dynamic_decode>` except argument `maximum_iterations` which is set to :attr:`max_decoding_length`. Returns: A tuple `(outputs, final_state, sequence_length)`, where - outputs: An instance of :tf_main:`FinalBeamSearchDecoderOutput \ <contrib/seq2seq/FinalBeamSearchDecoderOutput>`. - final_state: An instance of :tf_main:`BeamSearchDecoderState \ <contrib/seq2seq/BeamSearchDecoderState>`. - sequence_length: A Tensor of shape `[batch_size]` containing \ the lengths of samples. Example: .. code-block:: python ## Beam search with basic RNN decoder embedder = WordEmbedder(vocab_size=data.vocab.size) decoder = BasicRNNDecoder(vocab_size=data.vocab.size) outputs, _, _, = beam_search_decode( decoder_or_cell=decoder, embedding=embedder, start_tokens=[data.vocab.bos_token_id] * 100, end_token=data.vocab.eos_token_id, beam_width=5, max_decoding_length=60) sample_ids = sess.run(outputs.predicted_ids) sample_text = tx.utils.map_ids_to_strs(sample_id[:,:,0], data.vocab) print(sample_text) # [ # the first sequence sample . # the second sequence sample . # ... # ] .. code-block:: python ## Beam search with attention RNN decoder # Encodes the source enc_embedder = WordEmbedder(data.source_vocab.size, ...) encoder = UnidirectionalRNNEncoder(...) enc_outputs, enc_state = encoder( inputs=enc_embedder(data_batch['source_text_ids']), sequence_length=data_batch['source_length']) # Decodes while attending to the source dec_embedder = WordEmbedder(vocab_size=data.target_vocab.size, ...) decoder = AttentionRNNDecoder( memory=enc_outputs, memory_sequence_length=data_batch['source_length'], vocab_size=data.target_vocab.size) # Beam search outputs, _, _, = beam_search_decode( decoder_or_cell=decoder, embedding=dec_embedder, start_tokens=[data.vocab.bos_token_id] * 100, end_token=data.vocab.eos_token_id, beam_width=5, initial_state=enc_state, max_decoding_length=60) """ if isinstance(decoder_or_cell, RNNDecoderBase): cell = decoder_or_cell._get_beam_search_cell(beam_width=beam_width) elif isinstance(decoder_or_cell, tf.contrib.rnn.RNNCell): cell = decoder_or_cell else: raise ValueError("`decoder` must be an instance of a subclass of " "either `RNNDecoderBase` or `RNNCell`.") start_tokens = tf.convert_to_tensor( start_tokens, dtype=tf.int32, name="start_tokens") if start_tokens.get_shape().ndims != 1: raise ValueError("`start_tokens` must be a vector") batch_size = tf.size(start_tokens) initial_state = _get_initial_state( initial_state, tiled_initial_state, cell, batch_size, beam_width, tf.float32) if output_layer is None and isinstance(decoder_or_cell, RNNDecoderBase): output_layer = decoder_or_cell.output_layer def _decode(): beam_docoder = BeamSearchDecoder( cell=cell, embedding=embedding, start_tokens=start_tokens, end_token=end_token, initial_state=initial_state, beam_width=beam_width, output_layer=None if output_layer is tf.identity else output_layer, length_penalty_weight=length_penalty_weight) if 'maximum_iterations' in kwargs: raise ValueError('Use `max_decoding_length` to set the maximum ' 'allowed number of decoding steps.') outputs, final_state, _ = dynamic_decode( decoder=beam_docoder, output_time_major=output_time_major, maximum_iterations=max_decoding_length, **kwargs) return outputs, final_state, final_state.lengths if isinstance(decoder_or_cell, RNNDecoderBase): vs = decoder_or_cell.variable_scope with tf.variable_scope(vs, reuse=tf.AUTO_REUSE): return _decode() else: return _decode()