Source code for torch_mimicry.modules.losses

"""
Loss functions definitions.
"""
import torch
import torch.nn.functional as F


def _bce_loss_with_logits(output, labels, **kwargs):
    r"""
    Wrapper for BCE loss with logits.
    """
    return F.binary_cross_entropy_with_logits(output, labels, **kwargs)


[docs]def minimax_loss_gen(output_fake, real_label_val=1.0, **kwargs): r""" Standard minimax loss for GANs through the BCE Loss with logits fn. Args: output (Tensor): Discriminator output logits. labels (Tensor): Labels for computing cross entropy. Returns: Tensor: A scalar tensor loss output. """ # Produce real labels so G is rewarded if D is fooled real_labels = torch.full((output_fake.shape[0], 1), real_label_val, device=output_fake.device) loss = _bce_loss_with_logits(output_fake, real_labels, **kwargs) return loss
[docs]def minimax_loss_dis(output_fake, output_real, real_label_val=1.0, fake_label_val=0.0, **kwargs): r""" Standard minimax loss for GANs through the BCE Loss with logits fn. Args: output_fake (Tensor): Discriminator output logits for fake images. output_real (Tensor): Discriminator output logits for real images. real_label_val (int): Label for real images. fake_label_val (int): Label for fake images. device (torch.device): Torch device object for sending created data. Returns: Tensor: A scalar tensor loss output. """ # Produce real and fake labels. fake_labels = torch.full((output_fake.shape[0], 1), fake_label_val, device=output_fake.device) real_labels = torch.full((output_real.shape[0], 1), real_label_val, device=output_real.device) # FF, compute loss and backprop D errD_fake = _bce_loss_with_logits(output=output_fake, labels=fake_labels, **kwargs) errD_real = _bce_loss_with_logits(output=output_real, labels=real_labels, **kwargs) # Compute cumulative error loss = errD_real + errD_fake return loss
[docs]def ns_loss_gen(output_fake): r""" Non-saturating loss for generator. Args: output_fake (Tensor): Discriminator output logits for fake images. Returns: Tensor: A scalar tensor loss output. """ output_fake = torch.sigmoid(output_fake) return -torch.mean(torch.log(output_fake + 1e-8))
[docs]def wasserstein_loss_dis(output_real, output_fake): r""" Computes the wasserstein loss for the discriminator. Args: output_real (Tensor): Discriminator output logits for real images. output_fake (Tensor): Discriminator output logits for fake images. Returns: Tensor: A scalar tensor loss output. """ loss = -1.0 * output_real.mean() + output_fake.mean() return loss
[docs]def wasserstein_loss_gen(output_fake): r""" Computes the wasserstein loss for generator. Args: output_fake (Tensor): Discriminator output logits for fake images. Returns: Tensor: A scalar tensor loss output. """ loss = -output_fake.mean() return loss
[docs]def hinge_loss_dis(output_fake, output_real): r""" Hinge loss for discriminator. Args: output_fake (Tensor): Discriminator output logits for fake images. output_real (Tensor): Discriminator output logits for real images. Returns: Tensor: A scalar tensor loss output. """ loss = F.relu(1.0 - output_real).mean() + \ F.relu(1.0 + output_fake).mean() return loss
[docs]def hinge_loss_gen(output_fake): r""" Hinge loss for generator. Args: output_fake (Tensor): Discriminator output logits for fake images. Returns: Tensor: A scalar tensor loss output. """ loss = -output_fake.mean() return loss