torch_mimicry.nets

torch_mimicry.nets.basemodel

Implementation of BaseModel.

class BaseModel[source]

BaseModel with basic functionalities for checkpointing and restoration.

count_params()[source]

Computes the number of parameters in this model.

Args: None

Returns:Total number of weight parameters for this model. int: Total number of trainable parameters for this model.
Return type:int
restore_checkpoint(ckpt_file, optimizer=None)[source]

Restores checkpoint from a pth file and restores optimizer state.

Parameters:
  • ckpt_file (str) – A PyTorch pth file containing model weights.
  • optimizer (Optimizer) – A vanilla optimizer to have its state restored from.
Returns:

Global step variable where the model was last checkpointed.

Return type:

int

save_checkpoint(directory, global_step, optimizer=None, name=None)[source]

Saves checkpoint at a certain global step during training. Optimizer state is also saved together.

Parameters:
  • directory (str) – Path to save checkpoint to.
  • global_step (int) – The global step variable during training.
  • optimizer (Optimizer) – Optimizer state to be saved concurrently.
  • name (str) – The name to save the checkpoint file as.
Returns:

None

torch_mimicry.nets.gan

Base Unconditional GAN

Implementation of Base GAN models.

class BaseDiscriminator(ndf, loss_type, **kwargs)[source]

Base class for a generic unconditional discriminator model.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
compute_gan_loss(output_real, output_fake)[source]

Computes GAN loss for discriminator.

Parameters:
  • output_real (Tensor) – A batch of output logits of shape (N, 1) from real images.
  • output_fake (Tensor) – A batch of output logits of shape (N, 1) from fake images.
Returns:

A batch of GAN losses for the discriminator.

Return type:

errD (Tensor)

compute_probs(output_real, output_fake)[source]

Computes probabilities from real/fake images logits.

Parameters:
  • output_real (Tensor) – A batch of output logits of shape (N, 1) from real images.
  • output_fake (Tensor) – A batch of output logits of shape (N, 1) from fake images.
Returns:

Average probabilities of real/fake image considered as real for the batch.

Return type:

tuple

train_step(real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs)[source]

Takes one training step for D.

Parameters:
  • real_batch (Tensor) – A batch of real images of shape (N, C, H, W).
  • loss_type (str) – Name of loss to use for GAN loss.
  • netG (nn.Module) – Generator model for obtaining fake images.
  • optD (Optimizer) – Optimizer for updating discriminator’s parameters.
  • device (torch.device) – Device to use for running the model.
  • log_data (dict) – A dict mapping name to values for logging uses.
  • global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns:

Returns MetricLog object containing updated logging variables after 1 training step.

Return type:

MetricLog

class BaseGenerator(nz, ngf, bottom_width, loss_type, **kwargs)[source]

Base class for a generic unconditional generator model.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
compute_gan_loss(output)[source]

Computes GAN loss for generator.

Parameters:output (Tensor) – A batch of output logits from the discriminator of shape (N, 1).
Returns:A batch of GAN losses for the generator.
Return type:Tensor
generate_images(num_images, device=None)[source]

Generates num_images randomly.

Parameters:
  • num_images (int) – Number of images to generate
  • device (torch.device) – Device to send images to.
Returns:

A batch of generated images.

Return type:

Tensor

train_step(real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs)[source]

Takes one training step for G.

Parameters:
  • real_batch (Tensor) – A batch of real images of shape (N, C, H, W). Used for obtaining current batch size.
  • netD (nn.Module) – Discriminator model for obtaining losses.
  • optG (Optimizer) – Optimizer for updating generator’s parameters.
  • log_data (dict) – A dict mapping name to values for logging uses.
  • device (torch.device) – Device to use for running the model.
  • global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns:

Returns MetricLog object containing updated logging variables after 1 training step.

Base Conditional GAN

Implementation of Base GAN models for a generic conditional GAN.

class BaseConditionalDiscriminator(num_classes, ndf, loss_type, **kwargs)[source]

Base class for a generic conditional discriminator model.

