import torch import torchvision.transforms as transforms import random __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]} __imagenet_pca = { 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 'eigvec': torch.Tensor([ [-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203], ]) } def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): t_list = [ transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize(**normalize), ] if scale_size != input_size: t_list = [transforms.Scale(scale_size)] + t_list return transforms.Compose(t_list) def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): t_list = [ transforms.RandomCrop(input_size), transforms.ToTensor(), transforms.Normalize(**normalize), ] if scale_size != input_size: t_list = [transforms.Scale(scale_size)] + t_list transforms.Compose(t_list) def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): padding = int((scale_size - input_size) / 2) return transforms.Compose([ transforms.RandomCrop(input_size, padding=padding), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(**normalize), ]) def inception_preproccess(input_size, normalize=__imagenet_stats): return transforms.Compose([ transforms.RandomSizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(**normalize) ]) def inception_color_preproccess(input_size, normalize=__imagenet_stats): return transforms.Compose([ transforms.RandomSizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ColorJitter( brightness=0.4, contrast=0.4, saturation=0.4, ), Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), transforms.Normalize(**normalize) ]) def get_transform(name='imagenet', input_size=None, scale_size=None, normalize=None, augment=True): normalize = normalize or __imagenet_stats if name == 'imagenet': scale_size = scale_size or 256 input_size = input_size or 224 if augment: return inception_preproccess(input_size, normalize=normalize) else: return scale_crop(input_size=input_size, scale_size=scale_size, normalize=normalize) elif 'cifar' in name: input_size = input_size or 32 if augment: scale_size = scale_size or 40 return pad_random_crop(input_size, scale_size=scale_size, normalize=normalize) else: scale_size = scale_size or 32 return scale_crop(input_size=input_size, scale_size=scale_size, normalize=normalize) elif name == 'mnist': normalize = {'mean': [0.5], 'std': [0.5]} input_size = input_size or 28 if augment: scale_size = scale_size or 32 return pad_random_crop(input_size, scale_size=scale_size, normalize=normalize) else: scale_size = scale_size or 32 return scale_crop(input_size=input_size, scale_size=scale_size, normalize=normalize) class Lighting(object): """Lighting noise(AlexNet - style PCA - based noise)""" def __init__(self, alphastd, eigval, eigvec): self.alphastd = alphastd self.eigval = eigval self.eigvec = eigvec def __call__(self, img): if self.alphastd == 0: return img alpha = img.new().resize_(3).normal_(0, self.alphastd) rgb = self.eigvec.type_as(img).clone()\ .mul(alpha.view(1, 3).expand(3, 3))\ .mul(self.eigval.view(1, 3).expand(3, 3))\ .sum(1).squeeze() return img.add(rgb.view(3, 1, 1).expand_as(img)) class Grayscale(object): def __call__(self, img): gs = img.clone() gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) gs[1].copy_(gs[0]) gs[2].copy_(gs[0]) return gs class Saturation(object): def __init__(self, var): self.var = var def __call__(self, img): gs = Grayscale()(img) alpha = random.uniform(0, self.var) return img.lerp(gs, alpha) class Brightness(object): def __init__(self, var): self.var = var def __call__(self, img): gs = img.new().resize_as_(img).zero_() alpha = random.uniform(0, self.var) return img.lerp(gs, alpha) class Contrast(object): def __init__(self, var): self.var = var def __call__(self, img): gs = Grayscale()(img) gs.fill_(gs.mean()) alpha = random.uniform(0, self.var) return img.lerp(gs, alpha) class RandomOrder(object): """ Composes several transforms together in random order. """ def __init__(self, transforms): self.transforms = transforms def __call__(self, img): if self.transforms is None: return img order = torch.randperm(len(self.transforms)) for i in order: img = self.transforms[i](img) return img class ColorJitter(RandomOrder): def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): self.transforms = [] if brightness != 0: self.transforms.append(Brightness(brightness)) if contrast != 0: self.transforms.append(Contrast(contrast)) if saturation != 0: self.transforms.append(Saturation(saturation))