Source code for texar.tf.modules.pretrained.xlnet

# Copyright 2019 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utils of XLNet Modules.
"""

import collections
import json
import os
import re

from abc import ABCMeta

import tensorflow as tf

from texar.tf.modules.pretrained.pretrained_base import PretrainedMixin

__all__ = [
    "PretrainedXLNetMixin",
]

_XLNET_PATH = "https://storage.googleapis.com/xlnet/released_models/"


[docs]class PretrainedXLNetMixin(PretrainedMixin): r"""A mixin class to support loading pre-trained checkpoints for modules that implement the XLNet model. The XLNet model was proposed in `XLNet: Generalized Autoregressive Pretraining for Language Understanding`_ by `Yang et al.` It is based on the Transformer-XL model, pre-trained on a large corpus using a language modeling objective that considers all permutations of the input sentence. The available XLNet models are as follows: * ``xlnet-based-cased``: 12-layer, 768-hidden, 12-heads. This model is trained on full data (different from the one in the paper). * ``xlnet-large-cased``: 24-layer, 1024-hidden, 16-heads. We provide the following XLNet classes: * :class:`~texar.torch.modules.XLNetEncoder` for text encoding. * :class:`~texar.torch.modules.XLNetDecoder` for text generation and decoding. * :class:`~texar.torch.modules.XLNetClassifier` for text classification and sequence tagging. * :class:`~texar.torch.modules.XLNetRegressor` for text regression. .. _`XLNet: Generalized Autoregressive Pretraining for Language Understanding`: http://arxiv.org/abs/1906.08237 """ __metaclass__ = ABCMeta _MODEL_NAME = "XLNet" _MODEL2URL = { 'xlnet-base-cased': _XLNET_PATH + "cased_L-12_H-768_A-12.zip", 'xlnet-large-cased': _XLNET_PATH + "cased_L-24_H-1024_A-16.zip", } @classmethod def _transform_config(cls, pretrained_model_name, cache_dir): info = list(os.walk(cache_dir)) root, _, files = info[0] config_path = None for file in files: if file.endswith('config.json'): config_path = os.path.join(root, file) if config_path is None: raise ValueError("Cannot find the config file in {}".format( cache_dir)) with open(config_path) as f: config_ckpt = json.loads(f.read()) configs = { "head_dim": config_ckpt["d_head"], "ffn_inner_dim": config_ckpt["d_inner"], "hidden_dim": config_ckpt["d_model"], "activation": config_ckpt["ff_activation"], "num_heads": config_ckpt["n_head"], "num_layers": config_ckpt["n_layer"], "vocab_size": config_ckpt["n_token"], "untie_r": config_ckpt["untie_r"] } return configs def _init_from_checkpoint(self, pretrained_model_name, cache_dir, scope_name, **kwargs): tvars = tf.trainable_variables() init_checkpoint = os.path.join(cache_dir, 'xlnet_model.ckpt') if init_checkpoint: assignment_map, initialized_variable_names = \ self._get_assignment_map_from_checkpoint( tvars, init_checkpoint, scope_name) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) def _get_assignment_map_from_checkpoint(self, tvars, init_checkpoint, scope_name): r""" Compute the union of the current variables and checkpoint variables. Because of the variable scope of the original XLNet and Texar implementation, we need to build a assignment map to match the variables. """ assignment_map = {} initialized_variable_names = {} name_to_variable = collections.OrderedDict() for var in tvars: name = var.name m = re.match("^(.*):\\d+$", name) if m is not None: name = m.group(1) name_to_variable[name] = var init_vars = tf.train.list_variables(init_checkpoint) for check_name, _ in init_vars: check_name_scope = check_name.replace( 'model/transformer/', scope_name + '/') model_name = check_name_scope if check_name.startswith('model/lm_loss/bias'): model_name = scope_name + '/lm_loss/bias' elif check_name.startswith('model/transformer/mask_emb'): model_name = check_name_scope.replace( 'mask_emb/mask_emb', 'mask_emb') elif check_name.startswith('model/transformer/word_embedding'): model_name = scope_name + '/word_embedder/w' elif re.match('model/transformer/r_[r,s,w]_bias', check_name): model_name = check_name_scope elif re.match('model/transformer/seg_embed', check_name): model_name = check_name_scope elif re.match('model/transformer/layer_\\d+/rel_attn/[q,k,v,r,o]', check_name): model_name = check_name_scope elif re.match('model/transformer/layer_\\d+/rel_attn/LayerNorm', check_name): model_name = check_name_scope.replace('LayerNorm/', '') elif re.match('model/transformer/layer_\\d+/ff/layer_[1,2]', check_name): model_name = check_name_scope.replace('ff/layer_1', 'ff/dense') if model_name == check_name_scope: model_name = check_name_scope.replace( 'ff/layer_2', 'ff/dense_1') elif re.match('model/transformer/layer_\\d+/ff/LayerNorm', check_name): model_name = check_name_scope.replace('LayerNorm/', '') if model_name in name_to_variable.keys(): assignment_map[check_name] = model_name initialized_variable_names[model_name] = 1 initialized_variable_names[model_name + ":0"] = 1 else: tf.logging.info('model name:{} not exist'.format(model_name)) return assignment_map, initialized_variable_names