num_classes

Number of classes, more than 0 for conditional GANs.

Type:int
ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
train_step(real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs)[source]

Takes one training step for D.

Parameters:
  • real_batch (Tensor) – A batch of real images of shape (N, C, H, W).
  • loss_type (str) – Name of loss to use for GAN loss.
  • netG (nn.Module) – Generator model for obtaining fake images.
  • optD (Optimizer) – Optimizer for updating discriminator’s parameters.
  • device (torch.device) – Device to use for running the model.
  • log_data (MetricLog) – A dict mapping name to values for logging uses.
  • global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns:

Returns MetricLog object containing updated logging variables after 1 training step.

Return type:

MetricLog

class BaseConditionalGenerator(num_classes, nz, ngf, bottom_width, loss_type, **kwargs)[source]

Base class for a generic conditional generator model.

num_classes

Number of classes, more than 0 for conditional GANs.

Type:int
nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
generate_images(num_images, c=None, device=None)[source]

Generate images with possibility for conditioning on a fixed class.

Parameters:
  • num_images (int) – The number of images to generate.
  • c (int) – The class of images to generate. If None, generates random images.
  • device (int) – The device to send the generated images to.
Returns:

Batch of generated images and their corresponding labels.

Return type:

tuple

generate_images_with_labels(num_images, c=None, device=None)[source]

Generate images with possibility for conditioning on a fixed class. Additionally returns labels.

Parameters:
  • num_images (int) – The number of images to generate.
  • c (int) – The class of images to generate. If None, generates random images.
  • device (int) – The device to send the generated images to.
Returns:

Batch of generated images and their corresponding labels.

Return type:

tuple

train_step(real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs)[source]

Takes one training step for G.

Parameters:
  • real_batch (Tensor) – A batch of real images of shape (N, C, H, W). Used for obtaining current batch size.
  • netD (nn.Module) – Discriminator model for obtaining losses.
  • optG (Optimizer) – Optimizer for updating generator’s parameters.
  • log_data (MetricLog) – A dict mapping name to values for logging uses.
  • device (torch.device) – Device to use for running the model.
  • global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns:

Returns MetricLog object containing updated logging variables after 1 training step.

Return type:

MetricLog

torch_mimicry.nets.dcgan

DCGAN CIFAR

Implementation of DCGAN based on Kurach et al. specifically for CIFAR-10. The main difference with dcgan_32 is in using sigmoid as the final activation for the generator instead of tanh.

To reproduce scores, CIFAR-10 images should not be normalized from -1 to 1, and should instead have values from 0 to 1, which is the default when loading images as np arrays.

class DCGANDiscriminatorCIFAR(ndf=128, **kwargs)[source]

ResNet backbone discriminator for ResNet DCGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class DCGANGeneratorCIFAR(nz=128, ngf=256, bottom_width=4, **kwargs)[source]

ResNet backbone generator for ResNet DCGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

DCGAN 32

Implementation of DCGAN for image size 32.

class DCGANDiscriminator32(ndf=128, **kwargs)[source]

ResNet backbone discriminator for ResNet DCGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class DCGANGenerator32(nz=128, ngf=256, bottom_width=4, **kwargs)[source]

ResNet backbone generator for ResNet DCGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

DCGAN 48

Implementation of DCGAN for image size 48.

