Source code for texar.tf.utils.average_recorder

# Copyright 2018 The Texar Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for maintaining moving average.
"""

from collections import deque

# pylint: disable=invalid-name

__all__ = [
    "_SingleAverageRecorder",
    "AverageRecorder"
]


class _SingleAverageRecorder(object):
    """Maintains the moving average (i.e., the average of the latest N records)
    of a single metric.

    Args:
        size (int, optional): The window size of moving average. If `None`,
            the average of all added records is maintained.
        name (str, optional): name of the recorder. Used when printing.
    """

    def __init__(self, size=None, name=None):
        if size is not None and size <= 0:
            raise ValueError("`size` must be > 0 or `None`.")
        self._size = size
        self._q = deque([])
        self._w = deque([])
        self._sum = 0.
        self._w_sum = 0
        self._name = name

    def add(self, record, weight=None):
        """Appends a new record.

        Args:
            record: A scalar; the new record to append.
            weight (optional): A scalar, weight of the new record for
                calculating a weighted average. If `None`, weight is set to `1`.
                For example, :attr:`weight` can be set to batch size and
                :attr:`record` the average value of certain metric on the batch
                in order to calculate the average metric value on a whole
                dataset.

        Returns:
            The (moving) average after appending the record.
        """
        w = weight if weight is not None else 1
        self._w_sum += w
        self._sum += record * w

        if self._size is not None:
            if len(self._q) == self._size:
                w_pop = self._w.popleft()
                self._sum -= self._q.popleft() * w_pop
                self._w_sum -= w_pop
            self._q.append(record)
            self._w.append(w)

        return self.avg()

    def avg(self):
        """Returns the (moving) average.
        """
        if self._w_sum == 0:
            return 0.
        return self._sum / self._w_sum

    def reset(self):
        """Cleans all records.
        """
        self._q.clear()
        self._w.clear()
        self._sum = 0.
        self._w_sum = 0

    def to_str(self, precision=None):
        """Returns a string of the average value.

        Args:
            precision (int, optional): The number of decimal places to keep in
                the returned string. E.g., for an average value of `0.1234`,
                :attr:`precision = 2` leads to `'0.12'`.

        Returns:
            A string of the average value. If :meth:`name` is given, the
            string is of the format like `'name: 0.1234'`, otherwise
            the string is of the format like `'0.1234'`.
        """
        prec_str = "{}"
        if precision is not None:
            prec_str = "{:.%df}" % precision

        avg_str = prec_str.format(self.avg())
        if self._name is not None:
            avg_str = "{}: {}".format(self._name, avg_str)

        return avg_str

    @property
    def name(self):
        """The name of the recorder.
        """
        return self.name


[docs]class AverageRecorder(object): """Maintains the moving averages (i.e., the average of the latest N records) of (possibly multiple) fields. Fields are determined by the first call of :meth:`add`. Args: size (int, optional): The window size of moving average. If `None`, the average of all added records is maintained. Example: .. code-block:: python ## Use to maintain moving average of training loss avg_rec = AverageRecorder(size=10) # average over latest 10 records while training: loss_0, loss_1 = ... avg_rec.add([loss_0, loss_1]) # avg_rec.avg() == [0.12343452, 0.567800323] # avg_rec.avg(0) == 0.12343452 # avg_rec.to_str(precision=2, ) == '0.12 0.57' ## Use to maintain average of test metrics on the whole test set avg_rec = AverageRecorder() # average over ALL records while test: metric_0, metric_1 = ... avg_rec.add({'m0': metric_0, 'm1': metric_1}) # dict is allowed print(avg_rec.to_str(precision=4, delimiter=' , ')) # 'm0: 0.1234 , m1: 0.5678' # # avg_rec.avg() == {'m0': 0.12343452, 'm1': 0.567800323} # avg_rec.avg(0) == 0.12343452 """ def __init__(self, size=None): if size is not None and size <= 0: raise ValueError("`size` must be > 0 or `None`.") self._size = size self._recorders = None self._default_metric_name = "metric" self._record_type = None def _to_dict(self, record): if isinstance(record, dict): record_dict = record elif isinstance(record, (list, tuple)): record_dict = {i: vi for i, vi in enumerate(record)} else: record_dict = {self._default_metric_name: record} return record_dict
[docs] def add(self, record, weight=None): """Appends a new record. :attr:`record` can be a `list`, `dict`, or a single scalar. The record type is determined at the first time :meth:`add` is called. All subsequent calls to :meth:`add` must have the same type of :attr:`record`. :attr:`record` in subsequent calls to :meth:`add` can contain only a subset of fields than the first call to :meth:`add`. Example: .. code-block:: python recorder.add({'1': 0.2, '2': 0.2}) # 1st call to `add` x = recorder.add({'1': 0.4}) # 2nd call to `add` # x == {'1': 0.3, '2': 0.2} Args: record: A single scalar, a list of scalars, or a dict of scalars. weight (optional): A scalar, weight of the new record for calculating a weighted average. If `None`, weight is set to `1`. For example, :attr:`weight` can be set to batch size and :attr:`record` the average value of certain metrics on the batch in order to calculate the average metric values on a whole dataset. Returns: The (moving) average after appending the record, with the same type as :attr:`record`. """ if self._record_type is None: self._record_type = type(record) elif self._record_type != type(record): raise ValueError('The type of `record` is not consistent. ' 'Expect type `{}`'.format(self._record_type)) record_dict = self._to_dict(record) if self._recorders is None: self._recorders = { name: _SingleAverageRecorder( self._size, name if self._record_type == dict else None) for name in record_dict.keys() } for name, val in record_dict.items(): self._recorders[name].add(val, weight=weight) return self.avg()
[docs] def avg(self, id_or_name=None): """Returns the (moving) average. Args: id_or_name (optional): A list of or a single element. Each element is the index (if the record type is `list`) or name (if the record type is `dict`) of the field for which the average is calculated. If not given, the average of all fields are returned. Returns: The average value(s). If :attr:`id_or_name` is a single element (not a list), then returns the average value of the corresponding field. Otherwise, if :attr:`id_or_name` is a list of element(s), then returns average value(s) in the same type as :attr:`record` of :meth:`add`. """ if self._recorders is None: return 0. keys = id_or_name if keys is None: keys = list(self._recorders.keys()) if not isinstance(keys, (list, tuple)): return self._recorders[keys].avg() avg = {key: self._recorders[key].avg() for key in keys} if self._record_type in {list, tuple}: ret_avg = [] for k, v in avg.items(): if k in keys: ret_avg.append(v) return self._record_type(ret_avg) elif self._record_type == dict: return avg else: return avg[self._default_metric_name]
[docs] def reset(self, id_or_name=None): """Resets the record. Args: id_or_name (optional): A list or a single element. Each element is the index (if the record type is `list`) or name (if the record type is `dict`) of the field to reset. If `None`, all fields are reset. """ keys = id_or_name if keys is None: keys = list(self._recorders.keys()) elif not isinstance(keys, (list, tuple)): keys = [keys] for key in keys: self._recorders[key].reset()
[docs] def to_str(self, precision=None, delimiter=' '): """Returns a string of the average values of the records. Args: precision (int, optional): The number of decimal places to keep in the returned string. E.g., for an average value of `0.1234`, :attr:`precision = 2` leads to `'0.12'`. delimiter (str): The delimiter string that separates between fields. Returns: A string of the average values. If record is of type `dict`, the string is a concatenation of 'field_name: average_value', delimited with :attr:`delimiter`. E.g., `'field_name_1: 0.1234 field_name_2: 0.5678 ...'`. Otherwise, the string is of a concatenation of 'average_value'. E.g., `'0.1234 0.5678 ...'` """ strs = {name: rec.to_str(precision=precision) for name, rec in self._recorders.items()} str_list = [] if self._record_type in {list, tuple}: for i in range(len(strs)): # Enumerates the keys in order, which are the indexes str_list.append(strs[i]) elif self._record_type == dict: str_list = list(strs.values()) else: str_list = [strs[self._default_metric_name]] avg_str = delimiter.join(str_list) return avg_str