You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

37 lines
1.5 KiB

  1. import os
  2. import torchvision.datasets as datasets
  3. import torchvision.transforms as transforms
  4. _DATASETS_MAIN_PATH = '/home/Datasets'
  5. _dataset_path = {
  6. 'cifar10': os.path.join(_DATASETS_MAIN_PATH, 'CIFAR10'),
  7. 'cifar100': os.path.join(_DATASETS_MAIN_PATH, 'CIFAR100'),
  8. 'stl10': os.path.join(_DATASETS_MAIN_PATH, 'STL10'),
  9. 'mnist': os.path.join(_DATASETS_MAIN_PATH, 'MNIST'),
  10. 'imagenet': {
  11. 'train': os.path.join(_DATASETS_MAIN_PATH, 'ImageNet/train'),
  12. 'val': os.path.join(_DATASETS_MAIN_PATH, 'ImageNet/val')
  13. }
  14. }
  15. def get_dataset(name, split='train', transform=None,
  16. target_transform=None, download=True):
  17. train = (split == 'train')
  18. if name == 'cifar10':
  19. return datasets.CIFAR10(root=_dataset_path['cifar10'],
  20. train=train,
  21. transform=transform,
  22. target_transform=target_transform,
  23. download=download)
  24. elif name == 'cifar100':
  25. return datasets.CIFAR100(root=_dataset_path['cifar100'],
  26. train=train,
  27. transform=transform,
  28. target_transform=target_transform,
  29. download=download)
  30. elif name == 'imagenet':
  31. path = _dataset_path[name][split]
  32. return datasets.ImageFolder(root=path,
  33. transform=transform,
  34. target_transform=target_transform)