class DCGANDiscriminator48(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for ResNet DCGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class DCGANGenerator48(nz=128, ngf=512, bottom_width=6, **kwargs)[source]

ResNet backbone generator for ResNet DCGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

DCGAN 64

Implementation of DCGAN for image size 64.

class DCGANDiscriminator64(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for ResNet DCGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class DCGANGenerator64(nz=128, ngf=1024, bottom_width=4, **kwargs)[source]

ResNet backbone generator for ResNet DCGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

DCGAN 128

Implementation of DCGAN for image size 128.

class DCGANDiscriminator128(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for ResNet DCGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class DCGANGenerator128(nz=128, ngf=1024, bottom_width=4, **kwargs)[source]

ResNet backbone generator for ResNet DCGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

DCGAN Base

Base class definition of DCGAN.

class DCGANBaseDiscriminator(ndf, loss_type='ns', **kwargs)[source]

ResNet backbone discriminator for ResNet DCGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
class DCGANBaseGenerator(nz, ngf, bottom_width, loss_type='ns', **kwargs)[source]

ResNet backbone generator for ResNet DCGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str

torch_mimicry.nets.wgan_gp

WGAN-GP 32

Implementation of WGAN-GP for image size 32.

class WGANGPDiscriminator32(ndf=128, **kwargs)[source]

ResNet backbone discriminator for WGAN-GP.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
gp_scale

Lamda parameter for gradient penalty.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class WGANGPGenerator32(nz=128, ngf=256, bottom_width=4, **kwargs)[source]

ResNet backbone generator for WGAN-GP.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

WGAN-GP 48

Implementation of WGAN-GP for image size 48.

class WGANGPDiscriminator48(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for WGAN-GP.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
gp_scale

Lamda parameter for gradient penalty.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class WGANGPGenerator48(nz=128, ngf=512, bottom_width=6, **kwargs)[source]

ResNet backbone generator for WGAN-GP.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

WGAN-GP 64

Implementation of WGAN-GP for image size 64.

class WGANGPDiscriminator64(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for WGAN-GP.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
gp_scale

Lamda parameter for gradient penalty.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class WGANGPGenerator64(nz=128, ngf=1024, bottom_width=4, **kwargs)[source]

ResNet backbone generator for WGAN-GP.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

WGAN-GP 128

Implementation of WGAN-GP for image size 128.

class WGANGPDiscriminator128(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for WGAN-GP.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
gp_scale

Lamda parameter for gradient penalty.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class WGANGPGenerator128(nz=128, ngf=1024, bottom_width=4, **kwargs)[source]

ResNet backbone generator for WGAN-GP.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

WGAN-GP Base

Base class implementation of WGAN-GP.

class WGANGPBaseDiscriminator(ndf, loss_type='wasserstein', gp_scale=10.0, **kwargs)[source]

ResNet backbone discriminator for WGAN-GP.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
gp_scale

Lamda parameter for gradient penalty.

Type:float
compute_gradient_penalty_loss(real_images, fake_images, gp_scale=10.0)[source]

Computes gradient penalty loss, as based on: https://github.com/jalola/improved-wgan-pytorch/blob/master/gan_train.py

Parameters:
  • real_images (Tensor) – A batch of real images of shape (N, 3, H, W).
  • fake_images (Tensor) – A batch of fake images of shape (N, 3, H, W).
  • gp_scale (float) – Gradient penalty lamda parameter.
Returns:

Scalar gradient penalty loss.

Return type:

Tensor

train_step(real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs)[source]

Takes one training step for D.

Parameters:
  • real_batch (Tensor) – A batch of real images of shape (N, C, H, W).
  • netG (nn.Module) – Generator model for obtaining fake images.
  • optD (Optimizer) – Optimizer for updating discriminator’s parameters.
  • device (torch.device) – Device to use for running the model.
  • log_data (MetricLog) – An object to add custom metrics for visualisations.
  • global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns:

Returns MetricLog object containing updated logging variables after 1 training step.

Return type:

MetricLog

class WGANGPBaseGenerator(nz, ngf, bottom_width, loss_type='wasserstein', **kwargs)[source]

ResNet backbone generator for WGAN-GP.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
train_step(real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs)[source]

Takes one training step for G.

Parameters:
  • real_batch (Tensor) – A batch of real images of shape (N, C, H, W). Used for obtaining current batch size.
  • netD (nn.Module) – Discriminator model for obtaining losses.
  • optG (Optimizer) – Optimizer for updating generator’s parameters.
  • log_data (MetricLog) – An object to add custom metrics for visualisations.
  • device (torch.device) – Device to use for running the model.
  • global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns:

Returns MetricLog object containing updated logging variables after 1 training step.

Return type:

MetricLog

torch_mimicry.nets.sngan

SNGAN 32

Implementation of SNGAN for image size 32.

class SNGANDiscriminator32(ndf=128, **kwargs)[source]

ResNet backbone discriminator for SNGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class SNGANGenerator32(nz=128, ngf=256, bottom_width=4, **kwargs)[source]

ResNet backbone generator for SNGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

SNGAN 48

Implementation of SNGAN for image size 48.

class SNGANDiscriminator48(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for SNGAN.

Attribates:
ndf (int): Variable controlling discriminator feature map sizes. loss_type (str): Name of loss to use for GAN loss.
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class SNGANGenerator48(nz=128, ngf=512, bottom_width=6, **kwargs)[source]

ResNet backbone generator for SNGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

SNGAN 64

Implementation of SNGAN for image size 64.

class SNGANDiscriminator64(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for SNGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class SNGANGenerator64(nz=128, ngf=1024, bottom_width=4, **kwargs)[source]

ResNet backbone generator for SNGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

SNGAN 128

Implementation of SNGAN for image size 128.

class SNGANDiscriminator128(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for SNGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1).
Return type:Tensor
class SNGANGenerator128(nz=128, ngf=1024, bottom_width=4, **kwargs)[source]

ResNet backbone generator for SNGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

SNGAN Base

Base implementation of SNGAN with default variables.

class SNGANBaseDiscriminator(ndf, loss_type='hinge', **kwargs)[source]

ResNet backbone discriminator for SNGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
class SNGANBaseGenerator(nz, ngf, bottom_width, loss_type='hinge', **kwargs)[source]

ResNet backbone generator for SNGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str

torch_mimicry.nets.cgan_pd

CGAN-PD 32

Implementation of cGAN-PD for image size 32.

class CGANPDDiscriminator32(num_classes, ndf=128, **kwargs)[source]

ResNet backbone discriminator for cGAN-PD.

num_classes

Number of classes, more than 0 for conditional GANs.

Type:int
ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x, y=None)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits. Further projects labels to condition on the output logit score.

Parameters:
  • x (Tensor) – A batch of images of shape (N, C, H, W).
  • y (Tensor) – A batch of labels of shape (N,).
Returns:

A batch of GAN logits of shape (N, 1).

Return type:

output (Tensor)

class CGANPDGenerator32(num_classes, bottom_width=4, nz=128, ngf=256, **kwargs)[source]

ResNet backbone generator for cGAN-PD,

num_classes

Number of classes, more than 0 for conditional GANs.

Type:int
nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x, y=None)[source]

Feedforwards a batch of noise vectors into a batch of fake images, also conditioning the batch norm with labels of the images to be produced.

Parameters:
  • x (Tensor) – A batch of noise vectors of shape (N, nz).
  • y (Tensor) – A batch of labels of shape (N,) for conditional batch norm.
Returns:

A batch of fake images of shape (N, C, H, W).

Return type:

Tensor

CGAN-PD 128

Implementation of cGAN-PD for image size 128.

class CGANPDDiscriminator128(num_classes, ndf=128, **kwargs)[source]

ResNet backbone discriminator for cGAN-PD.

num_classes

Number of classes, more than 0 for conditional GANs.

Type:int
ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x, y=None)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits. Further projects labels to condition on the output logit score.

Parameters:
  • x (Tensor) – A batch of images of shape (N, C, H, W).
  • y (Tensor) – A batch of labels of shape (N,).
Returns:

A batch of GAN logits of shape (N, 1).

Return type:

Tensor

class CGANPDGenerator128(num_classes, nz=128, ngf=1024, bottom_width=4, **kwargs)[source]

ResNet backbone generator for cGAN-PD,

num_classes

Number of classes, more than 0 for conditional GANs.

Type:int
nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
forward(x, y=None)[source]

Feedforwards a batch of noise vectors into a batch of fake images, also conditioning the batch norm with labels of the images to be produced.

Parameters:
  • x (Tensor) – A batch of noise vectors of shape (N, nz).
  • y (Tensor) – A batch of labels of shape (N,) for conditional batch norm.
Returns:

A batch of fake images of shape (N, C, H, W).

Return type:

Tensor

CGAN-PD Base

Base class definition of cGAN-PD.

class CGANPDBaseDiscriminator(num_classes, ndf, loss_type='hinge', **kwargs)[source]

ResNet backbone discriminator for cGAN-PD.

num_classes

Number of classes, more than 0 for conditional GANs.

Type:int
ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
class CGANPDBaseGenerator(num_classes, bottom_width, nz, ngf, loss_type='hinge', **kwargs)[source]

ResNet backbone generator for cGAN-PD,

num_classes

Number of classes, more than 0 for conditional GANs.

Type:int
nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str

torch_mimicry.nets.ssgan

SSGAN 32

Implementation of SSGAN for image size 32.

class SSGANDiscriminator32(ndf=128, **kwargs)[source]

ResNet backbone discriminator for SSGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
ss_loss_scale

Self-supervised loss scale for discriminator.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits, and rotation classes.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1). Tensor: A batch of predicted classes of shape (N, num_classes).
Return type:Tensor
class SSGANGenerator32(nz=128, ngf=256, bottom_width=4, **kwargs)[source]

ResNet backbone generator for SSGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
ss_loss_scale

Self-supervised loss scale for generator.

Type:float
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

SSGAN 48

Implementation of SSGAN for image size 48.

class SSGANDiscriminator48(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for SSGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
ss_loss_scale

Self-supervised loss scale for discriminator.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits, and rotation classes.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1). Tensor: A batch of predicted classes of shape (N, num_classes).
Return type:Tensor
class SSGANGenerator48(nz=128, ngf=512, bottom_width=6, **kwargs)[source]

ResNet backbone generator for SSGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
ss_loss_scale

Self-supervised loss scale for generator.

Type:float
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

SSGAN 64

Implementation of SSGAN for image size 64.

class SSGANDiscriminator64(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for SSGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
ss_loss_scale

Self-supervised loss scale for discriminator.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits, and rotation classes.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1). Tensor: A batch of predicted classes of shape (N, num_classes).
Return type:Tensor
class SSGANGenerator64(nz=128, ngf=1024, bottom_width=4, **kwargs)[source]

ResNet backbone generator for SSGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
ss_loss_scale

Self-supervised loss scale for generator.

Type:float
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

SSGAN 128

Implementation of SSGAN for image size 128.

class SSGANDiscriminator128(ndf=1024, **kwargs)[source]

ResNet backbone discriminator for SSGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
ss_loss_scale

Self-supervised loss scale for discriminator.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits, and rotation classes.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1). Tensor: A batch of predicted classes of shape (N, num_classes).
Return type:Tensor
class SSGANGenerator128(nz=128, ngf=1024, bottom_width=4, **kwargs)[source]

ResNet backbone generator for SSGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
ss_loss_scale

Self-supervised loss scale for generator.

Type:float
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

SSGAN Base

Implementation of Base SSGAN models.

class SSGANBaseDiscriminator(ndf, loss_type='hinge', ss_loss_scale=1.0, **kwargs)[source]

ResNet backbone discriminator for SSGAN.

ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
ss_loss_scale

Self-supervised loss scale for discriminator.

Type:float
compute_ss_loss(images, scale)[source]

Function to compute SS loss.

Parameters:
  • images (Tensor) – A batch of non-rotated, upright images.
  • scale (float) – The parameter to scale SS loss by.
Returns:

Scalar tensor representing the SS loss.

Return type:

Tensor

train_step(real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs)[source]

Takes one training step for D.

Parameters:
  • real_batch (Tensor) – A batch of real images of shape (N, C, H, W).
  • netG (nn.Module) – Generator model for obtaining fake images.
  • optD (Optimizer) – Optimizer for updating discriminator’s parameters.
  • device (torch.device) – Device to use for running the model.
  • log_data (MetricLog) – An object to add custom metrics for visualisations.
  • global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns:

Returns MetricLog object containing updated logging variables after 1 training step.

Return type:

MetricLog

class SSGANBaseGenerator(nz, ngf, bottom_width, loss_type='hinge', ss_loss_scale=0.2, **kwargs)[source]

ResNet backbone generator for SSGAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
ss_loss_scale

Self-supervised loss scale for generator.

Type:float
train_step(real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs)[source]

Takes one training step for G.

Parameters:
  • real_batch (Tensor) – A batch of real images of shape (N, C, H, W). Used for obtaining current batch size.
  • netD (nn.Module) – Discriminator model for obtaining losses.
  • optG (Optimizer) – Optimizer for updating generator’s parameters.
  • log_data (MetricLog) – An object to add custom metrics for visualisations.
  • device (torch.device) – Device to use for running the model.
  • global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns:

Returns MetricLog object containing updated logging variables after 1 training step.

Return type:

MetricLog

torch_mimicry.nets.infomax_gan

InfoMax-GAN 32

Implementation of InfoMax-GAN for image size 32.

class InfoMaxGANDiscriminator32(nrkhs=1024, ndf=128, **kwargs)[source]

ResNet backbone discriminator for InfoMax-GAN.

nrkhs

The RKHS dimension R to project the local and global features to.

Type:int
ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
infomax_loss_scale

The beta parameter used for scaling the discriminator infomax loss.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits, local features of the images, and global features of the images.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1). Tensor: A batch of local features of shape (N, ndf, H>>2, W>>2). Tensor: A batch of global features of shape (N, ndf)
Return type:Tensor
class InfoMaxGANGenerator32(nz=128, ngf=256, bottom_width=4, **kwargs)[source]

ResNet backbone generator for InfoMax-GAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
infomax_loss_scale

The alpha parameter used for scaling the generator infomax loss.

Type:float
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

InfoMax-GAN 48

Implementation of InfoMax-GAN for image size 48.

class InfoMaxGANDiscriminator48(ndf=1024, nrkhs=1024, **kwargs)[source]

ResNet backbone discriminator for SNGAN-InfoMax.

nrkhs

The RKHS dimension R to project the local and global features to.

Type:int
ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
infomax_loss_scale

The beta parameter used for scaling the discriminator infomax loss.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits, local features of the images, and global features of the images.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1). Tensor: A batch of local features of shape (N, ndf, H>>2, W>>2). Tensor: A batch of global features of shape (N, ndf)
Return type:Tensor
class InfoMaxGANGenerator48(nz=128, ngf=512, bottom_width=6, **kwargs)[source]

ResNet backbone generator for InfoMax-GAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
infomax_loss_scale

The alpha parameter used for scaling the generator infomax loss.

Type:float
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

InfoMax-GAN 64

Implementation of InfoMax-GAN for image size 64.

class InfoMaxGANDiscriminator64(nrkhs=1024, ndf=1024, **kwargs)[source]

ResNet backbone discriminator for InfoMax-GAN.

nrkhs

The RKHS dimension R to project the local and global features to.

Type:int
ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
infomax_loss_scale

The beta parameter used for scaling the discriminator infomax loss.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits, local features of the images, and global features of the images.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1). Tensor: A batch of local features of shape (N, ndf, H>>2, W>>2). Tensor: A batch of global features of shape (N, ndf)
Return type:Tensor
class InfoMaxGANGenerator64(nz=128, ngf=1024, bottom_width=4, **kwargs)[source]

