Source code for torch_mimicry.nets.sngan.sngan_base

"""
Base implementation of SNGAN with default variables.
"""
from torch_mimicry.nets.gan import gan


[docs]class SNGANBaseGenerator(gan.BaseGenerator): r""" ResNet backbone generator for SNGAN. Attributes: nz (int): Noise dimension for upsampling. ngf (int): Variable controlling generator feature map sizes. bottom_width (int): Starting width for upsampling generator output to an image. loss_type (str): Name of loss to use for GAN loss. """ def __init__(self, nz, ngf, bottom_width, loss_type='hinge', **kwargs): super().__init__(nz=nz, ngf=ngf, bottom_width=bottom_width, loss_type=loss_type, **kwargs)
[docs]class SNGANBaseDiscriminator(gan.BaseDiscriminator): r""" ResNet backbone discriminator for SNGAN. Attributes: ndf (int): Variable controlling discriminator feature map sizes. """ def __init__(self, ndf, loss_type='hinge', **kwargs): super().__init__(ndf=ndf, loss_type=loss_type, **kwargs)