Skip to content

Commit 74ac5e2

Browse files
committedSep 29, 2023
kerascv
1 parent 89c3865 commit 74ac5e2

File tree

3 files changed

+170
-0
lines changed

3 files changed

+170
-0
lines changed
 

‎kerascv/extra_reading.txt

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
KerasCV List of Models
2+
https://keras.io/api/keras_cv/models/
3+
4+
Fast R-CNN (Ross Girshick)
5+
https://arxiv.org/pdf/1504.08083.pdf
6+
7+
Focal Loss for Dense Object Detection (Lin et al.)
8+
https://arxiv.org/abs/1708.02002

‎kerascv/makelist.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
'''
2+
Use this script to generate a list of all XML files in a folder.
3+
'''
4+
5+
from glob import glob
6+
7+
files = glob('*.xml')
8+
with open('xml_list.txt', 'w') as f:
9+
for fn in files:
10+
f.write("%s\n" % fn)

‎kerascv/pascal2coco.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# adapted from https://blog.roboflow.com/how-to-convert-annotations-from-voc-xml-to-coco-json/
2+
3+
import os
4+
import argparse
5+
import json
6+
import xml.etree.ElementTree as ET
7+
from typing import Dict, List
8+
from tqdm import tqdm
9+
import re
10+
11+
12+
def get_label2id(labels_path: str) -> Dict[str, int]:
13+
"""id is 1 start"""
14+
with open(labels_path, 'r') as f:
15+
labels_str = f.read().split()
16+
labels_ids = list(range(0, len(labels_str)))
17+
return dict(zip(labels_str, labels_ids))
18+
19+
20+
def get_annpaths(ann_dir_path: str = None,
21+
ann_ids_path: str = None,
22+
ext: str = '',
23+
annpaths_list_path: str = None) -> List[str]:
24+
# If use annotation paths list
25+
if annpaths_list_path is not None:
26+
with open(annpaths_list_path, 'r') as f:
27+
ann_paths = f.read().split()
28+
return ann_paths
29+
30+
# If use annotaion ids list
31+
ext_with_dot = '.' + ext if ext != '' else ''
32+
with open(ann_ids_path, 'r') as f:
33+
ann_ids = f.read().split()
34+
ann_paths = [os.path.join(ann_dir_path, aid+ext_with_dot) for aid in ann_ids]
35+
return ann_paths
36+
37+
38+
def get_image_info(annotation_root, extract_num_from_imgid=True):
39+
path = annotation_root.findtext('path')
40+
if path is None:
41+
filename = annotation_root.findtext('filename')
42+
else:
43+
filename = os.path.basename(path)
44+
img_name = os.path.basename(filename)
45+
img_id = os.path.splitext(img_name)[0]
46+
if extract_num_from_imgid and isinstance(img_id, str):
47+
img_id = int(re.findall(r'\d+', img_id)[0])
48+
49+
size = annotation_root.find('size')
50+
width = int(size.findtext('width'))
51+
height = int(size.findtext('height'))
52+
53+
image_info = {
54+
'file_name': filename,
55+
'height': height,
56+
'width': width,
57+
'id': img_id
58+
}
59+
return image_info
60+
61+
62+
def get_coco_annotation_from_obj(obj, label2id):
63+
label = obj.findtext('name')
64+
assert label in label2id, f"Error: {label} is not in label2id !"
65+
category_id = label2id[label]
66+
bndbox = obj.find('bndbox')
67+
xmin = int(bndbox.findtext('xmin')) - 1
68+
ymin = int(bndbox.findtext('ymin')) - 1
69+
xmax = int(bndbox.findtext('xmax'))
70+
ymax = int(bndbox.findtext('ymax'))
71+
assert xmax > xmin and ymax > ymin, f"Box size error !: (xmin, ymin, xmax, ymax): {xmin, ymin, xmax, ymax}"
72+
o_width = xmax - xmin
73+
o_height = ymax - ymin
74+
ann = {
75+
'area': o_width * o_height,
76+
'iscrowd': 0,
77+
'bbox': [xmin, ymin, o_width, o_height],
78+
'category_id': category_id,
79+
'ignore': 0,
80+
'segmentation': [] # This script is not for segmentation
81+
}
82+
return ann
83+
84+
85+
def convert_xmls_to_cocojson(annotation_paths: List[str],
86+
label2id: Dict[str, int],
87+
output_jsonpath: str,
88+
extract_num_from_imgid: bool = True):
89+
output_json_dict = {
90+
"images": [],
91+
"type": "instances",
92+
"annotations": [],
93+
"categories": []
94+
}
95+
bnd_id = 1 # START_BOUNDING_BOX_ID, TODO input as args ?
96+
print('Start converting !')
97+
for a_path in tqdm(annotation_paths):
98+
# Read annotation xml
99+
ann_tree = ET.parse(a_path)
100+
ann_root = ann_tree.getroot()
101+
102+
img_info = get_image_info(annotation_root=ann_root,
103+
extract_num_from_imgid=extract_num_from_imgid)
104+
img_id = img_info['id']
105+
output_json_dict['images'].append(img_info)
106+
107+
for obj in ann_root.findall('object'):
108+
ann = get_coco_annotation_from_obj(obj=obj, label2id=label2id)
109+
ann.update({'image_id': img_id, 'id': bnd_id})
110+
output_json_dict['annotations'].append(ann)
111+
bnd_id = bnd_id + 1
112+
113+
for label, label_id in label2id.items():
114+
category_info = {'supercategory': 'none', 'id': label_id, 'name': label}
115+
output_json_dict['categories'].append(category_info)
116+
117+
with open(output_jsonpath, 'w') as f:
118+
output_json = json.dumps(output_json_dict)
119+
f.write(output_json)
120+
121+
122+
def main():
123+
parser = argparse.ArgumentParser(
124+
description='This script support converting voc format xmls to coco format json')
125+
parser.add_argument('--ann_dir', type=str, default=None,
126+
help='path to annotation files directory. It is not need when use --ann_paths_list')
127+
parser.add_argument('--ann_ids', type=str, default=None,
128+
help='path to annotation files ids list. It is not need when use --ann_paths_list')
129+
parser.add_argument('--ann_paths_list', type=str, default=None,
130+
help='path of annotation paths list. It is not need when use --ann_dir and --ann_ids')
131+
parser.add_argument('--labels', type=str, default=None,
132+
help='path to label list.')
133+
parser.add_argument('--output', type=str, default='output.json', help='path to output json file')
134+
parser.add_argument('--ext', type=str, default='', help='additional extension of annotation file')
135+
args = parser.parse_args()
136+
label2id = get_label2id(labels_path=args.labels)
137+
ann_paths = get_annpaths(
138+
ann_dir_path=args.ann_dir,
139+
ann_ids_path=args.ann_ids,
140+
ext=args.ext,
141+
annpaths_list_path=args.ann_paths_list
142+
)
143+
convert_xmls_to_cocojson(
144+
annotation_paths=ann_paths,
145+
label2id=label2id,
146+
output_jsonpath=args.output,
147+
extract_num_from_imgid=True
148+
)
149+
150+
151+
if __name__ == '__main__':
152+
main()

0 commit comments

Comments
 (0)
Please sign in to comment.