-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcommon.py
55 lines (47 loc) · 1.75 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from torchvision import transforms
def get_data_transformation(keep_image_ratio, downsample_image):
train_transform = transforms.Compose([
transforms.RandomRotation(45)
])
test_transform = transforms.Compose([])
if keep_image_ratio:
if downsample_image:
import PIL
def downsampling(img):
new_size = (img.size[0]/2, img.size[1]/2)
return img.resize(size=new_size, resample=PIL.Image.BILINEAR)
train_transform.transforms.extend([
transforms.Lambda(downsampling),
transforms.CenterCrop(size=(400, 250)),
])
test_transform.transforms.extend([
transforms.Lambda(downsampling),
transforms.CenterCrop(size=(400, 250))
])
else:
train_transform.transforms.extend([
transforms.CenterCrop(size=(850, 550)),
])
test_transform.transforms.extend([
transforms.CenterCrop(size=(850, 550))
])
else:
train_transform.transforms.extend([
transforms.Resize(size=256),
transforms.RandomCrop(size=224)
])
test_transform.transforms.extend([
transforms.Resize(size=256),
transforms.CenterCrop(size=224)
])
train_transform.transforms.extend([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip()
])
for transform in [train_transform, test_transform]:
transform.transforms.extend([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
return train_transform, test_transform