Source code for texar.tf.modules.memory.memory_network

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
End-to-end memory network described in
(Sukhbaatar et al.) End-To-End Memory Networks
"""

import tensorflow as tf

from texar.tf.module_base import ModuleBase
from texar.tf.modules.embedders import WordEmbedder, PositionEmbedder
from texar.tf.utils.mode import switch_dropout
from texar.tf.modules.memory.embed_fns import default_memnet_embed_fn_hparams

# pylint: disable=invalid-name, too-many-instance-attributes, too-many-arguments
# pylint: disable=too-many-locals

__all__ = [
    'MemNetBase',
    'MemNetRNNLike',
]


class MemNetSingleLayer(ModuleBase):
    """An A-C layer for memory network.

    Args:
        H (optional): The matrix :attr:`H` multiplied to :attr:`o` at the end.
        hparams (dict or HParams, optional): Memory network single layer
            hyperparameters. If it is not specified, the default hyperparameter
            setting is used. See :attr:`default_hparams` for the structure and
            default values.
    """

    def __init__(self, H=None, hparams=None):
        ModuleBase.__init__(self, hparams)

        self._H = H

    @staticmethod
    def default_hparams():
        """Returns a dictionary of hyperparameters with default values.

        .. code-block:: python

            {
                "name": "memnet_single_layer"
            }

        Here:

        "name": str
            Name of the memory network single layer.
        """
        return {
            "name": "memnet_single_layer"
        }

    def _build(self, u, m, c, **kwargs):
        """An A-C operation with memory and query vector.

        Args:
            u (Tensor): The input query `Tensor` of shape `[None, memory_dim]`.
            m (Tensor): Output of A operation. Should be in shape
                `[None, memory_size, memory_dim]`.
            c (Tensor): Output of C operation. Should be in shape
                `[None, memory_size, memory_dim]`.

        Returns:
            A `Tensor` of shape same as :attr:`u`.
        """
        # Input memory representation
        p = tf.matmul(m, tf.expand_dims(u, axis=2))
        p = tf.transpose(p, perm=[0, 2, 1])

        p = tf.nn.softmax(p)  # equ. (1)

        # Output memory representation
        o = tf.matmul(p, c)  # equ. (2)
        o = tf.squeeze(o, axis=[1])

        if self._H:
            u = tf.matmul(u, self._H)  # RNN-like style
        u_ = tf.add(u, o)  # u^{k+1} = H u^k + o^k

        if not self._built:
            self._add_internal_trainable_variables()
            if self._H:
                self._add_trainable_variable(self._H)
            self._built = True

        return u_


[docs]class MemNetBase(ModuleBase): """Base class inherited by all memory network classes. Args: raw_memory_dim (int): Dimension size of raw memory entries (before embedding). For example, if a raw memory entry is a word, this is the **vocabulary size** (imagine a one-hot representation of word). If a raw memory entry is a dense vector, this is the dimension size of the vector. input_embed_fn (optional): A callable that embeds raw memory entries as inputs. This corresponds to the `A` embedding operation in (Sukhbaatar et al.) If not provided, a default embedding operation is created as specified in :attr:`hparams`. See :meth:`~texar.tf.modules.MemNetBase.get_default_embed_fn` for details. output_embed_fn (optional): A callable that embeds raw memory entries as outputs. This corresponds to the `C` embedding operation in (Sukhbaatar et al.) If not provided, a default embedding operation is created as specified in :attr:`hparams`. See :meth:`~texar.tf.modules.MemNetBase.get_default_embed_fn` for details. query_embed_fn (optional): A callable that embeds query. This corresponds to the `B` embedding operation in (Sukhbaatar et al.). If not provided and "use_B" is True in :attr:`hparams`, a default embedding operation is created as specified in :attr:`hparams`. See :meth:`~texar.tf.modules.MemNetBase.get_default_embed_fn` for details. Notice: If you'd like to customize this callable, please follow the same number and style of dimensions as in `input_embed_fn` or `output_embed_fn`, and assume that the 2nd dimension of its input and output (which corresponds to `memory_size`) is 1. hparams (dict or HParams, optional): Hyperparameters. Missing hyperparamerter will be set to default values. See :meth:`default_hparams` for the hyperparameter sturcture and default values. """ def __init__(self, raw_memory_dim, input_embed_fn=None, output_embed_fn=None, query_embed_fn=None, hparams=None): ModuleBase.__init__(self, hparams) self._raw_memory_dim = raw_memory_dim self._n_hops = self._hparams.n_hops self._relu_dim = self._hparams.relu_dim self._memory_size = self._hparams.memory_size with tf.variable_scope(self.variable_scope): self._A, self._C, self._B, self._memory_dim = self._build_embed_fn( input_embed_fn, output_embed_fn, query_embed_fn) self.H = None if self.hparams.use_H: self.H = tf.get_variable( name="H", shape=[self._memory_dim, self._memory_dim]) def _build_embed_fn(self, input_embed_fn, output_embed_fn, query_embed_fn): # Optionally creates embed_fn's memory_dim = self.hparams.memory_dim mdim_A, mdim_C, mdim_B = None, None, None A = input_embed_fn if input_embed_fn is None: A, mdim_A = self.get_default_embed_fn( self._memory_size, self._hparams.A) memory_dim = mdim_A C = output_embed_fn if output_embed_fn is None: C, mdim_C = self.get_default_embed_fn( self._memory_size, self._hparams.C) if mdim_A is not None and mdim_A != mdim_C: raise ValueError('Embedding config `A` and `C` must have ' 'the same output dimension.') memory_dim = mdim_C B = query_embed_fn if query_embed_fn is None and self._hparams.use_B: B, mdim_B = self.get_default_embed_fn(1, self._hparams.B) if mdim_A is not None and mdim_A != mdim_B: raise ValueError('Embedding config `A` and `B` must have ' 'the same output dimension.') if mdim_C is not None and mdim_C != mdim_B: raise ValueError('Embedding config `C` and `B` must have ' 'the same output dimension.') memory_dim = mdim_B return A, C, B, memory_dim
[docs] def get_default_embed_fn(self, memory_size, embed_fn_hparams): """Creates a default embedding function. Can be used for A, C, or B operation. For B operation (i.e., query_embed_fn), :attr:`memory_size` must be 1. The function is a combination of both memory embedding and temporal embedding, with the combination method specified by "combine_mode" in the `embed_fn_hparams`. .. role:: python(code) :language: python Args: embed_fn_hparams (dict or HParams): Hyperparameter of the embedding function. See :func:`~texar.tf.modules.default_memnet_embed_fn` for details. Returns: A tuple `(embed_fn, memory_dim)`, where - **`memory_dim`** is the dimension of memory entry embedding, \ inferred from :attr:`embed_fn_hparams`. - If `combine_mode` == 'add', `memory_dim` is the \ embedder dimension. - If `combine_mode` == 'concat', `memory_dim` is the sum \ of the memory embedder dimension and the temporal embedder \ dimension. - **`embed_fn`** is an embedding function that takes in memory \ and returns memory embedding. \ Specifically, the function has signature \ :python:`memory_embedding= embed_fn(memory=None, soft_memory=None)`\ where one of `memory` and `soft_memory` is provided (but not both). Args: memory: An `int` Tensor of shape `[batch_size, memory_size]` containing memory indexes used for embedding lookup. soft_memory: A Tensor of shape `[batch_size, memory_size, raw_memory_dim]` containing soft weights used to mix the embedding vectors. Returns: A Tensor of shape `[batch_size, memory_size, memory_dim]` containing the memory entry embeddings. """ # memory embedder embedder = WordEmbedder( vocab_size=self._raw_memory_dim, hparams=embed_fn_hparams["embedding"] ) # temporal embedder temporal_embedder = PositionEmbedder( position_size=memory_size, hparams=embed_fn_hparams["temporal_embedding"] ) combine = embed_fn_hparams['combine_mode'] if combine == 'add': if embedder.dim != temporal_embedder.dim: raise ValueError('`embedding` and `temporal_embedding` must ' 'have the same dimension for "add" ' 'combination.') memory_dim = embedder.dim elif combine == 'concat': memory_dim = embedder.dim + temporal_embedder.dim def _embed_fn(memory, soft_memory, mode=None): if memory is None and soft_memory is None: raise ValueError( "Either `memory` or `soft_memory` is required.") if memory is not None and soft_memory is not None: raise ValueError( "Must not specify `memory` and `soft_memory` at the " "same time.") embedded_memory = embedder( ids=memory, soft_ids=soft_memory, mode=mode) temporal_embedded = temporal_embedder( sequence_length=tf.constant([memory_size]), mode=mode) temporal_embedded = tf.tile( temporal_embedded, [tf.shape(embedded_memory)[0], 1, 1]) if combine == 'add': return tf.add(embedded_memory, temporal_embedded) elif combine == 'concat': return tf.concat([embedded_memory, temporal_embedded], axis=-1) else: raise ValueError('Unknown combine method: {}'.format(combine)) return _embed_fn, memory_dim
[docs] @staticmethod def default_hparams(): """Returns a dictionary of hyperparameters with default values. .. code-block:: python { "n_hops": 1, "memory_dim": 100, "relu_dim": 50, "memory_size": 100, "A": default_embed_fn_hparams, "C": default_embed_fn_hparams, "B": default_embed_fn_hparams, "use_B": False, "use_H": False, "dropout_rate": 0, "variational": False, "name": "memnet", } Here: "n_hops": int Number of hops. "memory_dim": int Memory dimension, i.e., the dimension size of a memory entry embedding. Ignored if at least one of the embedding functions is created according to :attr:`hparams`. In this case :attr:`memory_dim` is inferred from the created embed_fn. "relu_dim": int Number of elements in :attr:`memory_dim` that have relu at the end of each hop. Should be not less than 0 and not more than :attr`memory_dim`. "memory_size": int Number of entries in memory. For example, the number of sentences {x_i} in Fig.1(a) of (Sukhbaatar et al.) End-To-End Memory Networks. "use_B": bool Whether to create the query embedding function. Ignored if `query_embed_fn` is given to the constructor. "use_H": bool Whether to perform a linear transformation with matrix `H` at the end of each A-C layer. "dropout_rate": float The dropout rate to apply to the output of each hop. Should be between 0 and 1. E.g., `dropout_rate=0.1` would drop out 10% of the units. "variational": bool Whether to share dropout masks after each hop. """ return { "n_hops": 1, "memory_dim": 100, "relu_dim": 50, "memory_size": 100, "A": default_memnet_embed_fn_hparams(), "C": default_memnet_embed_fn_hparams(), "B": default_memnet_embed_fn_hparams(), "use_B": False, "use_H": False, "dropout_rate": 0, "variational": False, "name": "memnet", }
def _build(self, memory, query, **kwargs): raise NotImplementedError @property def memory_size(self): """The memory size. """ return self._memory_size @property def raw_memory_dim(self): """The dimension of memory element (or vocabulary size). """ return self._raw_memory_dim @property def memory_dim(self): """The dimension of embedded memory and all vectors in hops. """ return self._memory_dim
[docs]class MemNetRNNLike(MemNetBase): """An implementation of multi-layer end-to-end memory network, with RNN-like weight tying described in (Sukhbaatar et al.) End-To-End Memory Networks . See :meth:`~texar.tf.modules.MemNetBase.get_default_embed_fn` for default embedding functions. Customized embedding functions must follow the same signature. Args: raw_memory_dim (int): Dimension size of raw memory entries (before embedding). For example, if a raw memory entry is a word, this is the **vocabulary size** (imagine a one-hot representation of word). If a raw memory entry is a dense vector, this is the dimension size of the vector. input_embed_fn (optional): A callable that embeds raw memory entries as inputs. This corresponds to the `A` embedding operation in (Sukhbaatar et al.) If not provided, a default embedding operation is created as specified in :attr:`hparams`. See :meth:`~texar.tf.modules.MemNetBase.get_default_embed_fn` for details. output_embed_fn (optional): A callable that embeds raw memory entries as outputs. This corresponds to the `C` embedding operation in (Sukhbaatar et al.) If not provided, a default embedding operation is created as specified in :attr:`hparams`. See :meth:`~texar.tf.modules.MemNetBase.get_default_embed_fn` for details. query_embed_fn (optional): A callable that embeds query. This corresponds to the `B` embedding operation in (Sukhbaatar et al.). If not provided and "use_B" is True in :attr:`hparams`, a default embedding operation is created as specified in :attr:`hparams`. See :meth:`~texar.tf.modules.MemNetBase.get_default_embed_fn` for details. For customized query_embed_fn, note that the function must follow the signature of the default embed_fn where `memory_size` must be 1. hparams (dict or HParams, optional): Hyperparameters. Missing hyperparamerter will be set to default values. See :meth:`default_hparams` for the hyperparameter sturcture and default values. """ def __init__(self, raw_memory_dim, input_embed_fn=None, output_embed_fn=None, query_embed_fn=None, hparams=None): MemNetBase.__init__(self, raw_memory_dim, input_embed_fn, output_embed_fn, query_embed_fn, hparams) with tf.variable_scope(self.variable_scope): self._AC = MemNetSingleLayer( self.H, hparams={"name": "AC"}) self._W = tf.layers.Dense( units=raw_memory_dim, use_bias=False, name="W")
[docs] @staticmethod def default_hparams(): """Returns a dictionary of hyperparameters with default values. .. code-block:: python { "n_hops": 1, "memory_dim": 100, "relu_dim": 50, "memory_size": 100, "A": default_embed_fn_hparams, "C": default_embed_fn_hparams, "B": default_embed_fn_hparams, "use_B": False, "use_H": True, "dropout_rate": 0, "variational": False, "name": "memnet_rnnlike", } Here: "n_hops": int Number of hops. "memory_dim": int Memory dimension, i.e., the dimension size of a memory entry embedding. Ignored if at least one of the embedding functions is created according to :attr:`hparams`. In this case :attr:`memory_dim` is inferred from the created embed_fn. "relu_dim": int Number of elements in :attr:`memory_dim` that have relu at the end of each hop. Should be not less than 0 and not more than :attr`memory_dim`. "memory_size": int Number of entries in memory. For example, the number of sentences {x_i} in Fig.1(a) of (Sukhbaatar et al.) End-To-End Memory Networks. "use_B": bool Whether to create the query embedding function. Ignored if `query_embed_fn` is given to the constructor. "use_H": bool Whether to perform a linear transformation with matrix `H` at the end of each A-C layer. "dropout_rate": float The dropout rate to apply to the output of each hop. Should be between 0 and 1. E.g., `dropout_rate=0.1` would drop out 10% of the units. "variational": bool Whether to share dropout masks after each hop. """ hparams = MemNetBase.default_hparams() hparams.update({ "use_H": True, "name": "memnet_rnnlike" }) return hparams
def _build(self, memory=None, query=None, soft_memory=None, soft_query=None, mode=None, **kwargs): """Pass the :attr:`memory` and :attr:`query` through the memory network and return the :attr:`logits` after the final matrix. Only one of :attr:`memory` and :attr:`soft_memory` can be specified. They should not be specified at the same time. Args: memory (optional): Memory used in A/C operations. By default, it should be an integer tensor of shape `[batch_size, memory_size]`, containing the ids to embed if provided. query (optional): Query vectors as the intial input of the memory network. If you'd like to apply some transformation (e.g., embedding) on it before it's fed into the network, please set `use_B` to True and add `query_embed_fn` when constructing this instance. If `query_embed_fn` is set to :meth:`~texar.tf.modules.MemNetBase.get_default_embed_fn`, it should be of shape `[batch_size]`. If `use_B` is not set, it should be of shape `[batch_size, memory_dim]`. soft_memory (optional): Soft memory used in A/C operations. By default, it should be a tensor of shape `[batch_size, memory_size, raw_memory_dim]`, containing the weights used to mix the embedding vectors. If you'd like to apply a matrix multiplication on the memory, this option can also be used. soft_query (optional): Query vectors as the intial input of the memory network. If you'd like to apply some transformation (e.g., embedding) on it before it's fed into the network, please set `use_B` to True and add `query_embed_fn` when constructing this instance. Similar to :attr:`soft_memory`, if `query_embed_fn` is set to :meth:`~texar.tf.modules.MemNetBase.get_default_embed_fn`, then it must be of shape `[batch_size, raw_memory_dim]`. Ignored if `use_B` is not set. mode (optional): A tensor taking value in :tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including `TRAIN`, `EVAL`, and `PREDICT`. If `None`, dropout is controlled by :func:`texar.tf.global_mode`. """ if self._B is not None: def _unsqueeze(x): return x if x is None else tf.expand_dims(x, 1) query = tf.squeeze( self._B(_unsqueeze(query), _unsqueeze(soft_query), mode=mode), 1) self._u = [query] self._m = self._A(memory, soft_memory, mode=mode) self._c = self._C(memory, soft_memory, mode=mode) keep_prob = switch_dropout(1 - self.hparams.dropout_rate, mode=mode) if self.hparams.variational: with tf.variable_scope("variational_dropout"): noise = tf.random_uniform(tf.shape(self._u[-1])) random_tensor = keep_prob + noise binary_tensor = tf.floor(random_tensor) def _variational_dropout(val): return tf.math.div(val, keep_prob) * binary_tensor for _ in range(self._n_hops): u_ = self._AC(self._u[-1], self._m, self._c) if self._relu_dim == 0: pass elif self._relu_dim == self._memory_dim: u_ = tf.nn.relu(u_) elif 0 < self._relu_dim < self._memory_dim: linear_part = u_[:, : self._memory_dim - self._relu_dim] relu_part = u_[:, self._memory_dim - self._relu_dim:] relued_part = tf.nn.relu(relu_part) u_ = tf.concat(axis=1, values=[linear_part, relued_part]) else: raise ValueError( "relu_dim = {} is illegal".format(self._relu_dim)) if self.hparams.variational: u_ = _variational_dropout(u_) else: u_ = tf.nn.dropout(u_, keep_prob) self._u.append(u_) logits = self._W(self._u[-1]) if not self._built: self._add_internal_trainable_variables() self._built = True return logits