Source code for torch_mimicry.modules.resblocks

"""
Implementation of residual blocks for discriminator and generator.
We follow the official SNGAN Chainer implementation as closely as possible:
https://github.com/pfnet-research/sngan_projection
"""
import math

import torch.nn as nn
import torch.nn.functional as F

from torch_mimicry.modules import SNConv2d, ConditionalBatchNorm2d


[docs]class GBlock(nn.Module): r""" Residual block for generator. Uses bilinear (rather than nearest) interpolation, and align_corners set to False. This is as per how torchvision does upsampling, as seen in: https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/_utils.py Attributes: in_channels (int): The channel size of input feature map. out_channels (int): The channel size of output feature map. hidden_channels (int): The channel size of intermediate feature maps. upsample (bool): If True, upsamples the input feature map. num_classes (int): If more than 0, uses conditional batch norm instead. spectral_norm (bool): If True, uses spectral norm for convolutional layers. """ def __init__(self, in_channels, out_channels, hidden_channels=None, upsample=False, num_classes=0, spectral_norm=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = hidden_channels if hidden_channels is not None else out_channels self.learnable_sc = in_channels != out_channels or upsample self.upsample = upsample self.num_classes = num_classes self.spectral_norm = spectral_norm # Build the layers # Note: Can't use something like self.conv = SNConv2d to save code length # this results in somehow spectral norm working worse consistently. if self.spectral_norm: self.c1 = SNConv2d(self.in_channels, self.hidden_channels, 3, 1, padding=1) self.c2 = SNConv2d(self.hidden_channels, self.out_channels, 3, 1, padding=1) else: self.c1 = nn.Conv2d(self.in_channels, self.hidden_channels, 3, 1, padding=1) self.c2 = nn.Conv2d(self.hidden_channels, self.out_channels, 3, 1, padding=1) if self.num_classes == 0: self.b1 = nn.BatchNorm2d(self.in_channels) self.b2 = nn.BatchNorm2d(self.hidden_channels) else: self.b1 = ConditionalBatchNorm2d(self.in_channels, self.num_classes) self.b2 = ConditionalBatchNorm2d(self.hidden_channels, self.num_classes) self.activation = nn.ReLU(True) nn.init.xavier_uniform_(self.c1.weight.data, math.sqrt(2.0)) nn.init.xavier_uniform_(self.c2.weight.data, math.sqrt(2.0)) # Shortcut layer if self.learnable_sc: if self.spectral_norm: self.c_sc = SNConv2d(in_channels, out_channels, 1, 1, padding=0) else: self.c_sc = nn.Conv2d(in_channels, out_channels, 1, 1, padding=0) nn.init.xavier_uniform_(self.c_sc.weight.data, 1.0) def _upsample_conv(self, x, conv): r""" Helper function for performing convolution after upsampling. """ return conv( F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)) def _residual(self, x): r""" Helper function for feedforwarding through main layers. """ h = x h = self.b1(h) h = self.activation(h) h = self._upsample_conv(h, self.c1) if self.upsample else self.c1(h) h = self.b2(h) h = self.activation(h) h = self.c2(h) return h def _residual_conditional(self, x, y): r""" Helper function for feedforwarding through main layers, including conditional BN. """ h = x h = self.b1(h, y) h = self.activation(h) h = self._upsample_conv(h, self.c1) if self.upsample else self.c1(h) h = self.b2(h, y) h = self.activation(h) h = self.c2(h) return h def _shortcut(self, x): r""" Helper function for feedforwarding through shortcut layers. """ if self.learnable_sc: x = self._upsample_conv( x, self.c_sc) if self.upsample else self.c_sc(x) return x else: return x
[docs] def forward(self, x, y=None): r""" Residual block feedforward function. """ if y is None: return self._residual(x) + self._shortcut(x) else: return self._residual_conditional(x, y) + self._shortcut(x)
[docs]class DBlock(nn.Module): """ Residual block for discriminator. Attributes: in_channels (int): The channel size of input feature map. out_channels (int): The channel size of output feature map. hidden_channels (int): The channel size of intermediate feature maps. downsample (bool): If True, downsamples the input feature map. spectral_norm (bool): If True, uses spectral norm for convolutional layers. """ def __init__(self, in_channels, out_channels, hidden_channels=None, downsample=False, spectral_norm=True): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.hidden_channels = hidden_channels if hidden_channels is not None else in_channels self.downsample = downsample self.learnable_sc = (in_channels != out_channels) or downsample self.spectral_norm = spectral_norm # Build the layers if self.spectral_norm: self.c1 = SNConv2d(self.in_channels, self.hidden_channels, 3, 1, 1) self.c2 = SNConv2d(self.hidden_channels, self.out_channels, 3, 1, 1) else: self.c1 = nn.Conv2d(self.in_channels, self.hidden_channels, 3, 1, 1) self.c2 = nn.Conv2d(self.hidden_channels, self.out_channels, 3, 1, 1) self.activation = nn.ReLU(True) nn.init.xavier_uniform_(self.c1.weight.data, math.sqrt(2.0)) nn.init.xavier_uniform_(self.c2.weight.data, math.sqrt(2.0)) # Shortcut layer if self.learnable_sc: if self.spectral_norm: self.c_sc = SNConv2d(in_channels, out_channels, 1, 1, 0) else: self.c_sc = nn.Conv2d(in_channels, out_channels, 1, 1, 0) nn.init.xavier_uniform_(self.c_sc.weight.data, 1.0) def _residual(self, x): """ Helper function for feedforwarding through main layers. """ h = x h = self.activation(h) h = self.c1(h) h = self.activation(h) h = self.c2(h) if self.downsample: h = F.avg_pool2d(h, 2) return h def _shortcut(self, x): """ Helper function for feedforwarding through shortcut layers. """ if self.learnable_sc: x = self.c_sc(x) return F.avg_pool2d(x, 2) if self.downsample else x else: return x
[docs] def forward(self, x): """ Residual block feedforward function. """ return self._residual(x) + self._shortcut(x)
[docs]class DBlockOptimized(nn.Module): """ Optimized residual block for discriminator. This is used as the first residual block, where there is a definite downsampling involved. Follows the official SNGAN reference implementation in chainer. Attributes: in_channels (int): The channel size of input feature map. out_channels (int): The channel size of output feature map. spectral_norm (bool): If True, uses spectral norm for convolutional layers. """ def __init__(self, in_channels, out_channels, spectral_norm=True): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.spectral_norm = spectral_norm # Build the layers if self.spectral_norm: self.c1 = SNConv2d(self.in_channels, self.out_channels, 3, 1, 1) self.c2 = SNConv2d(self.out_channels, self.out_channels, 3, 1, 1) self.c_sc = SNConv2d(self.in_channels, self.out_channels, 1, 1, 0) else: self.c1 = nn.Conv2d(self.in_channels, self.out_channels, 3, 1, 1) self.c2 = nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1) self.c_sc = nn.Conv2d(self.in_channels, self.out_channels, 1, 1, 0) self.activation = nn.ReLU(True) nn.init.xavier_uniform_(self.c1.weight.data, math.sqrt(2.0)) nn.init.xavier_uniform_(self.c2.weight.data, math.sqrt(2.0)) nn.init.xavier_uniform_(self.c_sc.weight.data, 1.0) def _residual(self, x): """ Helper function for feedforwarding through main layers. """ h = x h = self.c1(h) h = self.activation(h) h = self.c2(h) h = F.avg_pool2d(h, 2) return h def _shortcut(self, x): """ Helper function for feedforwarding through shortcut layers. """ return self.c_sc(F.avg_pool2d(x, 2))
[docs] def forward(self, x): """ Residual block feedforward function. """ return self._residual(x) + self._shortcut(x)