1111import copy
1212import platform
1313from functools import partial
14-
14+ import torch
1515from torch .utils .data import DataLoader
1616
1717from mmcv .utils import Registry
1818from mmcv .utils import build_from_cfg
1919from mmcv .parallel import collate
2020from mmcv .runner import get_dist_info
21+ from mmcv .parallel import DataContainer as DC
2122
2223from mmdet .datasets import DATASETS
2324from mmdet .models .builder import build
2425from mmdet .datasets .builder import worker_init_fn
2526from mmdet .datasets .samplers import DistributedGroupSampler , GroupSampler , DistributedSampler
27+ from mmdet .datasets .pipelines .formating import to_tensor
2628
2729from .davar_dataset_wrappers import DavarConcatDataset
2830from .davar_multi_dataset import DavarMultiDataset
2931
30-
3132if platform .system () != 'Windows' :
3233 # https://github.com/pytorch/pytorch/issues/973
3334 import resource
@@ -86,6 +87,8 @@ def davar_build_dataloader(dataset,
8687 else :
8788 sampler = kwargs .pop ('sampler' , None )
8889
90+ cfg_collate = kwargs .pop ('cfg_collate' , None )
91+
8992 # if choose distributed sampler
9093 if dist :
9194 # whether to shuffle data
@@ -134,7 +137,8 @@ def davar_build_dataloader(dataset,
134137 batch_size = batch_size ,
135138 sampler = sampler ,
136139 num_workers = num_workers ,
137- collate_fn = partial (collate , samples_per_gpu = samples_per_gpu ),
140+ collate_fn = multi_frame_collate if cfg_collate == 'multi_frame_collate' else partial (collate , samples_per_gpu =
141+ samples_per_gpu ),
138142 pin_memory = False ,
139143 worker_init_fn = init_fn ,
140144 ** kwargs )
@@ -283,3 +287,58 @@ def parameter_align(cfg):
283287 align_para .append (temp_dict )
284288
285289 return align_para
290+
291+
292+ def multi_frame_collate (batch ):
293+ """
294+ Args:
295+ batch (list): one batch data
296+ Returns:
297+ dict: collate batch data
298+ """
299+ data = dict ()
300+ # this collate func only support batch[0] contains multi instances
301+ if isinstance (batch [0 ], list ):
302+ img_meta = []
303+ img = []
304+ gt_mask = []
305+ max_w , max_h = 0 , 0
306+ max_mask_w , max_mask_h = 0 , 0
307+
308+ # calculate the max width and max height to pad
309+ for i in range (len (batch )):
310+ for j in range (len (batch [i ])):
311+ size = batch [i ][j ]['img' ].data .size ()
312+ size_mask = batch [i ][j ]['gt_masks' ].data .shape
313+ if max_w < size [1 ]:
314+ max_w = size [1 ]
315+ if max_h < size [2 ]:
316+ max_h = size [2 ]
317+ if max_mask_w < size_mask [1 ]:
318+ max_mask_w = size_mask [1 ]
319+ if max_mask_h < size_mask [2 ]:
320+ max_mask_h = size_mask [2 ]
321+
322+ # pad each img and gt into max width and height
323+ for i in range (len (batch )):
324+ for j in range (len (batch [i ])):
325+ img_meta .append (batch [i ][j ]['img_metas' ].data )
326+ c , w , h = batch [i ][j ]['img' ].data .size ()
327+ tmp_img = torch .zeros ((c , max_w , max_h ), dtype = torch .float )
328+ tmp_img [:, 0 :w , 0 :h ] = batch [i ][j ]['img' ].data
329+ img .append (tmp_img )
330+ c_mask , w_mask , h_mask = batch [i ][j ]['gt_masks' ].data .shape
331+ tmp_mask = torch .zeros ((c_mask , max_mask_w , max_mask_h ), dtype = torch .float )
332+ mask = to_tensor (batch [i ][j ]['gt_masks' ].data )
333+ tmp_mask [:, :w_mask , :h_mask ] = mask
334+ gt_mask .append (tmp_mask )
335+
336+ img = DC ([torch .stack (img , dim = 0 )])
337+ gt_mask = DC ([torch .stack (gt_mask , dim = 0 )])
338+ data ['img_metas' ] = DC ([img_meta ], cpu_only = True )
339+ data ['img' ] = img
340+ data ['gt_masks' ] = gt_mask
341+
342+ else :
343+ raise "not support type {} of batch" .format (type (batch [0 ]))
344+ return data
0 commit comments