"""
Implementation of BaseModel.
"""
import os
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
[docs]class BaseModel(nn.Module, ABC):
r"""
BaseModel with basic functionalities for checkpointing and restoration.
"""
def __init__(self):
super().__init__()
@abstractmethod
def forward(self, x):
pass
@property
def device(self):
return next(self.parameters()).device
[docs] def restore_checkpoint(self, ckpt_file, optimizer=None):
r"""
Restores checkpoint from a pth file and restores optimizer state.
Args:
ckpt_file (str): A PyTorch pth file containing model weights.
optimizer (Optimizer): A vanilla optimizer to have its state restored from.
Returns:
int: Global step variable where the model was last checkpointed.
"""
if not ckpt_file:
raise ValueError("No checkpoint file to be restored.")
try:
ckpt_dict = torch.load(ckpt_file)
except RuntimeError:
ckpt_dict = torch.load(ckpt_file,
map_location=lambda storage, loc: storage)
# Restore model weights
self.load_state_dict(ckpt_dict['model_state_dict'])
# Restore optimizer status if existing. Evaluation doesn't need this
if optimizer:
optimizer.load_state_dict(ckpt_dict['optimizer_state_dict'])
# Return global step
return ckpt_dict['global_step']
[docs] def save_checkpoint(self,
directory,
global_step,
optimizer=None,
name=None):
r"""
Saves checkpoint at a certain global step during training. Optimizer state
is also saved together.
Args:
directory (str): Path to save checkpoint to.
global_step (int): The global step variable during training.
optimizer (Optimizer): Optimizer state to be saved concurrently.
name (str): The name to save the checkpoint file as.
Returns:
None
"""
# Create directory to save to
if not os.path.exists(directory):
os.makedirs(directory)
# Build checkpoint dict to save.
ckpt_dict = {
'model_state_dict':
self.state_dict(),
'optimizer_state_dict':
optimizer.state_dict() if optimizer is not None else None,
'global_step':
global_step
}
# Save the file with specific name
if name is None:
name = "{}_{}_steps.pth".format(
os.path.basename(directory), # netD or netG
global_step)
torch.save(ckpt_dict, os.path.join(directory, name))
[docs] def count_params(self):
r"""
Computes the number of parameters in this model.
Args: None
Returns:
int: Total number of weight parameters for this model.
int: Total number of trainable parameters for this model.
"""
num_total_params = sum(p.numel() for p in self.parameters())
num_trainable_params = sum(p.numel() for p in self.parameters()
if p.requires_grad)
return num_total_params, num_trainable_params