Source code for

# Copyright 2018 The Texar Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
Base class for episodic reinforcement learning agents.

import tensorflow as tf

from import AgentBase

# pylint: disable=too-many-instance-attributes

[docs]class EpisodicAgentBase(AgentBase): """Base class inherited by episodic RL agents. An agent is a wrapper of the **training process** that trains a model with RL algorithms. Agent itself does not create new trainable variables. An episodic RL agent typically provides 3 interfaces, namely, :meth:`reset`, :meth:`get_action` and :meth:`observe`, and is used as the following example. Example: .. code-block:: python env = SomeEnvironment(...) agent = PGAgent(...) while True: # Starts one episode agent.reset() observ = env.reset() while True: action = agent.get_action(observ) next_observ, reward, terminal = env.step(action) agent.observe(reward, terminal) observ = next_observ if terminal: break Args: env_config: An instance of :class:`` specifying action space, observation space, and reward range, etc. Use :func:`` to create an EnvConfig from a gym environment. hparams (dict or HParams, optional): Hyperparameters. Missing hyperparamerter will be set to default values. See :meth:`default_hparams` for the hyperparameter sturcture and default values. """ def __init__(self, env_config, hparams=None): AgentBase.__init__(self, hparams) self._env_config = env_config self._reset_tmplt_fn = tf.make_template( "{}_reset".format(, self._reset) self._observe_tmplt_fn = tf.make_template( "{}_observe".format(, self._observe) self._get_action_tmplt_fn = tf.make_template( "{}_get_action".format(, self._get_action)
[docs] @staticmethod def default_hparams(): """Returns a dictionary of hyperparameters with default values. .. code-block:: python { "name": "agent" } """ return { 'name': 'agent' }
[docs] def reset(self): """Resets the states to begin new episode. """ self._reset_tmplt_fn()
def _reset(self): raise NotImplementedError
[docs] def observe(self, reward, terminal, train_policy=True, feed_dict=None): """Observes experience from environment. Args: reward: Reward of the action. The configuration (e.g., shape) of the reward is defined in :attr:`env_config`. terminal (bool): Whether the episode is terminated. train_policy (bool): Wether to update the policy for this step. feed_dict (dict, optional): Any stuffs fed to running the training operator. """ return self._observe_tmplt_fn(reward, terminal, train_policy, feed_dict)
def _observe(self, reward, terminal, train_policy, feed_dict): raise NotImplementedError
[docs] def get_action(self, observ, feed_dict=None): """Gets action according to observation. Args: observ: Observation from the environment. Returns: action from the policy. """ return self._get_action_tmplt_fn(observ, feed_dict)
def _get_action(self, observ, feed_dict): raise NotImplementedError @property def env_config(self): """Environment configuration. """ return self._env_config