Source code for torch_mimicry.nets.wgan_gp.wgan_gp_base

"""
Base class implementation of WGAN-GP.
"""
import torch
from torch import autograd

from torch_mimicry.nets.gan import gan


[docs]class WGANGPBaseGenerator(gan.BaseGenerator): r""" ResNet backbone generator for WGAN-GP. Attributes: nz (int): Noise dimension for upsampling. ngf (int): Variable controlling generator feature map sizes. bottom_width (int): Starting width for upsampling generator output to an image. loss_type (str): Name of loss to use for GAN loss. """ def __init__(self, nz, ngf, bottom_width, loss_type='wasserstein', **kwargs): super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, loss_type=loss_type, **kwargs)
[docs] def train_step(self, real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs): r""" Takes one training step for G. Args: real_batch (Tensor): A batch of real images of shape (N, C, H, W). Used for obtaining current batch size. netD (nn.Module): Discriminator model for obtaining losses. optG (Optimizer): Optimizer for updating generator's parameters. log_data (MetricLog): An object to add custom metrics for visualisations. device (torch.device): Device to use for running the model. global_step (int): Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training. Returns: MetricLog: Returns MetricLog object containing updated logging variables after 1 training step. """ self.zero_grad() # Get only batch size from real batch batch_size = real_batch[0].shape[0] # Produce fake images fake_images = self.generate_images(num_images=batch_size, device=device) # Compute output logit of D thinking image real output = netD(fake_images) # Compute loss errG = self.compute_gan_loss(output) # Backprop and update gradients errG.backward() optG.step() # Log statistics log_data.add_metric('errG', errG, group='loss') return log_data
[docs]class WGANGPBaseDiscriminator(gan.BaseDiscriminator): r""" ResNet backbone discriminator for WGAN-GP. Attributes: ndf (int): Variable controlling discriminator feature map sizes. loss_type (str): Name of loss to use for GAN loss. gp_scale (float): Lamda parameter for gradient penalty. """ def __init__(self, ndf, loss_type='wasserstein', gp_scale=10.0, **kwargs): super().__init__(ndf=ndf, loss_type=loss_type, **kwargs) self.gp_scale = gp_scale
[docs] def train_step(self, real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs): r""" Takes one training step for D. Args: real_batch (Tensor): A batch of real images of shape (N, C, H, W). netG (nn.Module): Generator model for obtaining fake images. optD (Optimizer): Optimizer for updating discriminator's parameters. device (torch.device): Device to use for running the model. log_data (MetricLog): An object to add custom metrics for visualisations. global_step (int): Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training. Returns: MetricLog: Returns MetricLog object containing updated logging variables after 1 training step. """ self.zero_grad() # Produce real images real_images, _ = real_batch batch_size = real_images.shape[0] # Match batch sizes for last iter # Produce fake images fake_images = netG.generate_images(num_images=batch_size, device=device).detach() # Produce logits for real and fake images output_real = self.forward(real_images) output_fake = self.forward(fake_images) # Compute losses errD = self.compute_gan_loss(output_real=output_real, output_fake=output_fake) errD_GP = self.compute_gradient_penalty_loss(real_images=real_images, fake_images=fake_images, gp_scale=self.gp_scale) # Backprop and update gradients errD_total = errD + errD_GP errD_total.backward() optD.step() # Compute probabilities D_x, D_Gz = self.compute_probs(output_real=output_real, output_fake=output_fake) log_data.add_metric('errD', errD, group='loss') log_data.add_metric('D(x)', D_x, group='prob') log_data.add_metric('D(G(z))', D_Gz, group='prob') return log_data
[docs] def compute_gradient_penalty_loss(self, real_images, fake_images, gp_scale=10.0): r""" Computes gradient penalty loss, as based on: https://github.com/jalola/improved-wgan-pytorch/blob/master/gan_train.py Args: real_images (Tensor): A batch of real images of shape (N, 3, H, W). fake_images (Tensor): A batch of fake images of shape (N, 3, H, W). gp_scale (float): Gradient penalty lamda parameter. Returns: Tensor: Scalar gradient penalty loss. """ # Obtain parameters N, _, H, W = real_images.shape device = real_images.device # Randomly sample some alpha between 0 and 1 for interpolation # where alpha is of the same shape for elementwise multiplication. alpha = torch.rand(N, 1) alpha = alpha.expand(N, int(real_images.nelement() / N)).contiguous() alpha = alpha.view(N, 3, H, W) alpha = alpha.to(device) # Obtain interpolates on line between real/fake images. interpolates = alpha * real_images.detach() \ + ((1 - alpha) * fake_images.detach()) interpolates = interpolates.to(device) interpolates.requires_grad_(True) # Get gradients of interpolates disc_interpolates = self.forward(interpolates) gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones( disc_interpolates.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) # Compute GP loss gradient_penalty = ( (gradients.norm(2, dim=1) - 1)**2).mean() * gp_scale return gradient_penalty