Source code for texar.tf.agents.agent_utils

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Various agent utilities.
"""

# pylint: disable=too-many-arguments, too-few-public-methods, no-member
# pylint: disable=invalid-name, wrong-import-position

import numpy as np

gym_utils = None
try:
    from texar.tf.agents import agent_gym_utils as gym_utils
except ImportError:
    pass

__all__ = [
    "Space",
    "EnvConfig"
]


[docs]class Space(object): """Observation and action spaces. Describes valid actions and observations. Similar to :gym:`gym.Space <#spaces>`. Args: shape (optional): Shape of the space, a tuple. If not given, infers from :attr:`low` and :attr:`high`. low (optional): Lower bound (inclusive) of each dimension of the space. Must have shape as specified by :attr:`shape`, and of the same shape with with :attr:`high` (if given). If `None`, set to `-inf` for each dimension. high (optional): Upper bound (inclusive) of each dimension of the space. Must have shape as specified by :attr:`shape`, and of the same shape with with :attr:`low` (if given). If `None`, set to `inf` for each dimension. dtype (optional): Data type of elements in the space. If not given, infers from :attr:`low` (if given) or set to `float`. Example: .. code-block:: python s = Space(low=0, high=10, dtype=np.int32) #s.contains(2) == True #s.contains(10) == True #s.contains(11) == False #s.shape == () s2 = Space(shape=(2,2), high=np.ones([2,2]), dtype=np.float) #s2.low == [[-inf, -inf], [-inf, -inf]] #s2.high == [[1., 1.], [1., 1.]] """ def __init__(self, shape=None, low=None, high=None, dtype=None): if low is None: low = -float('inf') if high is None: high = float('inf') if shape is None: low = np.asarray(low) high = np.asarray(high) if low.shape != high.shape: raise ValueError('`low` and `high` must have the same shape.') shape = low.shape else: shape = tuple(shape) if np.isscalar(low): low = low + np.zeros(shape, dtype=dtype) if np.isscalar(high): high = high + np.zeros(shape, dtype=dtype) if shape != low.shape or shape != high.shape: raise ValueError( 'Shape inconsistent: shape={}, low.shape={}, high.shape={}' .format(shape, low.shape, high.shape)) if dtype is None: dtype = low.dtype dtype = np.dtype(dtype) low = low.astype(dtype) high = high.astype(dtype) self._shape = shape self._low = low self._high = high self._dtype = dtype
[docs] def contains(self, x): """Checks if x is contained in the space. Returns a `bool`. """ x = np.asarray(x) dtype_match = True if self._dtype.kind in np.typecodes['AllInteger']: if x.dtype.kind not in np.typecodes['AllInteger']: dtype_match = False shape_match = x.shape == self._shape low_match = (x >= self._low).all() high_match = (x <= self._high).all() return dtype_match and shape_match and low_match and high_match
@property def shape(self): """Shape of the space. """ return self._shape @property def low(self): """Lower bound of the space. """ return self._low @property def high(self): """Upper bound of the space. """ return self._high @property def dtype(self): """Data type of the element. """ return self._dtype
[docs]class EnvConfig(object): """Configurations of an environment. Args: action_space: An instance of :class:`~texar.tf.agents.Space` or :gym:`gym.Space <#spaces>`, the action space. observ_space: An instance of :class:`~texar.tf.agents.Space` or :gym:`gym.Space <#spaces>`, the observation space. reward_range: A tuple corresponding to the min and max possible rewards, e.g., `reward_range=(-1.0, 1.0)`. """ def __init__(self, action_space, observ_space, reward_range): if gym_utils: action_space = gym_utils.convert_gym_space(action_space) observ_space = gym_utils.convert_gym_space(observ_space) self.action_space = action_space self.action_dtype = action_space.dtype self.action_shape = action_space.shape self.observ_space = observ_space self.observ_dtype = observ_space.dtype self.observ_shape = observ_space.shape self.reward_range = reward_range