Source code for torch_mimicry.nets.basemodel.basemodel

"""
Implementation of BaseModel.
"""
import os
from abc import ABC, abstractmethod

import torch
import torch.nn as nn


[docs]class BaseModel(nn.Module, ABC): r""" BaseModel with basic functionalities for checkpointing and restoration. """ def __init__(self): super().__init__() @abstractmethod def forward(self, x): pass @property def device(self): return next(self.parameters()).device
[docs] def restore_checkpoint(self, ckpt_file, optimizer=None): r""" Restores checkpoint from a pth file and restores optimizer state. Args: ckpt_file (str): A PyTorch pth file containing model weights. optimizer (Optimizer): A vanilla optimizer to have its state restored from. Returns: int: Global step variable where the model was last checkpointed. """ if not ckpt_file: raise ValueError("No checkpoint file to be restored.") try: ckpt_dict = torch.load(ckpt_file) except RuntimeError: ckpt_dict = torch.load(ckpt_file, map_location=lambda storage, loc: storage) # Restore model weights self.load_state_dict(ckpt_dict['model_state_dict']) # Restore optimizer status if existing. Evaluation doesn't need this if optimizer: optimizer.load_state_dict(ckpt_dict['optimizer_state_dict']) # Return global step return ckpt_dict['global_step']
[docs] def save_checkpoint(self, directory, global_step, optimizer=None, name=None): r""" Saves checkpoint at a certain global step during training. Optimizer state is also saved together. Args: directory (str): Path to save checkpoint to. global_step (int): The global step variable during training. optimizer (Optimizer): Optimizer state to be saved concurrently. name (str): The name to save the checkpoint file as. Returns: None """ # Create directory to save to if not os.path.exists(directory): os.makedirs(directory) # Build checkpoint dict to save. ckpt_dict = { 'model_state_dict': self.state_dict(), 'optimizer_state_dict': optimizer.state_dict() if optimizer is not None else None, 'global_step': global_step } # Save the file with specific name if name is None: name = "{}_{}_steps.pth".format( os.path.basename(directory), # netD or netG global_step) torch.save(ckpt_dict, os.path.join(directory, name))
[docs] def count_params(self): r""" Computes the number of parameters in this model. Args: None Returns: int: Total number of weight parameters for this model. int: Total number of trainable parameters for this model. """ num_total_params = sum(p.numel() for p in self.parameters()) num_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) return num_total_params, num_trainable_params