Skip to content

Commit 0ac88ed

Browse files
committed
release VSR
1 parent 892a352 commit 0ac88ed

File tree

65 files changed

+4553
-2
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+4553
-2
lines changed

davarocr/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .davar_rcg import *
1414
from .davar_spotting import *
1515
from .davar_ie import *
16+
from .davar_layout import *
1617
from .davar_videotext import *
1718
from .davar_table import *
1819
from .mmcv import *

davarocr/davar_layout/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2021-12-06
9+
##################################################################################################
10+
"""
11+
from .datasets import *
12+
from .models import *
+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2021-12-06
9+
##################################################################################################
10+
"""
11+
from .docbank_dataset import DocBankDataset
12+
from .publaynet_dataset import PublaynetDataset
13+
from .pipelines import MMLALoadAnnotations, MMLAFormatBundle, CharTokenize
14+
15+
__all__ = ['DocBankDataset', 'PublaynetDataset', 'MMLALoadAnnotations', 'MMLAFormatBundle',
16+
'CharTokenize']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : docbank_dataset.py
5+
# Abstract : Dataset definition for docbank dataset.
6+
7+
# Current Version: 1.0.0
8+
# Date : 2020-12-06
9+
##################################################################################################
10+
"""
11+
import json
12+
import os
13+
import copy
14+
import random
15+
import torch
16+
import numpy as np
17+
18+
from mmdet.models.losses import accuracy
19+
from mmdet.datasets.builder import DATASETS
20+
from .mm_layout_dataset import MMLayoutDataset
21+
22+
23+
@DATASETS.register_module()
24+
class DocBankDataset(MMLayoutDataset):
25+
"""
26+
Dataset defination for DocBank dataset.
27+
28+
Ref: [1] DocBank: A Benchmark Dataset for Document Layout Analysis, COLING 2020.
29+
"""
30+
31+
CLASSES = None
32+
33+
def __init__(self,
34+
ann_file,
35+
pipeline,
36+
data_root=None,
37+
img_prefix='',
38+
seg_prefix=None,
39+
proposal_file=None,
40+
test_mode=False,
41+
filter_empty_gt=True,
42+
classes_config=None,
43+
classes=None,
44+
ann_prefix='',
45+
eval_level=0,
46+
max_num=1024):
47+
"""
48+
Args:
49+
ann_file(str): the path to datalist.
50+
pipeline(list(dict)): the data-flow handling pipeline
51+
data_root(str): the root path of the dataset
52+
img_prefix(str): the image prefixes
53+
seg_prefix(str): the segmentation maps prefixes
54+
proposal_file(str): the path to the preset proposal files.
55+
test_mode(boolean): whether in test mode
56+
filter_empty_gt(boolean): whether to filter out image without ground-truthes.
57+
classes_config(str): the path to classes config file, used to transfer 'str' labels into 'int'
58+
classes(str): Dataset class, default None.
59+
ann_prefix(str): Annotation prefix path for each annotation file.
60+
eval_level(int): evaluation in which level. 1 for highest level, 0 for lowest level.
61+
max_num(int): specify the max number of tokens loading.
62+
"""
63+
self.max_num = max_num
64+
super().__init__(
65+
ann_file=ann_file,
66+
pipeline=pipeline,
67+
data_root=data_root,
68+
img_prefix=img_prefix,
69+
seg_prefix=seg_prefix,
70+
proposal_file=proposal_file,
71+
test_mode=test_mode,
72+
filter_empty_gt=filter_empty_gt,
73+
classes_config=classes_config,
74+
ann_prefix=ann_prefix,
75+
classes=classes,
76+
eval_level=eval_level
77+
)
78+
79+
def pre_prepare(self, img_info):
80+
"""Load per annotation file and reset img_info ann& ann2 fields. ann denotes the annotations in token level and
81+
ann2 in layout level.
82+
83+
Args:
84+
img_info(dict): img_info dict.
85+
86+
Returns:
87+
dict: updated img_info.
88+
89+
"""
90+
if img_info['url'] is not None:
91+
tmp_img_info = copy.deepcopy(img_info)
92+
ann = json.load(open(os.path.join(self.ann_prefix, tmp_img_info['url']), 'r', encoding='utf8'))
93+
94+
if "content_ann" in ann.keys():
95+
tmp_img_info["ann"] = ann["content_ann"]
96+
cares = ann["content_ann"]["cares"]
97+
bboxes = ann["content_ann"]["bboxes"]
98+
cnt_bboxes = 0
99+
areas = []
100+
for idx, per_bbox in enumerate(bboxes):
101+
w_s, h_s, w_e, h_e = per_bbox
102+
area = (w_e - w_s) * (h_e - h_s)
103+
areas.append(area)
104+
if w_e > w_s and h_e > h_s:
105+
cnt_bboxes += 1
106+
continue
107+
else:
108+
# filter bboxes whose area equals 0.
109+
cares[idx] = 0
110+
111+
# we divide all tokens into three groups according to their areas, and sample due to memory limit.
112+
if cnt_bboxes > self.max_num:
113+
area1 = []
114+
area10 = []
115+
area10_up = []
116+
for idx, area in enumerate(areas):
117+
if area > 10:
118+
area10_up.append(idx)
119+
elif area > 1:
120+
area10.append(idx)
121+
elif area == 1:
122+
area1.append(idx)
123+
else:
124+
continue
125+
if len(area1) > self.max_num//16:
126+
index = random.sample(area1, len(area1) - self.max_num//16)
127+
for i in index:
128+
cares[i] = 0
129+
130+
if len(area10) > self.max_num//16:
131+
index10 = random.sample(area10, len(area10) - self.max_num//16)
132+
for i in index10:
133+
cares[i] = 0
134+
135+
num_res = self.max_num - min(self.max_num//16, len(area1)) - min(self.max_num//16, len(area10))
136+
if len(area10_up) > num_res:
137+
index10_up = random.sample(area10_up, len(area10_up) - num_res)
138+
for i in index10_up:
139+
cares[i] = 0
140+
141+
tmp_img_info["ann"]["cares"] = cares
142+
else:
143+
tmp_img_info["ann"] = None
144+
145+
if "content_ann2" in ann.keys():
146+
tmp_img_info["ann2"] = ann["content_ann2"]
147+
148+
# filter wrong labels to not care
149+
cares = ann["content_ann2"]["cares"]
150+
bboxes = ann["content_ann2"]["bboxes"]
151+
for idx, per_bbox in enumerate(bboxes):
152+
w_s, h_s, w_e, h_e = per_bbox
153+
if w_e > w_s and h_e > h_s:
154+
continue
155+
else:
156+
cares[idx] = 0
157+
tmp_img_info["ann2"]["cares"] = cares
158+
159+
else:
160+
tmp_img_info["ann2"] = None
161+
162+
return tmp_img_info
163+
164+
else:
165+
return img_info
166+
167+
def evaluate(self,
168+
results,
169+
logger=None,
170+
metric='F1-score'):
171+
"""Evaluate the dataset.
172+
173+
Args:
174+
results (list): Testing results of the dataset.
175+
metric (str | list[str]): Metrics to be evaluated.
176+
logger (logging.Logger | None | str): Logger used for printing
177+
related information during evaluation. Default: None.
178+
179+
"""
180+
if not isinstance(metric, str):
181+
assert len(metric) == 1
182+
metric = metric[0]
183+
allowed_metrics = ['acc', 'F1-score']
184+
if metric not in allowed_metrics:
185+
raise KeyError('metric {} is not supported'.format(metric))
186+
187+
annotations = [self.process_anns(i) for i in range(len(self))]
188+
bboxes = [annotations[i]["bboxes"] for i in range(len(annotations))]
189+
labels = [annotations[i]["labels"] for i in range(len(annotations))]
190+
cares = [annotations[i]["cares"] for i in range(len(annotations))]
191+
labels_care = []
192+
bboxes_care = []
193+
classes = self.classes_config["classes_0"]
194+
195+
# remove not care tokens
196+
for i in range(len(labels)):
197+
labels_tmp = [classes.index(labels[i][j]) for j in range(len(labels[i])) if cares[i][j] != 0]
198+
bboxes_tmp = [bboxes[i][j] for j in range(len(bboxes[i])) if cares[i][j] != 0]
199+
for j in range(len(labels_tmp)):
200+
labels_care.append(labels_tmp[j])
201+
bboxes_care.append(bboxes_tmp[j])
202+
203+
eval_results = {}
204+
results = np.array([per_result[j] for per_result in results for j in range(len(per_result))])
205+
206+
# acc for each category
207+
if metric == 'acc':
208+
results_new = [[] for i in range(len(classes))]
209+
labels_new = [[] for i in range(len(classes))]
210+
for i in range(len(labels_care)):
211+
labels_new[labels_care[i]].append(labels_care[i])
212+
results_new[labels_care[i]].append(results[i])
213+
for i in range(len(labels_new)):
214+
results_per = torch.Tensor(np.array(results_new[i]))
215+
labels_per = torch.Tensor(np.array(labels_new[i]))
216+
acc_per = accuracy(results_per, labels_per)
217+
eval_results['acc@{}'.format(classes[i])] = float(acc_per)
218+
219+
# f1-score for each category
220+
# calculate pre, recall and f1 according to [1]
221+
if metric == 'F1-score':
222+
gt_area, cor_area, pre_area, precision, recall, f1_score = [[0 for i in range(len(classes))] for j in range(6)]
223+
for i in range(len(labels_care)):
224+
area = (bboxes_care[i][2]-bboxes_care[i][0])*(bboxes_care[i][3]-bboxes_care[i][1])
225+
gt_area[labels_care[i]] += area
226+
label_pre = np.argmax(results[i])
227+
pre_area[label_pre] += area
228+
if label_pre == labels_care[i]:
229+
cor_area[labels_care[i]] += area
230+
else:
231+
continue
232+
f1_list = []
233+
for i in range(len(gt_area)):
234+
if gt_area[i] == 0:
235+
continue
236+
else:
237+
precision[i] = cor_area[i] / (pre_area[i] + 0.01)
238+
recall[i] = cor_area[i] / gt_area[i]
239+
f1_score[i] = 2 * precision[i] * recall[i] / (precision[i] + recall[i] + 0.01)
240+
eval_results['precision@{}'.format(classes[i])] = float(precision[i])
241+
eval_results['recall@{}'.format(classes[i])] = float(recall[i])
242+
eval_results['F1 score@{}'.format(classes[i])] = float(f1_score[i])
243+
f1_list.append(float(f1_score[i]))
244+
245+
avg_f1 = sum(f1_list) / (len(f1_list) + 1e-3)
246+
eval_results['avg_f1'] = avg_f1
247+
248+
return eval_results

0 commit comments

Comments
 (0)