# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions related to mode.
"""
import tensorflow as tf
from texar.tf import context
__all__ = [
"maybe_global_mode",
"is_train_mode",
"is_eval_mode",
"is_predict_mode",
"is_train_mode_py",
"is_eval_mode_py",
"is_predict_mode_py",
"switch_dropout"
]
[docs]def maybe_global_mode(mode):
"""Returns :func:`texar.tf.global_mode` if :attr:`mode` is `None`,
otherwise returns :attr:`mode` as-is.
"""
if mode is None:
return context.global_mode()
else:
return mode
[docs]def is_train_mode(mode):
"""Returns a bool Tensor indicating whether the global mode is TRAIN.
If :attr:`mode` is `None`, the mode is determined by
:func:`texar.tf.global_mode`.
"""
if mode is None:
return context.global_mode_train()
else:
return tf.equal(mode, tf.estimator.ModeKeys.TRAIN)
[docs]def is_eval_mode(mode):
"""Returns a bool Tensor indicating whether the global mode is EVAL.
If :attr:`mode` is `None`, the mode is determined by
:func:`texar.tf.global_mode`.
"""
if mode is None:
return context.global_mode_eval()
else:
return tf.equal(mode, tf.estimator.ModeKeys.EVAL)
[docs]def is_predict_mode(mode):
"""Returns a bool Tensor indicating whether the global mode is PREDICT.
If :attr:`mode` is `None`, the mode is determined by
:func:`texar.tf.global_mode`.
"""
if mode is None:
return context.global_mode_predict()
else:
return tf.equal(mode, tf.estimator.ModeKeys.PREDICT)
[docs]def is_train_mode_py(mode, default=True):
"""Returns a python boolean indicating whether the mode is TRAIN.
Args:
mode: A string taking value in
:tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`.
Can be `None`.
default (bool): The return value when :attr:`mode` is `None`. Default
is `True`.
Returns:
A python boolean.
"""
if mode is None:
return default
if mode not in context.valid_modes():
raise ValueError('Unknown mode: {}'.format(mode))
return mode == tf.estimator.ModeKeys.TRAIN
[docs]def is_eval_mode_py(mode, default=False):
"""Returns a python boolean indicating whether the mode is EVAL.
Args:
mode: A string taking value in
:tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`.
Can be `None`.
default (bool): The return value when :attr:`mode` is `None`. Default
is `False`.
Returns:
A python boolean.
"""
if mode is None:
return default
if mode not in context.valid_modes():
raise ValueError('Unknown mode: {}'.format(mode))
return mode == tf.estimator.ModeKeys.EVAL
[docs]def is_predict_mode_py(mode, default=False):
"""Returns a python boolean indicating whether the mode is PREDICT.
Args:
mode: A string taking value in
:tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`.
Can be `None`.
default (bool): The return value when :attr:`mode` is `None`. Default
is `False`.
Returns:
A python boolean.
"""
if mode is None:
return default
if mode not in context.valid_modes():
raise ValueError('Unknown mode: {}'.format(mode))
return mode == tf.estimator.ModeKeys.PREDICT
[docs]def switch_dropout(dropout_keep_prob, mode=None):
"""Turns off dropout when not in training mode.
Args:
dropout_keep_prob: Dropout keep probability in training mode
mode (optional): A Tensor taking values of
:tf_main:`tf.estimator.ModeKeys <estimator/ModeKeys>`.
Dropout is activated if :attr:`mode` is `TRAIN`.
If `None`, the mode is inferred from
:func:`texar.tf.global_mode`.
Returns:
A unit Tensor that equals the dropout keep probability in `TRAIN` mode,
and `1.0` in other modes.
"""
return 1. - (1. - dropout_keep_prob) \
* tf.cast(is_train_mode(mode), tf.float32)