Source code for texar.tf.utils.dtypes

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions related to data types.
"""

# pylint: disable=invalid-name, no-member, protected-access

import six
import numpy as np

import tensorflow as tf

__all__ = [
    "get_tf_dtype",
    "is_callable",
    "is_str",
    "is_placeholder",
    "maybe_hparams_to_dict",
    "compat_as_text"
]


[docs]def get_tf_dtype(dtype): # pylint: disable=too-many-return-statements """Returns equivalent tf dtype. Args: dtype: A str, python numeric or string type, numpy data type, or tf dtype. Returns: The corresponding tf dtype. """ if dtype in {'float', 'float32', 'tf.float32', float, np.float32, tf.float32}: return tf.float32 elif dtype in {'float64', 'tf.float64', np.float64, np.float_, tf.float64}: return tf.float64 elif dtype in {'float16', 'tf.float16', np.float16, tf.float16}: return tf.float16 elif dtype in {'int', 'int32', 'tf.int32', int, np.int32, tf.int32}: return tf.int32 elif dtype in {'int64', 'tf.int64', np.int64, tf.int64}: return tf.int64 elif dtype in {'int16', 'tf.int16', np.int16, tf.int16}: return tf.int16 elif dtype in {'bool', 'tf.bool', bool, np.bool_, tf.bool}: return tf.bool elif dtype in {'string', 'str', 'tf.string', str, np.str, tf.string}: return tf.string try: if dtype == {'unicode', unicode}: return tf.string except NameError: pass raise ValueError( "Unsupported conversion from type {} to tf dtype".format(str(dtype)))
[docs]def is_callable(x): """Return `True` if :attr:`x` is callable. """ try: _is_callable = callable(x) except BaseException: # pylint: disable=bare-except _is_callable = hasattr(x, '__call__') return _is_callable
[docs]def is_str(x): """Returns `True` if :attr:`x` is either a str or unicode. Returns `False` otherwise. """ return isinstance(x, six.string_types)
[docs]def is_placeholder(x): """Returns `True` if :attr:`x` is a :tf_main:`tf.placeholder <placeholder>` or :tf_main:`tf.placeholder_with_default <placeholder_with_default>`. """ try: return x._ops.type in ['Placeholder', 'PlaceholderWithDefault'] except BaseException: # pylint: disable=bare-except return False
[docs]def maybe_hparams_to_dict(hparams): """If :attr:`hparams` is an instance of :class:`~texar.tf.HParams`, converts it to a `dict` and returns. If :attr:`hparams` is a `dict`, returns as is. """ if hparams is None: return None if isinstance(hparams, dict): return hparams return hparams.todict()
def _maybe_list_to_array(str_list, dtype_as): if isinstance(dtype_as, (list, tuple)): return type(dtype_as)(str_list) elif isinstance(dtype_as, np.ndarray): return np.array(str_list) else: return str_list
[docs]def compat_as_text(str_): """Converts strings into `unicode` (Python 2) or `str` (Python 3). Args: str_: A string or other data types convertible to string, or an `n`-D numpy array or (possibly nested) list of such elements. Returns: The converted strings of the same structure/shape as :attr:`str_`. """ def _recur_convert(s): if isinstance(s, (list, tuple, np.ndarray)): s_ = [_recur_convert(si) for si in s] return _maybe_list_to_array(s_, s) else: try: return tf.compat.as_text(s) except TypeError: return tf.compat.as_text(str(s)) text = _recur_convert(str_) return text