# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility functions related to data types.
"""
# pylint: disable=invalid-name, no-member, protected-access
import six
import numpy as np
import tensorflow as tf
__all__ = [
"get_tf_dtype",
"is_callable",
"is_str",
"is_placeholder",
"maybe_hparams_to_dict",
"compat_as_text"
]
[docs]def get_tf_dtype(dtype): # pylint: disable=too-many-return-statements
"""Returns equivalent tf dtype.
Args:
dtype: A str, python numeric or string type, numpy data type, or
tf dtype.
Returns:
The corresponding tf dtype.
"""
if dtype in {'float', 'float32', 'tf.float32', float,
np.float32, tf.float32}:
return tf.float32
elif dtype in {'float64', 'tf.float64', np.float64, np.float_, tf.float64}:
return tf.float64
elif dtype in {'float16', 'tf.float16', np.float16, tf.float16}:
return tf.float16
elif dtype in {'int', 'int32', 'tf.int32', int, np.int32, tf.int32}:
return tf.int32
elif dtype in {'int64', 'tf.int64', np.int64, tf.int64}:
return tf.int64
elif dtype in {'int16', 'tf.int16', np.int16, tf.int16}:
return tf.int16
elif dtype in {'bool', 'tf.bool', bool, np.bool_, tf.bool}:
return tf.bool
elif dtype in {'string', 'str', 'tf.string', str, np.str, tf.string}:
return tf.string
try:
if dtype == {'unicode', unicode}:
return tf.string
except NameError:
pass
raise ValueError(
"Unsupported conversion from type {} to tf dtype".format(str(dtype)))
[docs]def is_callable(x):
"""Return `True` if :attr:`x` is callable.
"""
try:
_is_callable = callable(x)
except BaseException: # pylint: disable=bare-except
_is_callable = hasattr(x, '__call__')
return _is_callable
[docs]def is_str(x):
"""Returns `True` if :attr:`x` is either a str or unicode. Returns `False`
otherwise.
"""
return isinstance(x, six.string_types)
[docs]def is_placeholder(x):
"""Returns `True` if :attr:`x` is a :tf_main:`tf.placeholder <placeholder>`
or :tf_main:`tf.placeholder_with_default <placeholder_with_default>`.
"""
try:
return x._ops.type in ['Placeholder', 'PlaceholderWithDefault']
except BaseException: # pylint: disable=bare-except
return False
[docs]def maybe_hparams_to_dict(hparams):
"""If :attr:`hparams` is an instance of :class:`~texar.tf.HParams`,
converts it to a `dict` and returns. If :attr:`hparams` is a `dict`,
returns as is.
"""
if hparams is None:
return None
if isinstance(hparams, dict):
return hparams
return hparams.todict()
def _maybe_list_to_array(str_list, dtype_as):
if isinstance(dtype_as, (list, tuple)):
return type(dtype_as)(str_list)
elif isinstance(dtype_as, np.ndarray):
return np.array(str_list)
else:
return str_list
[docs]def compat_as_text(str_):
"""Converts strings into `unicode` (Python 2) or `str` (Python 3).
Args:
str_: A string or other data types convertible to string, or an
`n`-D numpy array or (possibly nested) list of such elements.
Returns:
The converted strings of the same structure/shape as :attr:`str_`.
"""
def _recur_convert(s):
if isinstance(s, (list, tuple, np.ndarray)):
s_ = [_recur_convert(si) for si in s]
return _maybe_list_to_array(s_, s)
else:
try:
return tf.compat.as_text(s)
except TypeError:
return tf.compat.as_text(str(s))
text = _recur_convert(str_)
return text