ResNet backbone generator for InfoMax-GAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
infomax_loss_scale

The alpha parameter used for scaling the generator infomax loss.

Type:float
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

InfoMax-GAN 128

Implementation of InfoMax-GAN for image size 128.

class InfoMaxGANDiscriminator128(nrkhs=1024, ndf=1024, **kwargs)[source]

ResNet backbone discriminator for SNGAN-InfoMax.

nrkhs

The RKHS dimension R to project the local and global features to.

Type:int
ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
infomax_loss_scale

The beta parameter used for scaling the discriminator infomax loss.

Type:float
forward(x)[source]

Feedforwards a batch of real/fake images and produces a batch of GAN logits, local features of the images, and global features of the images.

Parameters:x (Tensor) – A batch of images of shape (N, C, H, W).
Returns:A batch of GAN logits of shape (N, 1). Tensor: A batch of local features of shape (N, ndf, H>>2, W>>2). Tensor: A batch of global features of shape (N, ndf)
Return type:Tensor
class InfoMaxGANGenerator128(nz=128, ngf=1024, bottom_width=4, **kwargs)[source]

ResNet backbone generator for InfoMax-GAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
infomax_loss_scale

The alpha parameter used for scaling the generator infomax loss.

Type:float
forward(x)[source]

Feedforwards a batch of noise vectors into a batch of fake images.

