Skip to content

Commit 002611e

Browse files
committed
support for video task
1 parent afc7899 commit 002611e

File tree

2 files changed

+76
-6
lines changed

2 files changed

+76
-6
lines changed

davarocr/davarocr/davar_common/apis/inference.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def init_model(config, checkpoint=None, device='cuda:0', cfg_options=None):
5757
elif cfg_types == "SPOTTER":
5858
from davarocr.davar_spotting.models.builder import build_spotter
5959
model = build_spotter(config.model, test_cfg=config.get('test_cfg'))
60+
elif cfg_types == "NER":
61+
from davarocr.davar_ner.models.builder import build_ner
62+
model = build_ner(config.model, test_cfg=config.get('test_cfg'))
6063
else:
6164
raise NotImplementedError
6265

@@ -97,7 +100,12 @@ def inference_model(model, imgs):
97100
test_pipeline = Compose(cfg.data.test.pipeline)
98101

99102
# Prepare data
100-
if isinstance(imgs, (str, np.ndarray)):
103+
if isinstance(imgs, dict):
104+
data = imgs
105+
data = test_pipeline(data)
106+
device = int(str(device).split(":")[-1])
107+
data = scatter(collate([data], samples_per_gpu=1), [device])[0]
108+
elif isinstance(imgs, (str, np.ndarray)):
101109
# If the input is single image
102110
data = dict(img=imgs)
103111
data = test_pipeline(data)
@@ -107,11 +115,14 @@ def inference_model(model, imgs):
107115
# If the input are batch of images
108116
batch_data = []
109117
for img in imgs:
110-
data = dict(img=img)
118+
if isinstance(img, dict):
119+
data = dict(img_info=img)
120+
else:
121+
data = dict(img=img)
111122
data = test_pipeline(data)
112123
batch_data.append(data)
113124
data_collate = collate(batch_data, samples_per_gpu=len(batch_data))
114-
device = int(str(device).split(":")[-1])
125+
device = int(str(device).rsplit(':', maxsplit=1)[-1])
115126
data = scatter(data_collate, [device])[0]
116127

117128
# Forward inference

davarocr/davarocr/davar_common/datasets/builder.py

+62-3
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,24 @@
1111
import copy
1212
import platform
1313
from functools import partial
14-
14+
import torch
1515
from torch.utils.data import DataLoader
1616

1717
from mmcv.utils import Registry
1818
from mmcv.utils import build_from_cfg
1919
from mmcv.parallel import collate
2020
from mmcv.runner import get_dist_info
21+
from mmcv.parallel import DataContainer as DC
2122

2223
from mmdet.datasets import DATASETS
2324
from mmdet.models.builder import build
2425
from mmdet.datasets.builder import worker_init_fn
2526
from mmdet.datasets.samplers import DistributedGroupSampler, GroupSampler, DistributedSampler
27+
from mmdet.datasets.pipelines.formating import to_tensor
2628

2729
from .davar_dataset_wrappers import DavarConcatDataset
2830
from .davar_multi_dataset import DavarMultiDataset
2931

30-
3132
if platform.system() != 'Windows':
3233
# https://github.com/pytorch/pytorch/issues/973
3334
import resource
@@ -86,6 +87,8 @@ def davar_build_dataloader(dataset,
8687
else:
8788
sampler = kwargs.pop('sampler', None)
8889

90+
cfg_collate = kwargs.pop('cfg_collate', None)
91+
8992
# if choose distributed sampler
9093
if dist:
9194
# whether to shuffle data
@@ -134,7 +137,8 @@ def davar_build_dataloader(dataset,
134137
batch_size=batch_size,
135138
sampler=sampler,
136139
num_workers=num_workers,
137-
collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
140+
collate_fn=multi_frame_collate if cfg_collate == 'multi_frame_collate' else partial(collate, samples_per_gpu=
141+
samples_per_gpu),
138142
pin_memory=False,
139143
worker_init_fn=init_fn,
140144
**kwargs)
@@ -283,3 +287,58 @@ def parameter_align(cfg):
283287
align_para.append(temp_dict)
284288

285289
return align_para
290+
291+
292+
def multi_frame_collate(batch):
293+
"""
294+
Args:
295+
batch (list): one batch data
296+
Returns:
297+
dict: collate batch data
298+
"""
299+
data = dict()
300+
# this collate func only support batch[0] contains multi instances
301+
if isinstance(batch[0], list):
302+
img_meta = []
303+
img = []
304+
gt_mask = []
305+
max_w, max_h = 0, 0
306+
max_mask_w, max_mask_h = 0, 0
307+
308+
# calculate the max width and max height to pad
309+
for i in range(len(batch)):
310+
for j in range(len(batch[i])):
311+
size = batch[i][j]['img'].data.size()
312+
size_mask = batch[i][j]['gt_masks'].data.shape
313+
if max_w < size[1]:
314+
max_w = size[1]
315+
if max_h < size[2]:
316+
max_h = size[2]
317+
if max_mask_w < size_mask[1]:
318+
max_mask_w = size_mask[1]
319+
if max_mask_h < size_mask[2]:
320+
max_mask_h = size_mask[2]
321+
322+
# pad each img and gt into max width and height
323+
for i in range(len(batch)):
324+
for j in range(len(batch[i])):
325+
img_meta.append(batch[i][j]['img_metas'].data)
326+
c, w, h = batch[i][j]['img'].data.size()
327+
tmp_img = torch.zeros((c, max_w, max_h), dtype=torch.float)
328+
tmp_img[:, 0:w, 0:h] = batch[i][j]['img'].data
329+
img.append(tmp_img)
330+
c_mask, w_mask, h_mask = batch[i][j]['gt_masks'].data.shape
331+
tmp_mask = torch.zeros((c_mask, max_mask_w, max_mask_h), dtype=torch.float)
332+
mask = to_tensor(batch[i][j]['gt_masks'].data)
333+
tmp_mask[:, :w_mask, :h_mask] = mask
334+
gt_mask.append(tmp_mask)
335+
336+
img = DC([torch.stack(img, dim=0)])
337+
gt_mask = DC([torch.stack(gt_mask, dim=0)])
338+
data['img_metas'] = DC([img_meta], cpu_only=True)
339+
data['img'] = img
340+
data['gt_masks'] = gt_mask
341+
342+
else:
343+
raise "not support type {} of batch".format(type(batch[0]))
344+
return data

0 commit comments

Comments
 (0)