import os import torchvision.datasets as datasets import torchvision.transforms as transforms _DATASETS_MAIN_PATH = '/home/Datasets' _dataset_path = { 'cifar10': os.path.join(_DATASETS_MAIN_PATH, 'CIFAR10'), 'cifar100': os.path.join(_DATASETS_MAIN_PATH, 'CIFAR100'), 'stl10': os.path.join(_DATASETS_MAIN_PATH, 'STL10'), 'mnist': os.path.join(_DATASETS_MAIN_PATH, 'MNIST'), 'imagenet': { 'train': os.path.join(_DATASETS_MAIN_PATH, 'ImageNet/train'), 'val': os.path.join(_DATASETS_MAIN_PATH, 'ImageNet/val') } } def get_dataset(name, split='train', transform=None, target_transform=None, download=True): train = (split == 'train') if name == 'cifar10': return datasets.CIFAR10(root=_dataset_path['cifar10'], train=train, transform=transform, target_transform=target_transform, download=download) elif name == 'cifar100': return datasets.CIFAR100(root=_dataset_path['cifar100'], train=train, transform=transform, target_transform=target_transform, download=download) elif name == 'imagenet': path = _dataset_path[name][split] return datasets.ImageFolder(root=path, transform=transform, target_transform=target_transform)