Parameters:x (Tensor) – A batch of noise vectors of shape (N, nz).
Returns:A batch of fake images of shape (N, C, H, W).
Return type:Tensor

InfoMax-GAN Base

Implementation of InfoMax-GAN base model.

class BaseDiscriminator(nrkhs, ndf, loss_type='hinge', infomax_loss_scale=0.2, **kwargs)[source]

ResNet backbone discriminator for SNGAN-Infomax.

nrkhs

The RKHS dimension R to project the local and global features to.

Type:int
ndf

Variable controlling discriminator feature map sizes.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
infomax_loss_scale

The beta parameter used for scaling the discriminator infomax loss.

Type:float
compute_infomax_loss(local_feat, global_feat, scale)[source]

Given local and global features of a real or fake image, produce the average dot product score between each local and global features, which is then used to obtain infoNCE loss.

Args
local_feat (Tensor): A batch of local features. global_feat (Tensor): A batch of global features. scale (float): The scaling hyperparameter for the infomax loss.
Returns:Scalar Tensor representing the scaled infomax loss.
Return type:Tensor
infonce_loss(l, m)[source]

InfoNCE loss for local and global feature maps as used in DIM: https://github.com/rdevon/DIM/blob/master/cortex_DIM/functions/dim_losses.py

Parameters:
  • l (Tensor) – Local feature map of shape (N, ndf, H*W).
  • m (Tensor) – Global feature vector of shape (N, ndf, 1).
