Source code for torch_mimicry.datasets.image_loader

"""
Loads randomly sampled images from datasets for computing metrics.
"""
import os
import random

import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image

from torch_mimicry.datasets import data_utils


[docs]def get_random_images(dataset, num_samples): """ Randomly sample without replacement num_samples images. Args: dataset (Dataset): Torch Dataset object for indexing elements. num_samples (int): The number of images to randomly sample. Returns: ndarray: Batch of num_samples images in np array form. """ choices = np.random.choice(range(len(dataset)), size=num_samples, replace=False) images = [] for choice in choices: img = np.array(dataset[choice][0]) img = np.expand_dims(img, axis=0) images.append(img) images = np.concatenate(images, axis=0) return images
[docs]def get_imagenet_images(num_samples, root='./datasets', size=32): """ Directly reads the imagenet folder for obtaining random images sampled in equal proportion for each class. Args: num_samples (int): The number of images to randomly sample. root (str): The root directory where all datasets are stored. size (int): Size of image to resize to. Returns: ndarray: Batch of num_samples images in np array form. """ if num_samples < 1000: raise ValueError( "num_samples {} must be at least 1000 to ensure images are sampled from each class." .format(num_samples)) data_dir = os.path.join(root, 'imagenet', 'train') class_dirs = os.listdir(data_dir) images = [] # Randomly choose equal proportion from each class. for class_dir in class_dirs: filenames = [ os.path.join(data_dir, class_dir, name) for name in os.listdir(os.path.join(data_dir, class_dir)) ] choices = np.random.choice(list(range(len(filenames))), size=(num_samples // 1000 or num_samples), replace=False) for i in range(len(choices)): img = Image.open(filenames[int(choices[i])]) img = transforms.CenterCrop(224)(img) img = transforms.Resize(size)(img) img = np.asarray(img) # Convert grayscale to rgb if len(img.shape) == 2: tmp = np.expand_dims(img, axis=2) img = np.concatenate([tmp, tmp, tmp], axis=2) # account for rgba images elif img.shape[2] == 4: img = img[:, :, :3] img = np.expand_dims(img, axis=0) images.append(img) images = np.concatenate(images, axis=0) return images
[docs]def get_fake_data_images(num_samples, root='./datasets', size=32, **kwargs): """ Loads fake images, especially for testing. Args: num_samples (int): The number of images to randomly sample. root (str): The root directory where all datasets are stored. size (int): Size of image to resize to. Returns: ndarray: Batch of num_samples images in np array form. """ dataset = data_utils.load_fake_dataset( root=root, image_size=(3, size, size), transform_data=True, convert_tensor=False, # Prevents normalization. **kwargs) images = get_random_images(dataset, num_samples) return images
[docs]def get_lsun_bedroom_images(num_samples, root='./datasets', size=128, **kwargs): """ Loads randomly sampled LSUN-Bedroom training images. Args: num_samples (int): The number of images to randomly sample. root (str): The root directory where all datasets are stored. size (int): Size of image to resize to. Returns: ndarray: Batch of num_samples images in np array form. """ dataset = data_utils.load_lsun_bedroom_dataset( root=root, size=size, transform_data=True, convert_tensor=False, # Prevents normalization. **kwargs) images = get_random_images(dataset, num_samples) return images
[docs]def get_celeba_images(num_samples, root='./datasets', size=128, **kwargs): """ Loads randomly sampled CelebA images. Args: num_samples (int): The number of images to randomly sample. root (str): The root directory where all datasets are stored. size (int): Size of image to resize to. Returns: ndarray: Batch of num_samples images in np array form. """ dataset = data_utils.load_celeba_dataset( root=root, size=size, transform_data=True, convert_tensor=False, # Prevents normalization. **kwargs) images = get_random_images(dataset, num_samples) return images
[docs]def get_stl10_images(num_samples, root='./datasets', size=48, **kwargs): """ Loads randomly sampled STL-10 images. Args: num_samples (int): The number of images to randomly sample. root (str): The root directory where all datasets are stored. size (int): Size of image to resize to. Returns: ndarray: Batch of num_samples images in np array form. """ dataset = data_utils.load_stl10_dataset( root=root, size=size, transform_data=True, convert_tensor=False, # Prevents normalization. **kwargs) images = get_random_images(dataset, num_samples) return images
[docs]def get_cifar10_images(num_samples, root="./datasets", **kwargs): """ Loads randomly sampled CIFAR-10 training images. Args: num_samples (int): The number of images to randomly sample. root (str): The root directory where all datasets are stored. Returns: ndarray: Batch of num_samples images in np array form. """ dataset = data_utils.load_cifar10_dataset(root=root, transform_data=False, **kwargs) images = get_random_images(dataset, num_samples) return images
[docs]def get_cifar100_images(num_samples, root="./datasets", **kwargs): """ Loads randomly sampled CIFAR-100 training images. Args: num_samples (int): The number of images to randomly sample. root (str): The root directory where all datasets are stored. Returns: ndarray: Batch of num_samples images in np array form. """ dataset = data_utils.load_cifar100_dataset(root=root, split='train', download=True, transform_data=False, **kwargs) images = get_random_images(dataset, num_samples) return images
[docs]def sample_dataset_images(dataset, num_samples): """ Randomly samples the dataset for images. Args: dataset (Dataset): Torch dataset object to sample images from. num_samples (int): The number of images to randomly sample. Returns: ndarray: Numpy array of images with first dim as batch size. """ # Check if sufficient images if len(dataset) < num_samples: raise ValueError( "Given dataset has less than num_samples images: {} given but requires at least {}." .format(len(dataset), num_samples)) choices = random.sample(range(len(dataset)), num_samples) images = [] for i in choices: data = dataset[i] # Case of iterable, assumes first arg is data. if isinstance(data, tuple) or isinstance(data, list): img = data[0] else: img = data img = torch.unsqueeze(img, 0) images.append(img) images = np.concatenate(images, axis=0) return images
[docs]def get_dataset_images(dataset, num_samples=50000, **kwargs): """ Randomly sample num_samples images based on input dataset name. Args: dataset (str/Dataset): Dataset to load images from. num_samples (int): The number of images to randomly sample. Returns: ndarray: Batch of num_samples images from a dataset in np array form. The final format is of (N, H, W, 3) shape for TF inference. """ if isinstance(dataset, str): if dataset == "imagenet_32": images = get_imagenet_images(num_samples, size=32, **kwargs) elif dataset == "imagenet_128": images = get_imagenet_images(num_samples, size=128, **kwargs) elif dataset == "celeba_64": images = get_celeba_images(num_samples, size=64, **kwargs) elif dataset == "celeba_128": images = get_celeba_images(num_samples, size=128, **kwargs) elif dataset == "stl10_48": images = get_stl10_images(num_samples, **kwargs) elif dataset == "cifar10": images = get_cifar10_images(num_samples, **kwargs) elif dataset == "cifar100": images = get_cifar100_images(num_samples, **kwargs) elif dataset == "lsun_bedroom_128": images = get_lsun_bedroom_images(num_samples, size=128, **kwargs) elif dataset == "fake_data": images = get_fake_data_images(num_samples, size=32, **kwargs) else: raise ValueError("Invalid dataset name {}.".format(dataset)) elif issubclass(type(dataset), torch.utils.data.Dataset): images = sample_dataset_images(dataset, num_samples) else: raise ValueError("dataset must be of type str or a Dataset object.") # Check shape and permute if needed if images.shape[1] == 3: images = images.transpose((0, 2, 3, 1)) # Ensure the values lie within the correct range, otherwise there might be some # preprocessing error from the library causing ill-valued scores. if np.min(images) < 0 or np.max(images) > 255: print( "INFO: Some pixel values lie outside of [0, 255]. Clipping values.." ) images = np.clip(images, 0, 255) return images