Source code for torch_mimicry.training.metric_log

"""
MetricLog object for intelligently logging data to display them more intuitively.
"""


[docs]class MetricLog: """ A dictionary-like object that logs data, and includes an extra dict to map the metrics to its group name, if any, and the corresponding precision to print out. Attributes: metrics_dict (dict): A dictionary mapping to another dict containing the corresponding value, precision, and the group this metric belongs to. """ def __init__(self, **kwargs): self.metrics_dict = {}
[docs] def add_metric(self, name, value, group=None, precision=4): """ Logs metric to internal dict, but with an additional option of grouping certain metrics together. Args: name (str): Name of metric to log. value (Tensor/Float): Value of the metric to log. group (str): Name of the group to classify different metrics together. precision (int): The number of floating point precision to represent the value. Returns: None """ # Grab tensor values only try: value = value.item() except AttributeError: value = value self.metrics_dict[name] = dict(value=value, group=group, precision=precision)
def __getitem__(self, key): return round(self.metrics_dict[key]['value'], self.metrics_dict[key]['precision'])
[docs] def get_group_name(self, name): """ Obtains the group name of a particular metric. For example, errD and errG which represents the discriminator/generator losses could fall under a group name called "loss". Args: name (str): The name of the metric to retrieve group name. Returns: str: A string representing the group name of the metric. """ return self.metrics_dict[name]['group']
[docs] def keys(self): """ Dict like functionality for retrieving keys. """ return self.metrics_dict.keys()
[docs] def items(self): """ Dict like functionality for retrieving items. """ return self.metrics_dict.items()