Returns:

Scalar loss Tensor.

Return type:

Tensor

project_features(local_feat, global_feat)[source]

Projects local and global features.

train_step(real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs)[source]

Takes one training step for D.

Parameters:
  • real_batch (Tensor) – A batch of real images of shape (N, C, H, W).
  • netG (nn.Module) – Generator model for obtaining fake images.
  • optD (Optimizer) – Optimizer for updating discriminator’s parameters.
  • device (torch.device) – Device to use for running the model.
  • log_data (MetricLog) – An object to add custom metrics for visualisations.
  • global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns:

Returns MetricLog object containing updated logging variables after 1 training step.

Return type:

MetricLog

class InfoMaxGANBaseGenerator(nz, ngf, bottom_width, loss_type='hinge', infomax_loss_scale=0.2, **kwargs)[source]

ResNet backbone generator for InfoMax-GAN.

nz

Noise dimension for upsampling.

Type:int
ngf

Variable controlling generator feature map sizes.

Type:int
bottom_width

Starting width for upsampling generator output to an image.

Type:int
loss_type

Name of loss to use for GAN loss.

Type:str
infomax_loss_scale

The alpha parameter used for scaling the generator infomax loss.

Type:float
train_step(real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs)[source]

Takes one training step for G.

Parameters:
  • real_batch (Tensor) – A batch of real images of shape (N, C, H, W). Used for obtaining current batch size.
  • netD (nn.Module) – Discriminator model for obtaining losses.
  • optG (Optimizer) – Optimizer for updating generator’s parameters.
  • log_data (MetricLog) – An object to add custom metrics for visualisations.
  • device (torch.device) – Device to use for running the model.
  • global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns:

Returns MetricLog object containing updated logging variables after 1 training step.

Return type:

MetricLog