import torch
|
|
import torchvision.transforms as transforms
|
|
import random
|
|
|
|
__imagenet_stats = {'mean': [0.485, 0.456, 0.406],
|
|
'std': [0.229, 0.224, 0.225]}
|
|
|
|
__imagenet_pca = {
|
|
'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]),
|
|
'eigvec': torch.Tensor([
|
|
[-0.5675, 0.7192, 0.4009],
|
|
[-0.5808, -0.0045, -0.8140],
|
|
[-0.5836, -0.6948, 0.4203],
|
|
])
|
|
}
|
|
|
|
|
|
def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats):
|
|
t_list = [
|
|
transforms.CenterCrop(input_size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(**normalize),
|
|
]
|
|
if scale_size != input_size:
|
|
t_list = [transforms.Scale(scale_size)] + t_list
|
|
|
|
return transforms.Compose(t_list)
|
|
|
|
|
|
def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
|
|
t_list = [
|
|
transforms.RandomCrop(input_size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(**normalize),
|
|
]
|
|
if scale_size != input_size:
|
|
t_list = [transforms.Scale(scale_size)] + t_list
|
|
|
|
transforms.Compose(t_list)
|
|
|
|
|
|
def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats):
|
|
padding = int((scale_size - input_size) / 2)
|
|
return transforms.Compose([
|
|
transforms.RandomCrop(input_size, padding=padding),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(**normalize),
|
|
])
|
|
|
|
|
|
def inception_preproccess(input_size, normalize=__imagenet_stats):
|
|
return transforms.Compose([
|
|
transforms.RandomSizedCrop(input_size),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(**normalize)
|
|
])
|
|
def inception_color_preproccess(input_size, normalize=__imagenet_stats):
|
|
return transforms.Compose([
|
|
transforms.RandomSizedCrop(input_size),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
ColorJitter(
|
|
brightness=0.4,
|
|
contrast=0.4,
|
|
saturation=0.4,
|
|
),
|
|
Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']),
|
|
transforms.Normalize(**normalize)
|
|
])
|
|
|
|
|
|
def get_transform(name='imagenet', input_size=None,
|
|
scale_size=None, normalize=None, augment=True):
|
|
normalize = normalize or __imagenet_stats
|
|
if name == 'imagenet':
|
|
scale_size = scale_size or 256
|
|
input_size = input_size or 224
|
|
if augment:
|
|
return inception_preproccess(input_size, normalize=normalize)
|
|
else:
|
|
return scale_crop(input_size=input_size,
|
|
scale_size=scale_size, normalize=normalize)
|
|
elif 'cifar' in name:
|
|
input_size = input_size or 32
|
|
if augment:
|
|
scale_size = scale_size or 40
|
|
return pad_random_crop(input_size, scale_size=scale_size,
|
|
normalize=normalize)
|
|
else:
|
|
scale_size = scale_size or 32
|
|
return scale_crop(input_size=input_size,
|
|
scale_size=scale_size, normalize=normalize)
|
|
elif name == 'mnist':
|
|
normalize = {'mean': [0.5], 'std': [0.5]}
|
|
input_size = input_size or 28
|
|
if augment:
|
|
scale_size = scale_size or 32
|
|
return pad_random_crop(input_size, scale_size=scale_size,
|
|
normalize=normalize)
|
|
else:
|
|
scale_size = scale_size or 32
|
|
return scale_crop(input_size=input_size,
|
|
scale_size=scale_size, normalize=normalize)
|
|
|
|
|
|
class Lighting(object):
|
|
"""Lighting noise(AlexNet - style PCA - based noise)"""
|
|
|
|
def __init__(self, alphastd, eigval, eigvec):
|
|
self.alphastd = alphastd
|
|
self.eigval = eigval
|
|
self.eigvec = eigvec
|
|
|
|
def __call__(self, img):
|
|
if self.alphastd == 0:
|
|
return img
|
|
|
|
alpha = img.new().resize_(3).normal_(0, self.alphastd)
|
|
rgb = self.eigvec.type_as(img).clone()\
|
|
.mul(alpha.view(1, 3).expand(3, 3))\
|
|
.mul(self.eigval.view(1, 3).expand(3, 3))\
|
|
.sum(1).squeeze()
|
|
|
|
return img.add(rgb.view(3, 1, 1).expand_as(img))
|
|
|
|
|
|
class Grayscale(object):
|
|
|
|
def __call__(self, img):
|
|
gs = img.clone()
|
|
gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
|
|
gs[1].copy_(gs[0])
|
|
gs[2].copy_(gs[0])
|
|
return gs
|
|
|
|
|
|
class Saturation(object):
|
|
|
|
def __init__(self, var):
|
|
self.var = var
|
|
|
|
def __call__(self, img):
|
|
gs = Grayscale()(img)
|
|
alpha = random.uniform(0, self.var)
|
|
return img.lerp(gs, alpha)
|
|
|
|
|
|
class Brightness(object):
|
|
|
|
def __init__(self, var):
|
|
self.var = var
|
|
|
|
def __call__(self, img):
|
|
gs = img.new().resize_as_(img).zero_()
|
|
alpha = random.uniform(0, self.var)
|
|
return img.lerp(gs, alpha)
|
|
|
|
|
|
class Contrast(object):
|
|
|
|
def __init__(self, var):
|
|
self.var = var
|
|
|
|
def __call__(self, img):
|
|
gs = Grayscale()(img)
|
|
gs.fill_(gs.mean())
|
|
alpha = random.uniform(0, self.var)
|
|
return img.lerp(gs, alpha)
|
|
|
|
|
|
class RandomOrder(object):
|
|
""" Composes several transforms together in random order.
|
|
"""
|
|
|
|
def __init__(self, transforms):
|
|
self.transforms = transforms
|
|
|
|
def __call__(self, img):
|
|
if self.transforms is None:
|
|
return img
|
|
order = torch.randperm(len(self.transforms))
|
|
for i in order:
|
|
img = self.transforms[i](img)
|
|
return img
|
|
|
|
|
|
class ColorJitter(RandomOrder):
|
|
|
|
def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
|
|
self.transforms = []
|
|
if brightness != 0:
|
|
self.transforms.append(Brightness(brightness))
|
|
if contrast != 0:
|
|
self.transforms.append(Contrast(contrast))
|
|
if saturation != 0:
|
|
self.transforms.append(Saturation(saturation))
|