Source code for texar.tf.agents.ac_agent

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Actor-critic agent.
"""

import tensorflow as tf
import numpy as np

from texar.tf.agents.episodic_agent_base import EpisodicAgentBase
from texar.tf.utils import utils

# pylint: disable=too-many-instance-attributes, protected-access
# pylint: disable=too-many-arguments

__all__ = [
    "ActorCriticAgent"
]


[docs]class ActorCriticAgent(EpisodicAgentBase): """Actor-critic agent for episodic setting. An actor-critic algorithm consists of several components: - **Actor** is the policy to optimize. As a temporary implementation,\ here by default we use a :class:`~texar.tf.agents.PGAgent` instance \ that wraps a `policy net` and provides proper interfaces to perform \ the role of an actor. - **Critic** that provides learning signals to the actor. Again, as \ a temporary implemetation, here by default we use a \ :class:`~texar.tf.agents.DQNAgent` instance that wraps a `Q net` and \ provides proper interfaces to perform the role of a critic. Args: env_config: An instance of :class:`~texar.tf.agents.EnvConfig` specifying action space, observation space, and reward range, etc. Use :func:`~texar.tf.agents.get_gym_env_config` to create an EnvConfig from a gym environment. sess (optional): A tf session. Can be `None` here and set later with `agent.sess = session`. actor (optional): An instance of :class:`~texar.tf.agents.PGAgent` that performs as actor in the algorithm. If not provided, an actor is created based on :attr:`hparams`. actor_kwargs (dict, optional): Keyword arguments for actor constructor. Note that the `hparams` argument for actor constructor is specified in the "actor_hparams" field of :attr:`hparams` and should not be included in `actor_kwargs`. Ignored if :attr:`actor` is given. critic (optional): An instance of :class:`~texar.tf.agents.DQNAgent` that performs as critic in the algorithm. If not provided, a critic is created based on :attr:`hparams`. critic_kwargs (dict, optional): Keyword arguments for critic constructor. Note that the `hparams` argument for critic constructor is specified in the "critic_hparams" field of :attr:`hparams` and should not be included in `critic_kwargs`. Ignored if :attr:`critic` is given. hparams (dict or HParams, optional): Hyperparameters. Missing hyperparamerters will be set to default values. See :meth:`default_hparams` for the hyperparameter sturcture and default values. """ def __init__(self, env_config, sess=None, actor=None, actor_kwargs=None, critic=None, critic_kwargs=None, hparams=None): EpisodicAgentBase.__init__(self, env_config=env_config, hparams=hparams) self._sess = sess self._num_actions = self._env_config.action_space.high - \ self._env_config.action_space.low with tf.variable_scope(self.variable_scope): if actor is None: kwargs = utils.get_instance_kwargs( actor_kwargs, self._hparams.actor_hparams) kwargs.update(dict(env_config=env_config, sess=sess)) actor = utils.get_instance( class_or_name=self._hparams.actor_type, kwargs=kwargs, module_paths=['texar.tf.agents', 'texar.tf.custom']) self._actor = actor if critic is None: kwargs = utils.get_instance_kwargs( critic_kwargs, self._hparams.critic_hparams) kwargs.update(dict(env_config=env_config, sess=sess)) critic = utils.get_instance( class_or_name=self._hparams.critic_type, kwargs=kwargs, module_paths=['texar.tf.agents', 'texar.tf.custom']) self._critic = critic if self._actor._discount_factor != self._critic._discount_factor: raise ValueError('discount_factor of the actor and the critic ' 'must be the same.') self._discount_factor = self._actor._discount_factor self._observs = [] self._actions = [] self._rewards = []
[docs] @staticmethod def default_hparams(): """Returns a dictionary of hyperparameters with default values: .. role:: python(code) :language: python .. code-block:: python { 'actor_type': 'PGAgent', 'actor_hparams': None, 'critic_type': 'DQNAgent', 'critic_hparams': None, 'name': 'actor_critic_agent' } Here: "actor_type": str or class or instance Actor. Can be class, its name or module path, or a class instance. If class name is given, the class must be from module :mod:`texar.tf.agents` or :mod:`texar.tf.custom`. Ignored if a `actor` is given to the agent constructor. "actor_kwargs": dict, optional Hyperparameters for the actor class. With the :attr:`actor_kwargs` argument to the constructor, an actor is created with :python:`actor_class(**actor_kwargs, hparams=actor_hparams)`. "critic_type": str or class or instance Critic. Can be class, its name or module path, or a class instance. If class name is given, the class must be from module :mod:`texar.tf.agents` or :mod:`texar.tf.custom`. Ignored if a `critic` is given to the agent constructor. "critic_kwargs": dict, optional Hyperparameters for the critic class. With the :attr:`critic_kwargs` argument to the constructor, an critic is created with :python:`critic_class(**critic_kwargs, hparams=critic_hparams)`. "name": str Name of the agent. """ return { 'actor_type': 'PGAgent', 'actor_hparams': None, 'critic_type': 'DQNAgent', 'critic_hparams': None, 'name': 'actor_critic_agent' }
def _reset(self): self._actor._reset() self._critic._reset() def _observe(self, reward, terminal, train_policy, feed_dict): self._train_actor( observ=self._observ, action=self._action, feed_dict=feed_dict) self._critic._observe(reward, terminal, train_policy, feed_dict) def _train_actor(self, observ, action, feed_dict): qvalues = self._critic._qvalues_from_target(observ=observ) advantage = qvalues[0][action] - np.mean(qvalues) # TODO (bowen): should be a funciton to customize? feed_dict_ = { self._actor._observ_inputs: [observ], self._actor._action_inputs: [action], self._actor._advantage_inputs: [advantage] } feed_dict_.update(feed_dict) self._actor._train_policy(feed_dict=feed_dict_)
[docs] def get_action(self, observ, feed_dict=None): self._observ = observ self._action = self._actor.get_action(observ, feed_dict=feed_dict) self._critic._update_observ_action(self._observ, self._action) return self._action
@property def sess(self): """The tf session. """ return self._sess @sess.setter def sess(self, session): self._sess = session self._actor._sess = session self._critic._sess = session