You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

198 lines
5.8 KiB

  1. import torch
  2. import torchvision.transforms as transforms
  3. import random
  4. __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
  5. 'std': [0.229, 0.224, 0.225]}
  6. __imagenet_pca = {
  7. 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
  8. 'eigvec': torch.Tensor([
  9. [-0.5675, 0.7192, 0.4009],
  10. [-0.5808, -0.0045, -0.8140],
  11. [-0.5836, -0.6948, 0.4203],
  12. ])
  13. }
  14. def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats):
  15. t_list = [
  16. transforms.CenterCrop(input_size),
  17. transforms.ToTensor(),
  18. transforms.Normalize(**normalize),
  19. ]
  20. if scale_size != input_size:
  21. t_list = [transforms.Scale(scale_size)] + t_list
  22. return transforms.Compose(t_list)
  23. def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
  24. t_list = [
  25. transforms.RandomCrop(input_size),
  26. transforms.ToTensor(),
  27. transforms.Normalize(**normalize),
  28. ]
  29. if scale_size != input_size:
  30. t_list = [transforms.Scale(scale_size)] + t_list
  31. transforms.Compose(t_list)
  32. def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
  33. padding = int((scale_size - input_size) / 2)
  34. return transforms.Compose([
  35. transforms.RandomCrop(input_size, padding=padding),
  36. transforms.RandomHorizontalFlip(),
  37. transforms.ToTensor(),
  38. transforms.Normalize(**normalize),
  39. ])
  40. def inception_preproccess(input_size, normalize=__imagenet_stats):
  41. return transforms.Compose([
  42. transforms.RandomSizedCrop(input_size),
  43. transforms.RandomHorizontalFlip(),
  44. transforms.ToTensor(),
  45. transforms.Normalize(**normalize)
  46. ])
  47. def inception_color_preproccess(input_size, normalize=__imagenet_stats):
  48. return transforms.Compose([
  49. transforms.RandomSizedCrop(input_size),
  50. transforms.RandomHorizontalFlip(),
  51. transforms.ToTensor(),
  52. ColorJitter(
  53. brightness=0.4,
  54. contrast=0.4,
  55. saturation=0.4,
  56. ),
  57. Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']),
  58. transforms.Normalize(**normalize)
  59. ])
  60. def get_transform(name='imagenet', input_size=None,
  61. scale_size=None, normalize=None, augment=True):
  62. normalize = normalize or __imagenet_stats
  63. if name == 'imagenet':
  64. scale_size = scale_size or 256
  65. input_size = input_size or 224
  66. if augment:
  67. return inception_preproccess(input_size, normalize=normalize)
  68. else:
  69. return scale_crop(input_size=input_size,
  70. scale_size=scale_size, normalize=normalize)
  71. elif 'cifar' in name:
  72. input_size = input_size or 32
  73. if augment:
  74. scale_size = scale_size or 40
  75. return pad_random_crop(input_size, scale_size=scale_size,
  76. normalize=normalize)
  77. else:
  78. scale_size = scale_size or 32
  79. return scale_crop(input_size=input_size,
  80. scale_size=scale_size, normalize=normalize)
  81. elif name == 'mnist':
  82. normalize = {'mean': [0.5], 'std': [0.5]}
  83. input_size = input_size or 28
  84. if augment:
  85. scale_size = scale_size or 32
  86. return pad_random_crop(input_size, scale_size=scale_size,
  87. normalize=normalize)
  88. else:
  89. scale_size = scale_size or 32
  90. return scale_crop(input_size=input_size,
  91. scale_size=scale_size, normalize=normalize)
  92. class Lighting(object):
  93. """Lighting noise(AlexNet - style PCA - based noise)"""
  94. def __init__(self, alphastd, eigval, eigvec):
  95. self.alphastd = alphastd
  96. self.eigval = eigval
  97. self.eigvec = eigvec
  98. def __call__(self, img):
  99. if self.alphastd == 0:
  100. return img
  101. alpha = img.new().resize_(3).normal_(0, self.alphastd)
  102. rgb = self.eigvec.type_as(img).clone()\
  103. .mul(alpha.view(1, 3).expand(3, 3))\
  104. .mul(self.eigval.view(1, 3).expand(3, 3))\
  105. .sum(1).squeeze()
  106. return img.add(rgb.view(3, 1, 1).expand_as(img))
  107. class Grayscale(object):
  108. def __call__(self, img):
  109. gs = img.clone()
  110. gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
  111. gs[1].copy_(gs[0])
  112. gs[2].copy_(gs[0])
  113. return gs
  114. class Saturation(object):
  115. def __init__(self, var):
  116. self.var = var
  117. def __call__(self, img):
  118. gs = Grayscale()(img)
  119. alpha = random.uniform(0, self.var)
  120. return img.lerp(gs, alpha)
  121. class Brightness(object):
  122. def __init__(self, var):
  123. self.var = var
  124. def __call__(self, img):
  125. gs = img.new().resize_as_(img).zero_()
  126. alpha = random.uniform(0, self.var)
  127. return img.lerp(gs, alpha)
  128. class Contrast(object):
  129. def __init__(self, var):
  130. self.var = var
  131. def __call__(self, img):
  132. gs = Grayscale()(img)
  133. gs.fill_(gs.mean())
  134. alpha = random.uniform(0, self.var)
  135. return img.lerp(gs, alpha)
  136. class RandomOrder(object):
  137. """ Composes several transforms together in random order.
  138. """
  139. def __init__(self, transforms):
  140. self.transforms = transforms
  141. def __call__(self, img):
  142. if self.transforms is None:
  143. return img
  144. order = torch.randperm(len(self.transforms))
  145. for i in order:
  146. img = self.transforms[i](img)
  147. return img
  148. class ColorJitter(RandomOrder):
  149. def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
  150. self.transforms = []
  151. if brightness != 0:
  152. self.transforms.append(Brightness(brightness))
  153. if contrast != 0:
  154. self.transforms.append(Contrast(contrast))
  155. if saturation != 0:
  156. self.transforms.append(Saturation(saturation))