-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_loader.py
164 lines (134 loc) · 5.43 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import pdb
import random
import torch
from torchvision import transforms, datasets
from PIL import ImageFilter
from util import subset_classes
# Extended version of ImageFolder to return index of image too.
class ImageFolderEx(datasets.ImageFolder):
def __init__(self, root, transforms, sup_split_file=None, only_sup=False, corrupt_split_file=None):
super(ImageFolderEx, self).__init__(root, transforms)
self.is_unsup = None
# Supervised subset for semi-supervised learning
if sup_split_file is not None:
with open(sup_split_file, 'r') as f:
lines = [line.strip() for line in f.readlines()]
sup_set = set(lines)
samples = []
self.is_unsup = -1 * torch.ones((len(self.samples)), dtype=torch.int)
for i, (image_path, image_class) in enumerate(self.samples):
image_name = image_path.split('/')[-1]
if image_name in sup_set:
self.is_unsup[i] = 0
if only_sup:
samples.append((image_path, image_class))
else:
self.is_unsup[i] = 1
# Use only supervised images
if only_sup:
self.samples = samples
if corrupt_split_file is not None:
with open(corrupt_split_file, 'r') as f:
samples = [line.strip().split(' ') for line in f.readlines()]
self.samples = [(os.path.join(root, pth), int(cls)) for pth, cls in samples]
def __getitem__(self, index):
sample, target = super(ImageFolderEx, self).__getitem__(index)
if self.is_unsup is not None:
is_unsup = self.is_unsup[index]
return index, sample, target, is_unsup
else:
return index, sample, target
class TwoCropsTransform:
"""Return two random crops of one image as the query and target."""
def __init__(self, weak_transform, strong_transform):
self.weak_transform = weak_transform
self.strong_transform = strong_transform
print(self.weak_transform)
print(self.strong_transform)
def __call__(self, x):
q = self.strong_transform(x)
t = self.weak_transform(x)
return [q, t]
class GaussianBlur(object):
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
def __init__(self, sigma):
self.sigma = sigma
def __call__(self, x):
sigma = random.uniform(self.sigma[0], self.sigma[1])
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
return x
# Create train loader
def get_train_loader(opt):
traindir = os.path.join(opt.data, 'train')
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=mean, std=std)
augmentation_strong = [
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
]
augmentation_weak = [
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
if 'sup_split_file' in vars(opt).keys():
sup_split_file = opt.sup_split_file
else:
sup_split_file = None
if 'corrupt_split' in vars(opt).keys():
corrupt_split_file = os.path.join('subsets', 'corrupt_{}.txt'.format(opt.corrupt_split))
else:
corrupt_split_file = None
if opt.weak_strong:
train_dataset = ImageFolderEx(
root=traindir,
transforms=TwoCropsTransform(transforms.Compose(augmentation_weak), transforms.Compose(augmentation_strong)),
sup_split_file=sup_split_file,
only_sup=False,
corrupt_split_file=corrupt_split_file,
)
else:
train_dataset = ImageFolderEx(
root=traindir,
transforms=TwoCropsTransform(transforms.Compose(augmentation_strong), transforms.Compose(augmentation_strong)),
sup_split_file=sup_split_file,
only_sup=False,
corrupt_split_file=corrupt_split_file,
)
if opt.dataset == 'imagenet100':
subset_classes(train_dataset, num_classes=100)
print('==> train dataset')
print(train_dataset)
# NOTE: remove drop_last
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batch_size, shuffle=True,
num_workers=opt.num_workers, pin_memory=True, drop_last=True)
# Applicable only for semi-supervised setup.
if 'sup_split_file' in vars(opt).keys():
# Get dataloader for pseudo-labelling
sup_val_dataset = ImageFolderEx(
root=traindir,
sup_split_file=opt.sup_split_file,
only_sup=True,
transforms=transforms.Compose(augmentation_weak)
)
if opt.dataset == 'imagenet100':
subset_classes(sup_val_dataset, num_classes=100)
train_val_loader = torch.utils.data.DataLoader(
sup_val_dataset,
batch_size=opt.batch_size, shuffle=False,
num_workers=opt.num_workers, pin_memory=True,
)
return train_loader, train_val_loader
else:
return train_loader