torch_mimicry.training¶
Trainer¶
Implementation of Trainer object for training GANs.
-
class
Trainer
(netD, netG, optD, optG, dataloader, num_steps, log_dir='./log', n_dis=1, lr_decay=None, device=None, netG_ckpt_file=None, netD_ckpt_file=None, print_steps=1, vis_steps=500, log_steps=50, save_steps=5000, flush_secs=30)[source]¶ Trainer object for constructing the GAN training pipeline.
-
netD
¶ Torch discriminator model.
Type: Module
-
netG
¶ Torch generator model.
Type: Module
-
optD
¶ Torch optimizer object for discriminator.
Type: Optimizer
-
optG
¶ Torch optimizer object for generator.
Type: Optimizer
-
dataloader
¶ Torch object for loading data from a dataset object.
Type: DataLoader
-
device
¶ Torch device object to send model/data to.
Type: Device
-
scheduler
¶ GAN training specific learning rate scheduler object.
Type: LRScheduler
-
Logger¶
Implementation of the Logger object for performing training logging and visualisation.
-
class
Logger
(log_dir, num_steps, dataset_size, device, flush_secs=120, **kwargs)[source]¶ Writes summaries and visualises training progress.
-
device
¶ Torch device object to send data to.
Type: Device
-
print_log
(global_step, log_data, time_taken)[source]¶ Formats the string to print to stdout based on training information.
Parameters: Returns: String to be printed to stdout.
Return type:
-
vis_images
(netG, global_step, num_images=64)[source]¶ Produce visualisations of the G(z), one fixed and one random.
Parameters: Returns: None
-
MetricLog¶
MetricLog object for intelligently logging data to display them more intuitively.
-
class
MetricLog
(**kwargs)[source]¶ A dictionary-like object that logs data, and includes an extra dict to map the metrics to its group name, if any, and the corresponding precision to print out.
-
metrics_dict
¶ A dictionary mapping to another dict containing the corresponding value, precision, and the group this metric belongs to.
Type: dict
-
add_metric
(name, value, group=None, precision=4)[source]¶ Logs metric to internal dict, but with an additional option of grouping certain metrics together.
Parameters: Returns: None
-
get_group_name
(name)[source]¶ Obtains the group name of a particular metric. For example, errD and errG which represents the discriminator/generator losses could fall under a group name called “loss”.
Parameters: name (str) – The name of the metric to retrieve group name. Returns: A string representing the group name of the metric. Return type: str
-
Scheduler¶
Implementation of a specific learning rate scheduler for GANs.
-
class
LRScheduler
(lr_decay, optD, optG, num_steps, start_step=0, **kwargs)[source]¶ Learning rate scheduler for training GANs. Supports GAN specific LR scheduling policies, such as the linear decay policy using in SN-GAN paper as based on the original chainer implementation. However, one could safely ignore this class and instead use the official PyTorch scheduler wrappers around a optimizer for other scheduling policies.
-
optD
¶ Torch optimizer object for discriminator.
Type: Optimizer
-
optG
¶ Torch optimizer object for generator.
Type: Optimizer
-
linear_decay
(optimizer, global_step, lr_value_range, lr_step_range)[source]¶ Performs linear decay of the optimizer learning rate based on the number of global steps taken. Follows SNGAN’s chainer implementation of linear decay, as seen in the chainer references: https://docs.chainer.org/en/stable/reference/generated/chainer.training.extensions.LinearShift.html https://github.com/chainer/chainer/blob/v6.2.0/chainer/training/extensions/linear_shift.py#L66
Note: assumes that the optimizer has only one parameter group to update!
Parameters: - optimizer (Optimizer) – Torch optimizer object to update learning rate.
- global_step (int) – The current global step of the training.
- lr_value_range (tuple) – A tuple of floats (x,y) to decrease from x to y.
- lr_step_range (tuple) – A tuple of ints (i, j) to start decreasing when global_step > i, and until j.
Returns: Float representing the new updated learning rate.
Return type:
-