Source code for torch_mimicry.nets.gan.gan

"""
Implementation of Base GAN models.
"""
import torch

from torch_mimicry.nets.basemodel import basemodel
from torch_mimicry.modules import losses


[docs]class BaseGenerator(basemodel.BaseModel): r""" Base class for a generic unconditional generator model. Attributes: nz (int): Noise dimension for upsampling. ngf (int): Variable controlling generator feature map sizes. bottom_width (int): Starting width for upsampling generator output to an image. loss_type (str): Name of loss to use for GAN loss. """ def __init__(self, nz, ngf, bottom_width, loss_type, **kwargs): super().__init__(**kwargs) self.nz = nz self.ngf = ngf self.bottom_width = bottom_width self.loss_type = loss_type
[docs] def generate_images(self, num_images, device=None): r""" Generates num_images randomly. Args: num_images (int): Number of images to generate device (torch.device): Device to send images to. Returns: Tensor: A batch of generated images. """ if device is None: device = self.device noise = torch.randn((num_images, self.nz), device=device) fake_images = self.forward(noise) return fake_images
[docs] def compute_gan_loss(self, output): r""" Computes GAN loss for generator. Args: output (Tensor): A batch of output logits from the discriminator of shape (N, 1). Returns: Tensor: A batch of GAN losses for the generator. """ # Compute loss and backprop if self.loss_type == "gan": errG = losses.minimax_loss_gen(output) elif self.loss_type == "ns": errG = losses.ns_loss_gen(output) elif self.loss_type == "hinge": errG = losses.hinge_loss_gen(output) elif self.loss_type == "wasserstein": errG = losses.wasserstein_loss_gen(output) else: raise ValueError("Invalid loss_type {} selected.".format( self.loss_type)) return errG
[docs] def train_step(self, real_batch, netD, optG, log_data, device=None, global_step=None, **kwargs): r""" Takes one training step for G. Args: real_batch (Tensor): A batch of real images of shape (N, C, H, W). Used for obtaining current batch size. netD (nn.Module): Discriminator model for obtaining losses. optG (Optimizer): Optimizer for updating generator's parameters. log_data (dict): A dict mapping name to values for logging uses. device (torch.device): Device to use for running the model. global_step (int): Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training. Returns: Returns MetricLog object containing updated logging variables after 1 training step. """ self.zero_grad() # Get only batch size from real batch batch_size = real_batch[0].shape[0] # Produce fake images fake_images = self.generate_images(num_images=batch_size, device=device) # Compute output logit of D thinking image real output = netD(fake_images) # Compute loss errG = self.compute_gan_loss(output=output) # Backprop and update gradients errG.backward() optG.step() # Log statistics log_data.add_metric('errG', errG, group='loss') return log_data
[docs]class BaseDiscriminator(basemodel.BaseModel): r""" Base class for a generic unconditional discriminator model. Attributes: ndf (int): Variable controlling discriminator feature map sizes. loss_type (str): Name of loss to use for GAN loss. """ def __init__(self, ndf, loss_type, **kwargs): super().__init__(**kwargs) self.ndf = ndf self.loss_type = loss_type
[docs] def compute_gan_loss(self, output_real, output_fake): r""" Computes GAN loss for discriminator. Args: output_real (Tensor): A batch of output logits of shape (N, 1) from real images. output_fake (Tensor): A batch of output logits of shape (N, 1) from fake images. Returns: errD (Tensor): A batch of GAN losses for the discriminator. """ # Compute loss for D if self.loss_type == "gan" or self.loss_type == "ns": errD = losses.minimax_loss_dis(output_fake=output_fake, output_real=output_real) elif self.loss_type == "hinge": errD = losses.hinge_loss_dis(output_fake=output_fake, output_real=output_real) elif self.loss_type == "wasserstein": errD = losses.wasserstein_loss_dis(output_fake=output_fake, output_real=output_real) else: raise ValueError("Invalid loss_type selected.") return errD
[docs] def compute_probs(self, output_real, output_fake): r""" Computes probabilities from real/fake images logits. Args: output_real (Tensor): A batch of output logits of shape (N, 1) from real images. output_fake (Tensor): A batch of output logits of shape (N, 1) from fake images. Returns: tuple: Average probabilities of real/fake image considered as real for the batch. """ D_x = torch.sigmoid(output_real).mean().item() D_Gz = torch.sigmoid(output_fake).mean().item() return D_x, D_Gz
[docs] def train_step(self, real_batch, netG, optD, log_data, device=None, global_step=None, **kwargs): r""" Takes one training step for D. Args: real_batch (Tensor): A batch of real images of shape (N, C, H, W). loss_type (str): Name of loss to use for GAN loss. netG (nn.Module): Generator model for obtaining fake images. optD (Optimizer): Optimizer for updating discriminator's parameters. device (torch.device): Device to use for running the model. log_data (dict): A dict mapping name to values for logging uses. global_step (int): Variable to sync training, logging and checkpointing. Useful for dynamic changes to model amidst training. Returns: MetricLog: Returns MetricLog object containing updated logging variables after 1 training step. """ self.zero_grad() real_images, real_labels = real_batch batch_size = real_images.shape[0] # Match batch sizes for last iter # Produce logits for real images output_real = self.forward(real_images) # Produce fake images fake_images = netG.generate_images(num_images=batch_size, device=device).detach() # Produce logits for fake images output_fake = self.forward(fake_images) # Compute loss for D errD = self.compute_gan_loss(output_real=output_real, output_fake=output_fake) # Backprop and update gradients errD.backward() optD.step() # Compute probabilities D_x, D_Gz = self.compute_probs(output_real=output_real, output_fake=output_fake) # Log statistics for D once out of loop log_data.add_metric('errD', errD.item(), group='loss') log_data.add_metric('D(x)', D_x, group='prob') log_data.add_metric('D(G(z))', D_Gz, group='prob') return log_data