# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Transformer encoders with multihead self attention.
"""
import tensorflow as tf
from texar.tf.core import layers
from texar.tf.modules.encoders.encoder_base import EncoderBase
from texar.tf.utils.shapes import shape_list
from texar.tf.utils.mode import is_train_mode
from texar.tf.utils import transpose_batch_time
# pylint: disable=too-many-locals, invalid-name, arguments-differ
# pylint: disable=too-many-arguments
__all__ = [
"MultiheadAttentionEncoder"
]
[docs]class MultiheadAttentionEncoder(EncoderBase):
"""Multihead Attention Encoder
Args:
hparams (dict or HParams, optional): Hyperparameters. Missing
hyperparamerter will be set to default values. See
:meth:`default_hparams` for the hyperparameter sturcture and
default values.
.. document private functions
.. automethod:: _build
"""
def __init__(self, hparams=None):
EncoderBase.__init__(self, hparams)
use_bias = self._hparams.use_bias
with tf.variable_scope(self.variable_scope):
if self._hparams.initializer:
tf.get_variable_scope().set_initializer(
layers.get_initializer(self._hparams.initializer))
self.Q_dense = tf.layers.Dense(self._hparams.num_units,
use_bias=use_bias,
name='query')
self.K_dense = tf.layers.Dense(self._hparams.num_units,
use_bias=use_bias,
name='key')
self.V_dense = tf.layers.Dense(self._hparams.num_units,
use_bias=use_bias,
name='value')
self.O_dense = tf.layers.Dense(self._hparams.output_dim,
use_bias=use_bias,
name='output')
[docs] @staticmethod
def default_hparams():
"""Returns a dictionary of hyperparameters with default values.
.. code-block:: python
{
"initializer": None,
'num_heads': 8,
'output_dim': 512,
'num_units': 512,
'dropout_rate': 0.1,
'use_bias': False,
"name": "multihead_attention"
}
Here:
"initializer": dict, optional
Hyperparameters of the default initializer that initializes
variables created in this module.
See :func:`~texar.tf.core.get_initializer` for details.
"num_heads": int
Number of heads for attention calculation.
"output_dim": int
Output dimension of the returned tensor.
"num_units": int
Hidden dimension of the unsplitted attention space.
Should be devisible by `num_heads`.
"dropout_rate: : float
Dropout rate in the attention.
"use_bias": bool
Use bias when projecting the key, value and query.
"name": str
Name of the module.
"""
return {
'initializer': None,
'num_heads': 8,
'output_dim': 512,
'num_units': 512,
'dropout_rate': 0.1,
'use_bias': False,
"name": "multihead_attention",
}
[docs] def _build(self, queries, memory, memory_attention_bias,
cache=None, mode=None):
"""Encodes the inputs.
Args:
queries: A 3d tensor with shape of [batch, length_query,
depth_query].
memory: A 3d tensor with shape of [batch, length_key, depth_key].
memory_attention_bias: A 3d tensor with shape of
[batch, length_key, num_units].
cache: Memory cache only when inferencing the sentence from sractch.
mode (optional): A tensor taking value in
:tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`, including
`TRAIN`, `EVAL` and `PREDICT`. Controls dropout mode.
If `None` (default), :func:`texar.tf.global_mode` is used.
Returns:
A Tensor of shape `[batch_size, max_time, dim]` containing the
encoded vectors.
"""
with tf.variable_scope(self.variable_scope):
num_heads = self._hparams.num_heads
num_units = self._hparams.num_units
if num_units % num_heads:
raise ValueError("Value depth (%d) must be divisible by "
"the number of attention heads (%d)."
% (num_units, num_heads))
def _update_and_return(layer, key):
if memory is None:
# Self Attention
out = layer(queries)
if cache is not None:
# 'decoder self attention when dynamic decoding'
key = 'self_{}'.format(key)
res = cache[key]
if isinstance(res, tf.TensorArray):
# inference-like decoding
# TODO(zhiting): This writing op may cause a bug
# on CPU--it looks the two TensorArray
# cache['self_keys'] and cache['self_values']
# will mix up starting from certain step, causing
# shape mismatch. This op looks fine on GPU.
res = res.write(
res.size(), tf.squeeze(out, axis=[1]))
out = transpose_batch_time(res.stack())
else:
# normal decoding
res = tf.concat([res, out], axis=1)
out = res
cache[key] = res
else:
# encoder decoder attention
if cache is not None:
key = 'memory_{}'.format(key)
res = cache[key]
if isinstance(res, tf.TensorArray):
# inference-like decoding
size = res.size()
false_fn = lambda: transpose_batch_time(res.stack())
else:
# normal decoding
size = tf.shape(res)[1]
false_fn = lambda: res
out = tf.cond(
tf.equal(size, 0),
true_fn=lambda: layer(memory),
false_fn=false_fn)
else:
out = layer(memory)
return out
Q = self.Q_dense(queries)
K = _update_and_return(self.K_dense, 'keys')
V = _update_and_return(self.V_dense, 'values')
Q_ = self._split_heads(Q)
K_ = self._split_heads(K)
V_ = self._split_heads(V)
# [batch_size, num_heads, seq_length, memory_depth]
key_depth_per_head = num_units // num_heads
Q_ *= key_depth_per_head**-0.5
logits = tf.matmul(Q_, K_, transpose_b=True)
if memory_attention_bias is not None:
logits += memory_attention_bias
weights = tf.nn.softmax(logits, name="attention_weights")
weights = tf.layers.dropout(weights,
rate=self._hparams.dropout_rate,
training=is_train_mode(mode))
outputs = tf.matmul(weights, V_)
outputs = self._combine_heads(outputs)
outputs = self.O_dense(outputs)
# (batch_size, length_query, output_dim)
if not self._built:
self._add_internal_trainable_variables()
self._built = True
return outputs
def _split_heads(self, x):
"""Split channels (dimension 2) into multiple heads,
becomes dimension 1).
Must ensure `x.shape[-1]` can be deviced by num_heads
"""
depth = shape_list(x)[-1]
splitted_x = tf.reshape(x, [tf.shape(x)[0], tf.shape(x)[1],
self._hparams.num_heads,
depth // self._hparams.num_heads])
return tf.transpose(splitted_x, [0, 2, 1, 3])
def _combine_heads(self, x):
"""
Args:
x: A Tensor of shape `[batch, num_heads, seq_len, dim]`
Returns:
A Tensor of shape `[batch, seq_len, num_heads * dim]`
"""
t = tf.transpose(x, [0, 2, 1, 3]) # [batch, seq_len, num_heads, dim]
num_heads, dim = shape_list(t)[-2:]
assert num_heads == self._hparams.num_heads
return tf.reshape(t, [tf.shape(t)[0], tf.shape(t)[1], num_heads * dim])