torch_mimicry.training

Trainer

Implementation of Trainer object for training GANs.

class Trainer(netD, netG, optD, optG, dataloader, num_steps, log_dir='./log', n_dis=1, lr_decay=None, device=None, netG_ckpt_file=None, netD_ckpt_file=None, print_steps=1, vis_steps=500, log_steps=50, save_steps=5000, flush_secs=30)[source]

Trainer object for constructing the GAN training pipeline.

netD

Torch discriminator model.

Type:Module
netG

Torch generator model.

Type:Module
optD

Torch optimizer object for discriminator.

Type:Optimizer
optG

Torch optimizer object for generator.

Type:Optimizer
dataloader

Torch object for loading data from a dataset object.

Type:DataLoader
num_steps

The number of training iterations.

Type:int
n_dis

Number of discriminator update steps per generator training step.

Type:int
lr_decay

The learning rate decay policy to use.

Type:str
log_dir

The path to storing logging information and checkpoints.

Type:str
device

Torch device object to send model/data to.

Type:Device
logger

Logger object for visualising training information.

Type:Logger
scheduler

GAN training specific learning rate scheduler object.

Type:LRScheduler
params

Dictionary of training hyperparameters.

Type:dict
netD_ckpt_file

Custom checkpoint file to restore discriminator from.

Type:str
netG_ckpt_file

Custom checkpoint file to restore generator from.

Type:str
print_steps

Number of training steps before printing training info to stdout.

Type:int
vis_steps

Number of training steps before visualising images with TensorBoard.

Type:int
flush_secs

Number of seconds before flushing summaries to disk.

Type:int
log_steps

Number of training steps before writing summaries to TensorBoard.

Type:int
save_steps

Number of training steps bfeore checkpointing.

Type:int
train()[source]

Runs the training pipeline with all given parameters in Trainer.

Logger

Implementation of the Logger object for performing training logging and visualisation.

class Logger(log_dir, num_steps, dataset_size, device, flush_secs=120, **kwargs)[source]

Writes summaries and visualises training progress.

log_dir

The path to store logging information.

Type:str
num_steps

Total number of training iterations.

Type:int
dataset_size

The number of examples in the dataset.

Type:int
device

Torch device object to send data to.

Type:Device
flush_secs

Number of seconds before flushing summaries to disk.

Type:int
writers

A dictionary of tensorboard writers with keys as metric names.

Type:dict
num_epochs

The number of epochs, for extra information.

Type:int
close_writers()[source]

Closes all writers.

print_log(global_step, log_data, time_taken)[source]

Formats the string to print to stdout based on training information.

Parameters:
  • log_data (MetricLog) – Dict-like object to collect log data for TB writing.
  • global_step (int) – Global step variable for syncing logs.
  • time_taken (float) – Time taken for one training iteration.
Returns:

String to be printed to stdout.

Return type:

str

vis_images(netG, global_step, num_images=64)[source]

Produce visualisations of the G(z), one fixed and one random.

Parameters:
  • netG (Module) – Generator model object for producing images.
  • global_step (int) – Global step variable for syncing logs.
  • num_images (int) – The number of images to visualise.
Returns:

None

write_summaries(log_data, global_step)[source]

Tasks appropriate writers to write the summaries in tensorboard. Creates additional writers for summary writing if there are new scalars to log in log_data.

Parameters:
  • log_data (MetricLog) – Dict-like object to collect log data for TB writing.
  • global_step (int) – Global step variable for syncing logs.
Returns:

None

MetricLog

MetricLog object for intelligently logging data to display them more intuitively.

class MetricLog(**kwargs)[source]

A dictionary-like object that logs data, and includes an extra dict to map the metrics to its group name, if any, and the corresponding precision to print out.

metrics_dict

A dictionary mapping to another dict containing the corresponding value, precision, and the group this metric belongs to.

Type:dict
add_metric(name, value, group=None, precision=4)[source]

Logs metric to internal dict, but with an additional option of grouping certain metrics together.

Parameters:
  • name (str) – Name of metric to log.
  • value (Tensor/Float) – Value of the metric to log.
  • group (str) – Name of the group to classify different metrics together.
  • precision (int) – The number of floating point precision to represent the value.
Returns:

None

get_group_name(name)[source]

Obtains the group name of a particular metric. For example, errD and errG which represents the discriminator/generator losses could fall under a group name called “loss”.

Parameters:name (str) – The name of the metric to retrieve group name.
Returns:A string representing the group name of the metric.
Return type:str
items()[source]

Dict like functionality for retrieving items.

keys()[source]

Dict like functionality for retrieving keys.

Scheduler

Implementation of a specific learning rate scheduler for GANs.

class LRScheduler(lr_decay, optD, optG, num_steps, start_step=0, **kwargs)[source]

Learning rate scheduler for training GANs. Supports GAN specific LR scheduling policies, such as the linear decay policy using in SN-GAN paper as based on the original chainer implementation. However, one could safely ignore this class and instead use the official PyTorch scheduler wrappers around a optimizer for other scheduling policies.

lr_decay

The learning rate decay policy to use.

Type:str
optD

Torch optimizer object for discriminator.

Type:Optimizer
optG

Torch optimizer object for generator.

Type:Optimizer
num_steps

The number of training iterations.

Type:int
lr_D

The initial learning rate of optD.

Type:float
lr_G

The initial learning rate of optG.

Type:float
linear_decay(optimizer, global_step, lr_value_range, lr_step_range)[source]

Performs linear decay of the optimizer learning rate based on the number of global steps taken. Follows SNGAN’s chainer implementation of linear decay, as seen in the chainer references: https://docs.chainer.org/en/stable/reference/generated/chainer.training.extensions.LinearShift.html https://github.com/chainer/chainer/blob/v6.2.0/chainer/training/extensions/linear_shift.py#L66

Note: assumes that the optimizer has only one parameter group to update!

Parameters:
  • optimizer (Optimizer) – Torch optimizer object to update learning rate.
  • global_step (int) – The current global step of the training.
  • lr_value_range (tuple) – A tuple of floats (x,y) to decrease from x to y.
  • lr_step_range (tuple) – A tuple of ints (i, j) to start decreasing when global_step > i, and until j.
Returns:

Float representing the new updated learning rate.

Return type:

float

step(log_data, global_step)[source]

Takes a step for updating learning rate and updates the input log_data with the current status.

Parameters:
  • log_data (MetricLog) – Object for logging the updated learning rate metric.
  • global_step (int) – The current global step of the training.
Returns:

MetricLog object containing the updated learning rate at the current global step.

Return type:

MetricLog