Source code for texar.tf.module_base

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Base class for modules.
"""

import re

import tensorflow as tf

from texar.tf.utils.exceptions import TexarError
from texar.tf.hyperparams import HParams

__all__ = [
    "ModuleBase"
]


[docs]class ModuleBase(object): """Base class inherited by modules that create Variables and are configurable through hyperparameters. A Texar module inheriting :class:`~texar.tf.ModuleBase` has following key features: - **Convenient variable re-use**: A module instance creates \ its own sets of variables, and automatically re-uses its variables on \ subsequent calls. Hence TF variable/name scope is \ transparent to users. For example: .. code-block:: python encoder = UnidirectionalRNNEncoder(hparams) # create instance output_1 = encoder(inputs_1) # variables are created output_2 = encoder(inputs_2) # variables are re-used print(encoder.trainable_variables) # access trainable variables # [ ... ] - **Configurable through hyperparameters**: Each module defines \ allowed hyperparameters and default values. Hyperparameters not \ specified by users will take default values. - **Callable**: As the above example, a module instance is "called" \ with input tensors and returns output tensors. Every call of a module \ will add ops to the Graph to perform the module's logic. Args: hparams (dict, optional): Hyperparameters of the module. See :meth:`default_hparams` for the structure and default values. .. document private functions .. automethod:: _build """ def __init__(self, hparams=None): if not hasattr(self, '_hparams'): self._hparams = HParams(hparams, self.default_hparams()) else: # Probably already parsed by subclasses. We rely on subclass # implementations to get this right. # As a sanity check, we require `hparams` to be `None` in this case. if hparams is not None: raise ValueError( "`self._hparams` already exists. Argument `hparams` " "must be set to `None` in this case.") self._template = tf.make_template(self._hparams.name, self._build, create_scope_now_=True) self._unique_name = self.variable_scope.name.split("/")[-1] self._trainable_variables = [] self._built = False
[docs] @staticmethod def default_hparams(): """Returns a `dict` of hyperparameters of the module with default values. Used to replace the missing values of input `hparams` during module construction. .. code-block:: python { "name": "module" } """ return { "name": "module" }
[docs] def _build(self, *args, **kwargs): """Subclass must implement this method to build the logic. Args: *args: Arguments. **kwargs: Keyword arguments. Returns: Output Tensor(s). """ raise NotImplementedError
def __call__(self, *args, **kwargs): """Executes the module logic defined in _build method Args: *args: Arguments of _build method. **kwargs: Keyword arguments of _build method. Returns: The output of _build method. """ return self._template(*args, **kwargs) def _add_internal_trainable_variables(self): # pylint: disable=invalid-name """Collects trainable variables constructured internally in this module. This is typically called at the end of `_build()` where all necessary trainable variables have been constructed. """ scope_name = self.variable_scope.name # Escape to handle possible "." characters in the name. # Append a slash to the end to avoid searching scopes that have this # scope name as a prefix. scope_name = re.escape(scope_name) + "/" internal_trainable_variables = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope_name) self._add_trainable_variable(internal_trainable_variables) def _add_trainable_variable(self, variable): """Adds a trainable variable to the trainable variable list of the module. Args: variable: a (list of) trainable variable(s) constructed either internally in the module or constructured outside but used inside the module. """ if isinstance(variable, (list, tuple)): for var in variable: self._add_trainable_variable(var) else: if variable not in self._trainable_variables: self._trainable_variables.append(variable) @property def variable_scope(self): """The variable scope of the module. """ return self._template.variable_scope @property def name(self): """The uniquified name of the module. """ return self._unique_name @property def trainable_variables(self): """The list of trainable variables of the module. """ if not self._built: raise TexarError( "Attempting to access trainable_variables before module %s " "was fully built. The module is built once it is called, " "e.g., with `%s(...)`" % (self.name, self.name)) return self._trainable_variables @property def hparams(self): """An :class:`~texar.tf.HParams` instance. The hyperparameters of the module. """ return self._hparams