1
1
import os
2
2
import sys
3
3
import six
4
- import PIL
5
4
import lmdb
5
+
6
+ import PIL
6
7
from PIL import Image
7
8
8
9
import torch
15
16
_STD_IMAGENET = torch .tensor ([0.229 , 0.224 , 0.225 ])
16
17
17
18
18
- def get_dataloader (opt , dataset , batch_size , shuffle = False , mode = "label" ):
19
+ def get_dataloader (args , dataset , batch_size , shuffle = False , mode = "label" ):
19
20
"""
20
21
Get dataloader for each dataset
21
22
22
23
Parameters
23
24
----------
24
- opt : argparse.ArgumentParser().parse_args()
25
+ args : argparse.ArgumentParser().parse_args()
25
26
dataset: torch.utils.data.Dataset
26
27
batch_size: int
27
28
shuffle: boolean
@@ -32,23 +33,23 @@ def get_dataloader(opt, dataset, batch_size, shuffle = False, mode = "label"):
32
33
"""
33
34
34
35
if mode == "raw" :
35
- myAlignCollate = AlignCollateRaw (opt )
36
+ myAlignCollate = AlignCollateRaw (args )
36
37
else :
37
- myAlignCollate = AlignCollate (opt , mode )
38
+ myAlignCollate = AlignCollate (args , mode )
38
39
39
40
data_loader = DataLoader (
40
41
dataset ,
41
42
batch_size = batch_size ,
42
43
shuffle = shuffle ,
43
- num_workers = opt .workers ,
44
+ num_workers = args .workers ,
44
45
collate_fn = myAlignCollate ,
45
46
pin_memory = False ,
46
47
drop_last = False ,
47
48
)
48
49
return data_loader
49
50
50
51
51
- def hierarchical_dataset (root , opt , mode = "label" , drop_data = []):
52
+ def hierarchical_dataset (root , args , mode = "label" , drop_data = []):
52
53
""" select_data='/' contains all sub-directory of root directory """
53
54
dataset_list = []
54
55
dataset_log = f"dataset_root: { root } \t dataset:"
@@ -72,10 +73,10 @@ def hierarchical_dataset(root, opt, mode="label", drop_data=[]):
72
73
for dirpath in listdir :
73
74
if mode == "raw" :
74
75
# load data without label
75
- dataset = LmdbDataset_raw (dirpath , opt )
76
+ dataset = LmdbDataset_raw (dirpath , args )
76
77
else :
77
78
# load data with label
78
- dataset = LmdbDataset (dirpath , opt )
79
+ dataset = LmdbDataset (dirpath , args )
79
80
sub_dataset_log = f"sub-directory:\t /{ os .path .relpath (dirpath , root )} \t num samples: { len (dataset )} "
80
81
print (sub_dataset_log )
81
82
dataset_log += f"{ sub_dataset_log } \n "
@@ -113,15 +114,15 @@ def __getitem__(self, index):
113
114
114
115
class AlignCollate (object ):
115
116
""" Transform data to the same format """
116
- def __init__ (self , opt , mode = "label" ):
117
- self .opt = opt
117
+ def __init__ (self , args , mode = "label" ):
118
+ self .args = args
118
119
# resize image
119
120
if (mode == "adapt" or mode == "supervised" ):
120
121
self .transform = Rand_augment ()
121
122
else :
122
123
self .transform = torchvision .transforms .Compose ([])
123
124
124
- self .resize = ResizeNormalize (opt )
125
+ self .resize = ResizeNormalize (args )
125
126
print ("Use Text_augment" , self .transform )
126
127
127
128
def __call__ (self , batch ):
@@ -135,10 +136,10 @@ def __call__(self, batch):
135
136
136
137
class AlignCollateRaw (object ):
137
138
""" Transform data to the same format """
138
- def __init__ (self , opt ):
139
- self .opt = opt
139
+ def __init__ (self , args ):
140
+ self .args = args
140
141
# resize image
141
- self .transform = ResizeNormalize (opt )
142
+ self .transform = ResizeNormalize (args )
142
143
143
144
def __call__ (self , batch ):
144
145
images = batch
@@ -151,20 +152,20 @@ def __call__(self, batch):
151
152
152
153
class AlignCollateHDGE (object ):
153
154
""" Transform data to the same format """
154
- def __init__ (self , opt , infer = False ):
155
- self .opt = opt
155
+ def __init__ (self , args , infer = False ):
156
+ self .args = args
156
157
157
158
# for transforming the input image
158
159
if infer == False :
159
160
transform = torchvision .transforms .Compose (
160
161
[torchvision .transforms .RandomHorizontalFlip (),
161
- torchvision .transforms .Resize ((opt .load_height ,opt .load_width )),
162
- torchvision .transforms .RandomCrop ((opt .crop_height ,opt .crop_width )),
162
+ torchvision .transforms .Resize ((args .load_height ,args .load_width )),
163
+ torchvision .transforms .RandomCrop ((args .crop_height ,args .crop_width )),
163
164
torchvision .transforms .ToTensor (),
164
165
torchvision .transforms .Normalize (mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ])])
165
166
else :
166
167
transform = torchvision .transforms .Compose (
167
- [torchvision .transforms .Resize ((opt .crop_height ,opt .crop_width )),
168
+ [torchvision .transforms .Resize ((args .crop_height ,args .crop_width )),
168
169
torchvision .transforms .ToTensor (),
169
170
torchvision .transforms .Normalize (mean = [0.5 , 0.5 , 0.5 ], std = [0.5 , 0.5 , 0.5 ])])
170
171
@@ -181,10 +182,10 @@ def __call__(self, batch):
181
182
182
183
class LmdbDataset (Dataset ):
183
184
""" Load data from Lmdb file with label """
184
- def __init__ (self , root , opt ):
185
+ def __init__ (self , root , args ):
185
186
186
187
self .root = root
187
- self .opt = opt
188
+ self .args = args
188
189
self .env = lmdb .open (
189
190
root ,
190
191
max_readers = 32 ,
@@ -207,7 +208,7 @@ def __init__(self, root, opt):
207
208
208
209
# length filtering
209
210
length_of_label = len (label )
210
- if length_of_label > opt .batch_max_length :
211
+ if length_of_label > args .batch_max_length :
211
212
continue
212
213
213
214
self .filtered_index_list .append (index )
@@ -236,18 +237,18 @@ def __getitem__(self, index):
236
237
except IOError :
237
238
print (f"Corrupted image for { index } " )
238
239
# make dummy image and dummy label for corrupted image.
239
- img = PIL .Image .new ("RGB" , (self .opt .imgW , self .opt .imgH ))
240
+ img = PIL .Image .new ("RGB" , (self .args .imgW , self .args .imgH ))
240
241
label = "[dummy_label]"
241
242
242
243
return (img , label )
243
244
244
245
245
246
class LmdbDataset_raw (Dataset ):
246
247
""" Load data from Lmdb file without label """
247
- def __init__ (self , root , opt ):
248
+ def __init__ (self , root , args ):
248
249
249
250
self .root = root
250
- self .opt = opt
251
+ self .args = args
251
252
self .env = lmdb .open (
252
253
root ,
253
254
max_readers = 32 ,
@@ -284,27 +285,21 @@ def __getitem__(self, index):
284
285
except IOError :
285
286
print (f"Corrupted image for { img_key } " )
286
287
# make dummy image for corrupted image.
287
- img = PIL .Image .new ("RGB" , (self .opt .imgW , self .opt .imgH ))
288
+ img = PIL .Image .new ("RGB" , (self .args .imgW , self .args .imgH ))
288
289
289
290
return img
290
291
291
292
292
293
class ResizeNormalize (object ):
293
294
294
- def __init__ (self , opt ):
295
- self .opt = opt
295
+ def __init__ (self , args ):
296
+ self .args = args
296
297
_transforms = []
297
298
298
- _transforms .append (
299
- torchvision .transforms .Resize ((self .opt .imgH , self .opt .imgW ),
300
- interpolation = torchvision .transforms .InterpolationMode .BICUBIC ))
299
+ _transforms .append (torchvision .transforms .Resize ((self .args .imgH , self .args .imgW ),
300
+ interpolation = torchvision .transforms .InterpolationMode .BICUBIC ))
301
301
_transforms .append (torchvision .transforms .ToTensor ())
302
- if self .opt .use_IMAGENET_norm :
303
- _transforms .append (torchvision .transforms .Normalize (mean = _MEAN_IMAGENET ,
304
- std = _STD_IMAGENET ))
305
- else :
306
- _transforms .append (torchvision .transforms .Normalize (mean = [0.5 , 0.5 , 0.5 ],
307
- std = [0.5 , 0.5 , 0.5 ]))
302
+ _transforms .append (torchvision .transforms .Normalize (mean = _MEAN_IMAGENET , std = _STD_IMAGENET ))
308
303
self ._transforms = torchvision .transforms .Compose (_transforms )
309
304
310
305
def __call__ (self , image ):
0 commit comments