torch_mimicry.nets¶
Contents
torch_mimicry.nets.basemodel¶
Implementation of BaseModel.
-
class
BaseModel
[source]¶ BaseModel with basic functionalities for checkpointing and restoration.
-
count_params
()[source]¶ Computes the number of parameters in this model.
Args: None
Returns: Total number of weight parameters for this model. int: Total number of trainable parameters for this model. Return type: int
-
restore_checkpoint
(ckpt_file, optimizer=None)[source]¶ Restores checkpoint from a pth file and restores optimizer state.
Parameters: - ckpt_file (str) – A PyTorch pth file containing model weights.
- optimizer (Optimizer) – A vanilla optimizer to have its state restored from.
Returns: Global step variable where the model was last checkpointed.
Return type:
-
torch_mimicry.nets.gan¶
Base Unconditional GAN¶
Implementation of Base GAN models.
-
class
BaseDiscriminator
(ndf, loss_type, **kwargs)[source]¶ Base class for a generic unconditional discriminator model.
-
compute_gan_loss
(output_real, output_fake)[source]¶ Computes GAN loss for discriminator.
Parameters: - output_real (Tensor) – A batch of output logits of shape (N, 1) from real images.
- output_fake (Tensor) – A batch of output logits of shape (N, 1) from fake images.
Returns: A batch of GAN losses for the discriminator.
Return type: errD (Tensor)
-
compute_probs
(output_real, output_fake)[source]¶ Computes probabilities from real/fake images logits.
Parameters: - output_real (Tensor) – A batch of output logits of shape (N, 1) from real images.
- output_fake (Tensor) – A batch of output logits of shape (N, 1) from fake images.
Returns: Average probabilities of real/fake image considered as real for the batch.
Return type:
-
train_step
(real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs)[source]¶ Takes one training step for D.
Parameters: - real_batch (Tensor) – A batch of real images of shape (N, C, H, W).
- loss_type (str) – Name of loss to use for GAN loss.
- netG (nn.Module) – Generator model for obtaining fake images.
- optD (Optimizer) – Optimizer for updating discriminator’s parameters.
- device (torch.device) – Device to use for running the model.
- log_data (dict) – A dict mapping name to values for logging uses.
- global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns: Returns MetricLog object containing updated logging variables after 1 training step.
Return type:
-
-
class
BaseGenerator
(nz, ngf, bottom_width, loss_type, **kwargs)[source]¶ Base class for a generic unconditional generator model.
-
compute_gan_loss
(output)[source]¶ Computes GAN loss for generator.
Parameters: output (Tensor) – A batch of output logits from the discriminator of shape (N, 1). Returns: A batch of GAN losses for the generator. Return type: Tensor
-
generate_images
(num_images, device=None)[source]¶ Generates num_images randomly.
Parameters: - num_images (int) – Number of images to generate
- device (torch.device) – Device to send images to.
Returns: A batch of generated images.
Return type: Tensor
-
train_step
(real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs)[source]¶ Takes one training step for G.
Parameters: - real_batch (Tensor) – A batch of real images of shape (N, C, H, W). Used for obtaining current batch size.
- netD (nn.Module) – Discriminator model for obtaining losses.
- optG (Optimizer) – Optimizer for updating generator’s parameters.
- log_data (dict) – A dict mapping name to values for logging uses.
- device (torch.device) – Device to use for running the model.
- global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns: Returns MetricLog object containing updated logging variables after 1 training step.
-
Base Conditional GAN¶
Implementation of Base GAN models for a generic conditional GAN.
-
class
BaseConditionalDiscriminator
(num_classes, ndf, loss_type, **kwargs)[source]¶ Base class for a generic conditional discriminator model.
-
train_step
(real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs)[source]¶ Takes one training step for D.
Parameters: - real_batch (Tensor) – A batch of real images of shape (N, C, H, W).
- loss_type (str) – Name of loss to use for GAN loss.
- netG (nn.Module) – Generator model for obtaining fake images.
- optD (Optimizer) – Optimizer for updating discriminator’s parameters.
- device (torch.device) – Device to use for running the model.
- log_data (MetricLog) – A dict mapping name to values for logging uses.
- global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns: Returns MetricLog object containing updated logging variables after 1 training step.
Return type:
-
-
class
BaseConditionalGenerator
(num_classes, nz, ngf, bottom_width, loss_type, **kwargs)[source]¶ Base class for a generic conditional generator model.
-
generate_images
(num_images, c=None, device=None)[source]¶ Generate images with possibility for conditioning on a fixed class.
Parameters: Returns: Batch of generated images and their corresponding labels.
Return type:
-
generate_images_with_labels
(num_images, c=None, device=None)[source]¶ Generate images with possibility for conditioning on a fixed class. Additionally returns labels.
Parameters: Returns: Batch of generated images and their corresponding labels.
Return type:
-
train_step
(real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs)[source]¶ Takes one training step for G.
Parameters: - real_batch (Tensor) – A batch of real images of shape (N, C, H, W). Used for obtaining current batch size.
- netD (nn.Module) – Discriminator model for obtaining losses.
- optG (Optimizer) – Optimizer for updating generator’s parameters.
- log_data (MetricLog) – A dict mapping name to values for logging uses.
- device (torch.device) – Device to use for running the model.
- global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns: Returns MetricLog object containing updated logging variables after 1 training step.
Return type:
-
torch_mimicry.nets.dcgan¶
DCGAN CIFAR¶
Implementation of DCGAN based on Kurach et al. specifically for CIFAR-10. The main difference with dcgan_32 is in using sigmoid as the final activation for the generator instead of tanh.
To reproduce scores, CIFAR-10 images should not be normalized from -1 to 1, and should instead have values from 0 to 1, which is the default when loading images as np arrays.
-
class
DCGANDiscriminatorCIFAR
(ndf=128, **kwargs)[source]¶ ResNet backbone discriminator for ResNet DCGAN.
DCGAN 32¶
Implementation of DCGAN for image size 32.
-
class
DCGANDiscriminator32
(ndf=128, **kwargs)[source]¶ ResNet backbone discriminator for ResNet DCGAN.
DCGAN 48¶
Implementation of DCGAN for image size 48.
-
class
DCGANDiscriminator48
(ndf=1024, **kwargs)[source]¶ ResNet backbone discriminator for ResNet DCGAN.
DCGAN 64¶
Implementation of DCGAN for image size 64.
-
class
DCGANDiscriminator64
(ndf=1024, **kwargs)[source]¶ ResNet backbone discriminator for ResNet DCGAN.
DCGAN 128¶
Implementation of DCGAN for image size 128.
-
class
DCGANDiscriminator128
(ndf=1024, **kwargs)[source]¶ ResNet backbone discriminator for ResNet DCGAN.
torch_mimicry.nets.wgan_gp¶
WGAN-GP 32¶
Implementation of WGAN-GP for image size 32.
WGAN-GP 48¶
Implementation of WGAN-GP for image size 48.
WGAN-GP 64¶
Implementation of WGAN-GP for image size 64.
WGAN-GP 128¶
Implementation of WGAN-GP for image size 128.
-
class
WGANGPDiscriminator128
(ndf=1024, **kwargs)[source]¶ ResNet backbone discriminator for WGAN-GP.
WGAN-GP Base¶
Base class implementation of WGAN-GP.
-
class
WGANGPBaseDiscriminator
(ndf, loss_type='wasserstein', gp_scale=10.0, **kwargs)[source]¶ ResNet backbone discriminator for WGAN-GP.
-
compute_gradient_penalty_loss
(real_images, fake_images, gp_scale=10.0)[source]¶ Computes gradient penalty loss, as based on: https://github.com/jalola/improved-wgan-pytorch/blob/master/gan_train.py
Parameters: - real_images (Tensor) – A batch of real images of shape (N, 3, H, W).
- fake_images (Tensor) – A batch of fake images of shape (N, 3, H, W).
- gp_scale (float) – Gradient penalty lamda parameter.
Returns: Scalar gradient penalty loss.
Return type: Tensor
-
train_step
(real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs)[source]¶ Takes one training step for D.
Parameters: - real_batch (Tensor) – A batch of real images of shape (N, C, H, W).
- netG (nn.Module) – Generator model for obtaining fake images.
- optD (Optimizer) – Optimizer for updating discriminator’s parameters.
- device (torch.device) – Device to use for running the model.
- log_data (MetricLog) – An object to add custom metrics for visualisations.
- global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns: Returns MetricLog object containing updated logging variables after 1 training step.
Return type:
-
-
class
WGANGPBaseGenerator
(nz, ngf, bottom_width, loss_type='wasserstein', **kwargs)[source]¶ ResNet backbone generator for WGAN-GP.
-
train_step
(real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs)[source]¶ Takes one training step for G.
Parameters: - real_batch (Tensor) – A batch of real images of shape (N, C, H, W). Used for obtaining current batch size.
- netD (nn.Module) – Discriminator model for obtaining losses.
- optG (Optimizer) – Optimizer for updating generator’s parameters.
- log_data (MetricLog) – An object to add custom metrics for visualisations.
- device (torch.device) – Device to use for running the model.
- global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns: Returns MetricLog object containing updated logging variables after 1 training step.
Return type:
-
torch_mimicry.nets.sngan¶
SNGAN 32¶
Implementation of SNGAN for image size 32.
SNGAN 48¶
Implementation of SNGAN for image size 48.
-
class
SNGANDiscriminator48
(ndf=1024, **kwargs)[source]¶ ResNet backbone discriminator for SNGAN.
- Attribates:
- ndf (int): Variable controlling discriminator feature map sizes. loss_type (str): Name of loss to use for GAN loss.
SNGAN 64¶
Implementation of SNGAN for image size 64.
SNGAN 128¶
Implementation of SNGAN for image size 128.
SNGAN Base¶
Base implementation of SNGAN with default variables.
-
class
SNGANBaseDiscriminator
(ndf, loss_type='hinge', **kwargs)[source]¶ ResNet backbone discriminator for SNGAN.
torch_mimicry.nets.cgan_pd¶
CGAN-PD 32¶
Implementation of cGAN-PD for image size 32.
-
class
CGANPDDiscriminator32
(num_classes, ndf=128, **kwargs)[source]¶ ResNet backbone discriminator for cGAN-PD.
-
forward
(x, y=None)[source]¶ Feedforwards a batch of real/fake images and produces a batch of GAN logits. Further projects labels to condition on the output logit score.
Parameters: - x (Tensor) – A batch of images of shape (N, C, H, W).
- y (Tensor) – A batch of labels of shape (N,).
Returns: A batch of GAN logits of shape (N, 1).
Return type: output (Tensor)
-
-
class
CGANPDGenerator32
(num_classes, bottom_width=4, nz=128, ngf=256, **kwargs)[source]¶ ResNet backbone generator for cGAN-PD,
-
forward
(x, y=None)[source]¶ Feedforwards a batch of noise vectors into a batch of fake images, also conditioning the batch norm with labels of the images to be produced.
Parameters: - x (Tensor) – A batch of noise vectors of shape (N, nz).
- y (Tensor) – A batch of labels of shape (N,) for conditional batch norm.
Returns: A batch of fake images of shape (N, C, H, W).
Return type: Tensor
-
CGAN-PD 128¶
Implementation of cGAN-PD for image size 128.
-
class
CGANPDDiscriminator128
(num_classes, ndf=128, **kwargs)[source]¶ ResNet backbone discriminator for cGAN-PD.
-
forward
(x, y=None)[source]¶ Feedforwards a batch of real/fake images and produces a batch of GAN logits. Further projects labels to condition on the output logit score.
Parameters: - x (Tensor) – A batch of images of shape (N, C, H, W).
- y (Tensor) – A batch of labels of shape (N,).
Returns: A batch of GAN logits of shape (N, 1).
Return type: Tensor
-
-
class
CGANPDGenerator128
(num_classes, nz=128, ngf=1024, bottom_width=4, **kwargs)[source]¶ ResNet backbone generator for cGAN-PD,
-
forward
(x, y=None)[source]¶ Feedforwards a batch of noise vectors into a batch of fake images, also conditioning the batch norm with labels of the images to be produced.
Parameters: - x (Tensor) – A batch of noise vectors of shape (N, nz).
- y (Tensor) – A batch of labels of shape (N,) for conditional batch norm.
Returns: A batch of fake images of shape (N, C, H, W).
Return type: Tensor
-
CGAN-PD Base¶
Base class definition of cGAN-PD.
-
class
CGANPDBaseDiscriminator
(num_classes, ndf, loss_type='hinge', **kwargs)[source]¶ ResNet backbone discriminator for cGAN-PD.
torch_mimicry.nets.ssgan¶
SSGAN 32¶
Implementation of SSGAN for image size 32.
-
class
SSGANDiscriminator32
(ndf=128, **kwargs)[source]¶ ResNet backbone discriminator for SSGAN.
-
forward
(x)[source]¶ Feedforwards a batch of real/fake images and produces a batch of GAN logits, and rotation classes.
Parameters: x (Tensor) – A batch of images of shape (N, C, H, W). Returns: A batch of GAN logits of shape (N, 1). Tensor: A batch of predicted classes of shape (N, num_classes). Return type: Tensor
-
SSGAN 48¶
Implementation of SSGAN for image size 48.
-
class
SSGANDiscriminator48
(ndf=1024, **kwargs)[source]¶ ResNet backbone discriminator for SSGAN.
-
forward
(x)[source]¶ Feedforwards a batch of real/fake images and produces a batch of GAN logits, and rotation classes.
Parameters: x (Tensor) – A batch of images of shape (N, C, H, W). Returns: A batch of GAN logits of shape (N, 1). Tensor: A batch of predicted classes of shape (N, num_classes). Return type: Tensor
-
SSGAN 64¶
Implementation of SSGAN for image size 64.
-
class
SSGANDiscriminator64
(ndf=1024, **kwargs)[source]¶ ResNet backbone discriminator for SSGAN.
-
forward
(x)[source]¶ Feedforwards a batch of real/fake images and produces a batch of GAN logits, and rotation classes.
Parameters: x (Tensor) – A batch of images of shape (N, C, H, W). Returns: A batch of GAN logits of shape (N, 1). Tensor: A batch of predicted classes of shape (N, num_classes). Return type: Tensor
-
SSGAN 128¶
Implementation of SSGAN for image size 128.
-
class
SSGANDiscriminator128
(ndf=1024, **kwargs)[source]¶ ResNet backbone discriminator for SSGAN.
-
forward
(x)[source]¶ Feedforwards a batch of real/fake images and produces a batch of GAN logits, and rotation classes.
Parameters: x (Tensor) – A batch of images of shape (N, C, H, W). Returns: A batch of GAN logits of shape (N, 1). Tensor: A batch of predicted classes of shape (N, num_classes). Return type: Tensor
-
SSGAN Base¶
Implementation of Base SSGAN models.
-
class
SSGANBaseDiscriminator
(ndf, loss_type='hinge', ss_loss_scale=1.0, **kwargs)[source]¶ ResNet backbone discriminator for SSGAN.
-
compute_ss_loss
(images, scale)[source]¶ Function to compute SS loss.
Parameters: - images (Tensor) – A batch of non-rotated, upright images.
- scale (float) – The parameter to scale SS loss by.
Returns: Scalar tensor representing the SS loss.
Return type: Tensor
-
train_step
(real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs)[source]¶ Takes one training step for D.
Parameters: - real_batch (Tensor) – A batch of real images of shape (N, C, H, W).
- netG (nn.Module) – Generator model for obtaining fake images.
- optD (Optimizer) – Optimizer for updating discriminator’s parameters.
- device (torch.device) – Device to use for running the model.
- log_data (MetricLog) – An object to add custom metrics for visualisations.
- global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns: Returns MetricLog object containing updated logging variables after 1 training step.
Return type:
-
-
class
SSGANBaseGenerator
(nz, ngf, bottom_width, loss_type='hinge', ss_loss_scale=0.2, **kwargs)[source]¶ ResNet backbone generator for SSGAN.
-
train_step
(real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs)[source]¶ Takes one training step for G.
Parameters: - real_batch (Tensor) – A batch of real images of shape (N, C, H, W). Used for obtaining current batch size.
- netD (nn.Module) – Discriminator model for obtaining losses.
- optG (Optimizer) – Optimizer for updating generator’s parameters.
- log_data (MetricLog) – An object to add custom metrics for visualisations.
- device (torch.device) – Device to use for running the model.
- global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns: Returns MetricLog object containing updated logging variables after 1 training step.
Return type:
-
torch_mimicry.nets.infomax_gan¶
InfoMax-GAN 32¶
Implementation of InfoMax-GAN for image size 32.
-
class
InfoMaxGANDiscriminator32
(nrkhs=1024, ndf=128, **kwargs)[source]¶ ResNet backbone discriminator for InfoMax-GAN.
-
forward
(x)[source]¶ Feedforwards a batch of real/fake images and produces a batch of GAN logits, local features of the images, and global features of the images.
Parameters: x (Tensor) – A batch of images of shape (N, C, H, W). Returns: A batch of GAN logits of shape (N, 1). Tensor: A batch of local features of shape (N, ndf, H>>2, W>>2). Tensor: A batch of global features of shape (N, ndf) Return type: Tensor
-
InfoMax-GAN 48¶
Implementation of InfoMax-GAN for image size 48.
-
class
InfoMaxGANDiscriminator48
(ndf=1024, nrkhs=1024, **kwargs)[source]¶ ResNet backbone discriminator for SNGAN-InfoMax.
-
forward
(x)[source]¶ Feedforwards a batch of real/fake images and produces a batch of GAN logits, local features of the images, and global features of the images.
Parameters: x (Tensor) – A batch of images of shape (N, C, H, W). Returns: A batch of GAN logits of shape (N, 1). Tensor: A batch of local features of shape (N, ndf, H>>2, W>>2). Tensor: A batch of global features of shape (N, ndf) Return type: Tensor
-
InfoMax-GAN 64¶
Implementation of InfoMax-GAN for image size 64.
-
class
InfoMaxGANDiscriminator64
(nrkhs=1024, ndf=1024, **kwargs)[source]¶ ResNet backbone discriminator for InfoMax-GAN.
-
forward
(x)[source]¶ Feedforwards a batch of real/fake images and produces a batch of GAN logits, local features of the images, and global features of the images.
Parameters: x (Tensor) – A batch of images of shape (N, C, H, W). Returns: A batch of GAN logits of shape (N, 1). Tensor: A batch of local features of shape (N, ndf, H>>2, W>>2). Tensor: A batch of global features of shape (N, ndf) Return type: Tensor
-
InfoMax-GAN 128¶
Implementation of InfoMax-GAN for image size 128.
-
class
InfoMaxGANDiscriminator128
(nrkhs=1024, ndf=1024, **kwargs)[source]¶ ResNet backbone discriminator for SNGAN-InfoMax.
-
forward
(x)[source]¶ Feedforwards a batch of real/fake images and produces a batch of GAN logits, local features of the images, and global features of the images.
Parameters: x (Tensor) – A batch of images of shape (N, C, H, W). Returns: A batch of GAN logits of shape (N, 1). Tensor: A batch of local features of shape (N, ndf, H>>2, W>>2). Tensor: A batch of global features of shape (N, ndf) Return type: Tensor
-
InfoMax-GAN Base¶
Implementation of InfoMax-GAN base model.
-
class
BaseDiscriminator
(nrkhs, ndf, loss_type='hinge', infomax_loss_scale=0.2, **kwargs)[source]¶ ResNet backbone discriminator for SNGAN-Infomax.
-
compute_infomax_loss
(local_feat, global_feat, scale)[source]¶ Given local and global features of a real or fake image, produce the average dot product score between each local and global features, which is then used to obtain infoNCE loss.
- Args
- local_feat (Tensor): A batch of local features. global_feat (Tensor): A batch of global features. scale (float): The scaling hyperparameter for the infomax loss.
Returns: Scalar Tensor representing the scaled infomax loss. Return type: Tensor
-
infonce_loss
(l, m)[source]¶ InfoNCE loss for local and global feature maps as used in DIM: https://github.com/rdevon/DIM/blob/master/cortex_DIM/functions/dim_losses.py
Parameters: - l (Tensor) – Local feature map of shape (N, ndf, H*W).
- m (Tensor) – Global feature vector of shape (N, ndf, 1).
Returns: Scalar loss Tensor.
Return type: Tensor
-
train_step
(real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs)[source]¶ Takes one training step for D.
Parameters: - real_batch (Tensor) – A batch of real images of shape (N, C, H, W).
- netG (nn.Module) – Generator model for obtaining fake images.
- optD (Optimizer) – Optimizer for updating discriminator’s parameters.
- device (torch.device) – Device to use for running the model.
- log_data (MetricLog) – An object to add custom metrics for visualisations.
- global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns: Returns MetricLog object containing updated logging variables after 1 training step.
Return type:
-
-
class
InfoMaxGANBaseGenerator
(nz, ngf, bottom_width, loss_type='hinge', infomax_loss_scale=0.2, **kwargs)[source]¶ ResNet backbone generator for InfoMax-GAN.
-
train_step
(real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs)[source]¶ Takes one training step for G.
Parameters: - real_batch (Tensor) – A batch of real images of shape (N, C, H, W). Used for obtaining current batch size.
- netD (nn.Module) – Discriminator model for obtaining losses.
- optG (Optimizer) – Optimizer for updating generator’s parameters.
- log_data (MetricLog) – An object to add custom metrics for visualisations.
- device (torch.device) – Device to use for running the model.
- global_step (int) – Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training.
Returns: Returns MetricLog object containing updated logging variables after 1 training step.
Return type:
-