# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
BERT classifiers.
"""
import tensorflow as tf
from texar.tf.core.layers import get_layer
from texar.tf.modules.classifiers.classifier_base import ClassifierBase
from texar.tf.modules.encoders.bert_encoder import BERTEncoder
from texar.tf.hyperparams import HParams
from texar.tf.modules.pretrained.bert import PretrainedBERTMixin
from texar.tf.utils.utils import dict_fetch
# pylint: disable=too-many-arguments, invalid-name, no-member,
# pylint: disable=too-many-branches, too-many-locals, too-many-statements
__all__ = [
"BERTClassifier"
]
[docs]class BERTClassifier(ClassifierBase, PretrainedBERTMixin):
r"""Classifier based on BERT modules. Please see
:class:`~texar.tf.modules.PretrainedBERTMixin` for a brief description
of BERT.
This is a combination of the
:class:`~texar.tf.modules.BertEncoder` with a classification
layer. Both step-wise classification and sequence-level classification
are supported, specified in :attr:`hparams`.
Arguments are the same as in
:class:`~texar.tf.modules.BERTEncoder`.
Args:
pretrained_model_name (optional): a `str`, the name
of pre-trained model (e.g., ``bert-base-uncased``). Please refer to
:class:`~texar.tf.modules.PretrainedBERTMixin` for
all supported models.
If `None`, the model name in :attr:`hparams` is used.
cache_dir (optional): the path to a folder in which the
pre-trained models will be cached. If `None` (default),
a default directory (``texar_data`` folder under user's home
directory) will be used.
hparams (dict or HParams, optional): Hyperparameters. Missing
hyperparameters will be set to default values. See
:meth:`default_hparams` for the hyperparameter structure
and default values.
.. document private functions
.. automethod:: _build
"""
_ENCODER_CLASS = BERTEncoder
def __init__(self,
pretrained_model_name=None,
cache_dir=None,
hparams=None):
super(BERTClassifier, self).__init__(hparams=hparams)
with tf.variable_scope(self.variable_scope):
# Creates the underlying encoder
encoder_hparams = dict_fetch(
hparams, BERTEncoder.default_hparams())
if encoder_hparams is not None:
encoder_hparams['name'] = None
self._encoder = BERTEncoder(
pretrained_model_name=pretrained_model_name,
cache_dir=cache_dir,
hparams=encoder_hparams)
# Creates an dropout layer
drop_kwargs = {"rate": self._hparams.dropout}
layer_hparams = {"type": "Dropout", "kwargs": drop_kwargs}
self._dropout_layer = get_layer(hparams=layer_hparams)
# Creates an additional classification layer if needed
self._num_classes = self._hparams.num_classes
if self._num_classes <= 0:
self._logit_layer = None
else:
logit_kwargs = self._hparams.logit_layer_kwargs
if logit_kwargs is None:
logit_kwargs = {}
elif not isinstance(logit_kwargs, HParams):
raise ValueError(
"hparams['logit_layer_kwargs'] must be a dict.")
else:
logit_kwargs = logit_kwargs.todict()
logit_kwargs.update({"units": self._num_classes})
if 'name' not in logit_kwargs:
logit_kwargs['name'] = "logit_layer"
layer_hparams = {"type": "Dense", "kwargs": logit_kwargs}
self._logit_layer = get_layer(hparams=layer_hparams)
[docs] @staticmethod
def default_hparams():
r"""Returns a dictionary of hyperparameters with default values.
.. code-block:: python
{
# (1) Same hyperparameters as in BertEncoder
...
# (2) Additional hyperparameters
"num_classes": 2,
"logit_layer_kwargs": None,
"clas_strategy": "cls_time",
"max_seq_length": None,
"dropout": 0.1,
"name": "bert_classifier"
}
Here:
1. Same hyperparameters as in
:class:`~texar.tf.modules.BertEncoder`.
See the :meth:`~texar.tf.modules.BertEncoder.default_hparams`.
An instance of BertEncoder is created for feature extraction.
2. Additional hyperparameters:
`"num_classes"`: int
Number of classes:
- If **> 0**, an additional :tf_main:`Dense <layers/Dense>`
layer is appended to the encoder to compute the logits over
classes.
- If **<= 0**, no dense layer is appended. The number of
classes is assumed to be the final dense layer size of the
encoder.
`"logit_layer_kwargs"`: dict
Keyword arguments for the logit Dense layer constructor,
except for argument "units" which is set to `num_classes`.
Ignored if no extra logit layer is appended.
`"clas_strategy"`: str
The classification strategy, one of:
- **cls_time**: Sequence-level classification based on the
output of the first time step (which is the `CLS` token).
Each sequence has a class.
- **all_time**: Sequence-level classification based on
the output of all time steps. Each sequence has a class.
- **time_wise**: Step-wise classification, i.e., make
classification for each time step based on its output.
`"max_seq_length"`: int, optional
Maximum possible length of input sequences. Required if
`clas_strategy` is `all_time`.
`"dropout"`: float
The dropout rate of the BERT encoder output.
`"name"`: str
Name of the classifier.
"""
hparams = BERTEncoder.default_hparams()
hparams.update({
"num_classes": 2,
"logit_layer_kwargs": None,
"clas_strategy": "cls_time",
"max_seq_length": None,
"dropout": 0.1,
"name": "bert_classifier"
})
return hparams
[docs] def _build(self,
inputs,
sequence_length=None,
segment_ids=None,
mode=None,
**kwargs):
r"""Feeds the inputs through the network and makes classification.
The arguments are the same as in
:class:`~texar.tf.modules.BertEncoder`.
Args:
inputs: A 2D Tensor of shape `[batch_size, max_time]`,
containing the token ids of tokens in input sequences.
sequence_length (optional): A 1D Tensor of shape `[batch_size]`.
Input tokens beyond respective sequence lengths are masked
out automatically.
segment_ids (optional): A 2D Tensor of shape
`[batch_size, max_time]`, containing the segment ids
of tokens in input sequences. If `None` (default), a tensor
with all elements set to zero is used.
mode (optional): A tensor taking value in
:tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`,
including `TRAIN`, `EVAL`, and `PREDICT`. Used to toggle
dropout.
If `None` (default), :func:`texar.tf.global_mode` is used.
**kwargs: Keyword arguments.
Returns:
A tuple `(logits, pred)`, containing the logits over classes and
the predictions, respectively.
- If "clas_strategy"=="cls_time" or "all_time"
- If "num_classes"==1, `logits` and `pred` are of both \
shape `[batch_size]`
- If "num_classes">1, `logits` is of shape \
`[batch_size, num_classes]` and `pred` is of shape \
`[batch_size]`.
- If "clas_strategy"=="time_wise",
- If "num_classes"==1, `logits` and `pred` are of both \
shape `[batch_size, max_time]`
- If "num_classes">1, `logits` is of shape \
`[batch_size, max_time, num_classes]` and `pred` is of shape \
`[batch_size, max_time]`.
"""
enc_outputs, pooled_output = self._encoder(inputs, sequence_length,
segment_ids, mode)
# Compute logits
stra = self._hparams.clas_strategy
if stra == 'time_wise':
logits = enc_outputs
elif stra == 'cls_time':
logits = pooled_output
elif stra == 'all_time':
# Pad `enc_outputs` to have max_seq_length before flatten
length_diff = self._hparams.max_seq_length - tf.shape(inputs)[1]
length_diff = tf.reshape(length_diff, [1, 1])
# Set `paddings = [[0, 0], [0, length_dif], [0, 0]]`
paddings = tf.pad(length_diff, paddings=[[1, 1], [1, 0]])
logit_input = tf.pad(enc_outputs, paddings=paddings)
logit_input_dim = self._hparams.hidden_size * \
self._hparams.max_seq_length
logits = tf.reshape(logit_input, [-1, logit_input_dim])
else:
raise ValueError('Unknown classification strategy: {}'.format(stra))
if self._logit_layer is not None:
logits = self._dropout_layer(logits, training=mode)
logits = self._logit_layer(logits)
# Compute predications
num_classes = self._hparams.num_classes
is_binary = num_classes == 1
is_binary = is_binary or (num_classes <= 0 and logits.shape[-1] == 1)
if stra == 'time_wise':
if is_binary:
pred = tf.squeeze(tf.greater(logits, 0), -1)
logits = tf.squeeze(logits, -1)
else:
pred = tf.argmax(logits, axis=-1)
else:
if is_binary:
pred = tf.greater(logits, 0)
logits = tf.reshape(logits, [-1])
else:
pred = tf.argmax(logits, axis=-1)
pred = tf.reshape(pred, [-1])
pred = tf.cast(pred, tf.int64)
if not self._built:
self._add_internal_trainable_variables()
if self._logit_layer:
self._add_trainable_variable(
self._logit_layer.trainable_variables)
self._built = True
return logits, pred