-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaugment.py
executable file
·123 lines (99 loc) · 3.36 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
"""
3Augment implementation
Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)
and timm DA(https://github.com/rwightman/pytorch-image-models)
"""
import torch
from torchvision import transforms
from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor
import numpy as np
from torchvision import datasets, transforms
import random
from PIL import ImageFilter, ImageOps
import torchvision.transforms.functional as TF
class GaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = random.random() <= self.prob
if not do_it:
return img
img = img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
return img
class Solarization(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
class gray_scale(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2):
self.p = p
self.transf = transforms.Grayscale(3)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
class horizontal_flip(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p=0.2,activate_pred=False):
self.p = p
self.transf = transforms.RandomHorizontalFlip(p=1.0)
def __call__(self, img):
if random.random() < self.p:
return self.transf(img)
else:
return img
def new_data_aug_generator(args = None):
img_size = args.input_size
remove_random_resized_crop = args.src
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
primary_tfl = []
scale=(0.08, 1.0)
interpolation='bicubic'
if remove_random_resized_crop:
primary_tfl = [
transforms.Resize(img_size, interpolation=3),
transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'),
transforms.RandomHorizontalFlip()
]
else:
primary_tfl = [
RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation),
transforms.RandomHorizontalFlip()
]
secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0),
Solarization(p=1.0),
GaussianBlur(p=1.0)])]
if args.color_jitter is not None and not args.color_jitter==0:
secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter))
final_tfl = [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
]
return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)