11
11
import copy
12
12
import platform
13
13
from functools import partial
14
-
14
+ import torch
15
15
from torch .utils .data import DataLoader
16
16
17
17
from mmcv .utils import Registry
18
18
from mmcv .utils import build_from_cfg
19
19
from mmcv .parallel import collate
20
20
from mmcv .runner import get_dist_info
21
+ from mmcv .parallel import DataContainer as DC
21
22
22
23
from mmdet .datasets import DATASETS
23
24
from mmdet .models .builder import build
24
25
from mmdet .datasets .builder import worker_init_fn
25
26
from mmdet .datasets .samplers import DistributedGroupSampler , GroupSampler , DistributedSampler
27
+ from mmdet .datasets .pipelines .formating import to_tensor
26
28
27
29
from .davar_dataset_wrappers import DavarConcatDataset
28
30
from .davar_multi_dataset import DavarMultiDataset
29
31
30
-
31
32
if platform .system () != 'Windows' :
32
33
# https://github.com/pytorch/pytorch/issues/973
33
34
import resource
@@ -86,6 +87,8 @@ def davar_build_dataloader(dataset,
86
87
else :
87
88
sampler = kwargs .pop ('sampler' , None )
88
89
90
+ cfg_collate = kwargs .pop ('cfg_collate' , None )
91
+
89
92
# if choose distributed sampler
90
93
if dist :
91
94
# whether to shuffle data
@@ -134,7 +137,8 @@ def davar_build_dataloader(dataset,
134
137
batch_size = batch_size ,
135
138
sampler = sampler ,
136
139
num_workers = num_workers ,
137
- collate_fn = partial (collate , samples_per_gpu = samples_per_gpu ),
140
+ collate_fn = multi_frame_collate if cfg_collate == 'multi_frame_collate' else partial (collate , samples_per_gpu =
141
+ samples_per_gpu ),
138
142
pin_memory = False ,
139
143
worker_init_fn = init_fn ,
140
144
** kwargs )
@@ -283,3 +287,58 @@ def parameter_align(cfg):
283
287
align_para .append (temp_dict )
284
288
285
289
return align_para
290
+
291
+
292
+ def multi_frame_collate (batch ):
293
+ """
294
+ Args:
295
+ batch (list): one batch data
296
+ Returns:
297
+ dict: collate batch data
298
+ """
299
+ data = dict ()
300
+ # this collate func only support batch[0] contains multi instances
301
+ if isinstance (batch [0 ], list ):
302
+ img_meta = []
303
+ img = []
304
+ gt_mask = []
305
+ max_w , max_h = 0 , 0
306
+ max_mask_w , max_mask_h = 0 , 0
307
+
308
+ # calculate the max width and max height to pad
309
+ for i in range (len (batch )):
310
+ for j in range (len (batch [i ])):
311
+ size = batch [i ][j ]['img' ].data .size ()
312
+ size_mask = batch [i ][j ]['gt_masks' ].data .shape
313
+ if max_w < size [1 ]:
314
+ max_w = size [1 ]
315
+ if max_h < size [2 ]:
316
+ max_h = size [2 ]
317
+ if max_mask_w < size_mask [1 ]:
318
+ max_mask_w = size_mask [1 ]
319
+ if max_mask_h < size_mask [2 ]:
320
+ max_mask_h = size_mask [2 ]
321
+
322
+ # pad each img and gt into max width and height
323
+ for i in range (len (batch )):
324
+ for j in range (len (batch [i ])):
325
+ img_meta .append (batch [i ][j ]['img_metas' ].data )
326
+ c , w , h = batch [i ][j ]['img' ].data .size ()
327
+ tmp_img = torch .zeros ((c , max_w , max_h ), dtype = torch .float )
328
+ tmp_img [:, 0 :w , 0 :h ] = batch [i ][j ]['img' ].data
329
+ img .append (tmp_img )
330
+ c_mask , w_mask , h_mask = batch [i ][j ]['gt_masks' ].data .shape
331
+ tmp_mask = torch .zeros ((c_mask , max_mask_w , max_mask_h ), dtype = torch .float )
332
+ mask = to_tensor (batch [i ][j ]['gt_masks' ].data )
333
+ tmp_mask [:, :w_mask , :h_mask ] = mask
334
+ gt_mask .append (tmp_mask )
335
+
336
+ img = DC ([torch .stack (img , dim = 0 )])
337
+ gt_mask = DC ([torch .stack (gt_mask , dim = 0 )])
338
+ data ['img_metas' ] = DC ([img_meta ], cpu_only = True )
339
+ data ['img' ] = img
340
+ data ['gt_masks' ] = gt_mask
341
+
342
+ else :
343
+ raise "not support type {} of batch" .format (type (batch [0 ]))
344
+ return data
0 commit comments