Source code for texar.tf.agents.seq_pg_agent

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Policy Gradient agent for sequence prediction.
"""

# pylint: disable=too-many-instance-attributes, too-many-arguments, no-member

import tensorflow as tf

from texar.tf.agents.seq_agent_base import SeqAgentBase
from texar.tf.core import optimization as opt
from texar.tf.losses.pg_losses import pg_loss_with_logits
from texar.tf.losses.rewards import discount_reward
from texar.tf.losses.entropy import sequence_entropy_with_logits

__all__ = [
    "SeqPGAgent"
]


[docs]class SeqPGAgent(SeqAgentBase): """Policy Gradient agent for sequence prediction. This is a wrapper of the **training process** that trains a model with policy gradient. Agent itself does not create new trainable variables. Args: samples: An `int` Tensor of shape `[batch_size, max_time]` containing sampled sequences from the model. logits: A float Tenosr of shape `[batch_size, max_time, vocab_size]` containing the logits of samples from the model. sequence_length: A Tensor of shape `[batch_size]`. Time steps beyond the respective sequence lengths are masked out. trainable_variables (optional): Trainable variables of the model to update during training. If `None`, all trainable variables in the graph are used. learning_rate (optional): Learning rate for policy optimization. If not given, determine the learning rate from :attr:`hparams`. See :func:`~texar.tf.core.get_train_op` for more details. sess (optional): A tf session. Can be `None` here and set later with `agent.sess = session`. hparams (dict or HParams, optional): Hyperparameters. Missing hyperparamerter will be set to default values. See :meth:`default_hparams` for the hyperparameter sturcture and default values. Example: .. code-block:: python ## Train a decoder with policy gradient decoder = BasicRNNDecoder(...) outputs, _, sequence_length = decoder( decoding_strategy='infer_sample', ...) sess = tf.Session() agent = SeqPGAgent( samples=outputs.sample_id, logits=outputs.logits, sequence_length=sequence_length, sess=sess) while training: # Generate samples vals = agent.get_samples() # Evaluate reward sample_text = tx.utils.map_ids_to_strs(vals['samples'], vocab) reward_bleu = [] for y, y_ in zip(ground_truth, sample_text) reward_bleu.append(tx.evals.sentence_bleu(y, y_) # Update agent.observe(reward=reward_bleu) """ def __init__(self, samples, logits, sequence_length, trainable_variables=None, learning_rate=None, sess=None, hparams=None): SeqAgentBase.__init__(self, hparams) self._lr = learning_rate # Tensors self._samples = samples self._logits = logits self._sequence_length = sequence_length self._trainable_variables = trainable_variables # Python values self._samples_py = None self._sequence_length_py = None self._rewards = None self._sess = sess # For session partial run self._partial_run_handle = None self._qvalue_inputs_fed = False self._build_graph() def _build_graph(self): with tf.variable_scope(self.variable_scope): self._qvalue_inputs = tf.placeholder( dtype=tf.float32, shape=[None, None], name='qvalue_inputs') self._pg_loss = self._get_pg_loss() self._train_op = self._get_train_op() def _get_pg_loss(self): loss_hparams = self._hparams.loss pg_loss = pg_loss_with_logits( actions=self._samples, logits=self._logits, sequence_length=self._sequence_length, advantages=self._qvalue_inputs, batched=True, average_across_batch=loss_hparams.average_across_batch, average_across_timesteps=loss_hparams.average_across_timesteps, sum_over_batch=loss_hparams.sum_over_batch, sum_over_timesteps=loss_hparams.sum_over_timesteps, time_major=loss_hparams.time_major) if self._hparams.entropy_weight > 0: entropy = self._get_entropy() pg_loss -= self._hparams.entropy_weight * entropy return pg_loss def _get_entropy(self): loss_hparams = self._hparams.loss return sequence_entropy_with_logits( self._logits, sequence_length=self._sequence_length, average_across_batch=loss_hparams.average_across_batch, average_across_timesteps=loss_hparams.average_across_timesteps, sum_over_batch=loss_hparams.sum_over_batch, sum_over_timesteps=loss_hparams.sum_over_timesteps, time_major=loss_hparams.time_major) def _get_train_op(self): train_op = opt.get_train_op( loss=self._pg_loss, variables=self._trainable_variables, learning_rate=self._lr, hparams=self._hparams.optimization.todict()) return train_op
[docs] @staticmethod def default_hparams(): """Returns a dictionary of hyperparameters with default values: .. role:: python(code) :language: python .. code-block:: python { 'discount_factor': 0.95, 'normalize_reward': False, 'entropy_weight': 0., 'loss': { 'average_across_batch': True, 'average_across_timesteps': False, 'sum_over_batch': False, 'sum_over_timesteps': True, 'time_major': False }, 'optimization': default_optimization_hparams(), 'name': 'pg_agent', } Here: "discount_factor": float The discount factor of reward. "normalize_reward": bool Whether to normalize the discounted reward, by `(discounted_reward - mean) / std`. Here `mean` and `std` are over all time steps and all samples in the batch. "entropy_weight": float The weight of entropy loss of the sample distribution, to encourage maximizing the Shannon entropy. Set to 0 to disable the loss. "loss": dict Extra keyword arguments for :func:`~texar.tf.losses.pg_loss_with_logits`, including the reduce arguments (e.g., `average_across_batch`) and `time_major` "optimization": dict Hyperparameters of optimization for updating the policy net. See :func:`~texar.tf.core.default_optimization_hparams` for details. "name": str Name of the agent. """ return { 'discount_factor': 0.95, 'normalize_reward': False, 'entropy_weight': 0., 'loss': { 'average_across_batch': True, 'average_across_timesteps': False, 'sum_over_batch': False, 'sum_over_timesteps': True, 'time_major': False }, 'optimization': opt.default_optimization_hparams(), 'name': 'pg_agent', }
def _get_partial_run_feeds(self, feeds=None): if feeds is None: feeds = [] feeds += [self._qvalue_inputs] return feeds def _setup_partial_run(self, fetches=None, feeds=None): fetches_ = [self._samples, self._sequence_length, self._pg_loss, self._train_op] if fetches is not None: for fet in fetches: if fet not in fetches_: fetches_.append(fet) feeds = self._get_partial_run_feeds(feeds) self._partial_run_handle = self._sess.partial_run_setup( fetches_, feeds=feeds) self._qvalue_inputs_fed = False def _check_extra_fetches(self, extra_fetches): fetch_values = None if extra_fetches is not None: fetch_values = list(extra_fetches.values()) if fetch_values is not None: if self._samples in fetch_values: raise ValueError( "`samples` must not be included in `extra_fetches`. " "It is added automatically.") if self._sequence_length in fetch_values: raise ValueError( "`sequence_length` must not be included in `extra_fetches`." " It is added automatically.") if "samples" in extra_fetches: raise ValueError( "Key 'samples' is preserved and must not be used " "in `extra_fetches`.") if "sequence_length" in extra_fetches: raise ValueError( "Key 'sequence_length' is preserved and must not be used " "in `extra_fetches`.")
[docs] def get_samples(self, extra_fetches=None, feed_dict=None): """Returns sequence samples and extra results. Args: extra_fetches (dict, optional): Extra tensors to fetch values, besides `samples` and `sequence_length`. Same as the `fetches` argument of :tf_main:`tf.Session.run <Session#run>` and tf_main:`partial_run <Session#partial_run>`. feed_dict (dict, optional): A `dict` that maps tensor to values. Note that all placeholder values used in :meth:`get_samples` and subsequent :meth:`observe` calls should be fed here. Returns: A `dict` with keys **"samples"** and **"sequence_length"** containing the fetched values of :attr:`samples` and :attr:`sequence_length`, as well as other fetched values as specified in :attr:`extra_fetches`. Example: .. code-block:: python extra_fetches = {'truth_ids': data_batch['text_ids']} vals = agent.get_samples() sample_text = tx.utils.map_ids_to_strs(vals['samples'], vocab) truth_text = tx.utils.map_ids_to_strs(vals['truth_ids'], vocab) reward = reward_fn_in_python(truth_text, sample_text) """ if self._sess is None: raise ValueError("`sess` must be specified before sampling.") self._check_extra_fetches(extra_fetches) # Sets up partial_run fetch_values = None if extra_fetches is not None: fetch_values = list(extra_fetches.values()) feeds = None if feed_dict is not None: feeds = list(feed_dict.keys()) self._setup_partial_run(fetches=fetch_values, feeds=feeds) # Runs the sampling fetches = { "samples": self._samples, "sequence_length": self._sequence_length } if extra_fetches is not None: fetches.update(extra_fetches) feed_dict_ = feed_dict vals = self._sess.partial_run( self._partial_run_handle, fetches, feed_dict=feed_dict_) self._samples_py = vals['samples'] self._sequence_length_py = vals['sequence_length'] return vals
[docs] def observe(self, reward, train_policy=True, compute_loss=True): """Observes the reward, and updates the policy or computes loss accordingly. Args: reward: A Python array/list of shape `[batch_size]` containing the reward for the samples generated in last call of :meth:`get_samples`. train_policy (bool): Whether to update the policy model according to the reward. compute_loss (bool): If `train_policy` is False, whether to compute the policy gradient loss (but does not update the policy). Returns: If `train_policy` or `compute_loss` is True, returns the loss (a python float scalar). Otherwise returns `None`. """ self._rewards = reward if train_policy: return self._train_policy() elif compute_loss: return self._evaluate_pg_loss() else: return None
def _get_qvalues(self): qvalues = discount_reward( self._rewards, self._sequence_length_py, discount=self._hparams.discount_factor, normalize=self._hparams.normalize_reward) return qvalues def _evaluate_pg_loss(self): fetches = { "loss": self._pg_loss } feed_dict_ = None if not self._qvalue_inputs_fed: qvalues = self._get_qvalues() feed_dict_ = {self._qvalue_inputs: qvalues} vals = self._sess.partial_run( self._partial_run_handle, fetches, feed_dict=feed_dict_) self._qvalue_inputs_fed = True return vals['loss'] def _train_policy(self): """Updates the policy. """ fetches = { "loss": self._train_op, } feed_dict_ = None if not self._qvalue_inputs_fed: qvalues = self._get_qvalues() feed_dict_ = {self._qvalue_inputs: qvalues} vals = self._sess.partial_run( self._partial_run_handle, fetches, feed_dict=feed_dict_) self._qvalue_inputs_fed = True return vals['loss'] @property def sess(self): """The tf session. """ return self._sess @sess.setter def sess(self, sess): self._sess = sess @property def pg_loss(self): """The scalar tensor of policy gradient loss. """ return self._pg_loss @property def sequence_length(self): """The tensor of sample sequence length, of shape `[batch_size]`. """ return self._sequence_length @property def samples(self): """The tensor of sequence samples. """ return self._samples @property def logits(self): """The tensor of sequence logits. """ return self._logits