Source code for torch_mimicry.metrics.compute_is

"""
PyTorch interface for computing Inception Score.
"""
import os
import random
import time

import numpy as np
import torch

from torch_mimicry.metrics.inception_model import inception_utils
from torch_mimicry.metrics.inception_score import inception_score_utils as tf_inception_score


def _normalize_images(images):
    """
    Given a tensor of images, uses the torchvision
    normalization method to convert floating point data to integers. See reference
    at: https://pytorch.org/docs/stable/_modules/torchvision/utils.html#save_image

    The function uses the normalization from make_grid and save_image functions.

    Args:
        images (Tensor): Batch of images of shape (N, 3, H, W).

    Returns:
        ndarray: Batch of normalized images of shape (N, H, W, 3).
    """
    # Shift the image from [-1, 1] range to [0, 1] range.
    min_val = float(images.min())
    max_val = float(images.max())
    images.clamp_(min=min_val, max=max_val)
    images.add_(-min_val).div_(max_val - min_val + 1e-5)

    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    images = images.mul_(255).add_(0.5).clamp_(0, 255).permute(0, 2, 3, 1).to(
        'cpu', torch.uint8).numpy()

    return images


[docs]def inception_score(num_samples, netG, device=None, batch_size=50, splits=10, log_dir='./log', seed=0, print_every=20): """ Computes the inception score of generated images. Args: netG (Module): The generator model to use for generating images. device (str/torch.device): Device identifier to use for computation. num_samples (int): The number of samples to generate. batch_size (int): Batch size per feedforward step for inception model. splits (int): The number of splits to use for computing IS. log_dir (str): Path to store metric computation objects. seed (int): Random seed for generation. Returns: Mean and standard deviation of the inception score computed from using num_samples generated images. """ start_time = time.time() if device is None: device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") # Make sure the random seeds are fixed torch.manual_seed(seed) random.seed(seed) np.random.seed(seed) # Build inception inception_path = os.path.join(log_dir, 'metrics/inception_model') inception_utils.create_inception_graph(inception_path) # Inference variables batch_size = min(batch_size, num_samples) num_batches = num_samples // batch_size # Get images images = [] with torch.no_grad(): start_time = time.time() for idx in range(num_batches): # noise = torch.randn((batch_size, netG.nz), device=device) # fake_images = netG(noise) fake_images = netG.generate_images(num_images=batch_size, device=device).detach().cpu() fake_images = _normalize_images(fake_images) images.append(fake_images) if (idx + 1) % min(print_every, num_batches) == 0: end_time = time.time() print( "INFO: Generated image {}/{} [Random Seed {}] ({:.4f} sec/idx)" .format( (idx + 1) * batch_size, num_samples, seed, (end_time - start_time) / (print_every * batch_size))) start_time = end_time images = np.concatenate(images, axis=0) is_mean, is_std = tf_inception_score.get_inception_score(images, splits=splits, device=device) print("INFO: Inception Score: {:.4f} ± {:.4f} [Time Taken: {:.4f} secs]". format(is_mean, is_std, time.time() - start_time)) return is_mean, is_std