|
|
- import os
- import torchvision.datasets as datasets
- import torchvision.transforms as transforms
-
- _DATASETS_MAIN_PATH = '/home/Datasets'
- _dataset_path = {
- 'cifar10': os.path.join(_DATASETS_MAIN_PATH, 'CIFAR10'),
- 'cifar100': os.path.join(_DATASETS_MAIN_PATH, 'CIFAR100'),
- 'stl10': os.path.join(_DATASETS_MAIN_PATH, 'STL10'),
- 'mnist': os.path.join(_DATASETS_MAIN_PATH, 'MNIST'),
- 'imagenet': {
- 'train': os.path.join(_DATASETS_MAIN_PATH, 'ImageNet/train'),
- 'val': os.path.join(_DATASETS_MAIN_PATH, 'ImageNet/val')
- }
- }
-
-
- def get_dataset(name, split='train', transform=None,
- target_transform=None, download=True):
- train = (split == 'train')
- if name == 'cifar10':
- return datasets.CIFAR10(root=_dataset_path['cifar10'],
- train=train,
- transform=transform,
- target_transform=target_transform,
- download=download)
- elif name == 'cifar100':
- return datasets.CIFAR100(root=_dataset_path['cifar100'],
- train=train,
- transform=transform,
- target_transform=target_transform,
- download=download)
- elif name == 'imagenet':
- path = _dataset_path[name][split]
- return datasets.ImageFolder(root=path,
- transform=transform,
- target_transform=target_transform)
|