Skip to content

Commit 997f8ce

Browse files
ziyeqinghancopybara-github
authored andcommitted
1. add missing part in ann_json_dict in object_detector_dataloader.
2. Move create_pascal_voc function to test_util so that other unittest (object_detector task unittest) can call it as well. PiperOrigin-RevId: 351260755
1 parent acbd328 commit 997f8ce

File tree

4 files changed

+82
-24
lines changed

4 files changed

+82
-24
lines changed

tensorflow_examples/lite/model_maker/core/data_util/object_detector_dataloader.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _get_object_detector_cache_filenames(cache_dir, image_dir, annotations_dir,
5353
"""Gets cache filenames for obejct detector."""
5454
if cache_dir is None:
5555
cache_dir = tempfile.mkdtemp()
56-
print('Create the cache directory: %s.', cache_dir)
56+
print('Create the cache directory: %s.' % cache_dir)
5757
cache_prefix = _get_cache_prefix(image_dir, annotations_dir, annotations_list)
5858
cache_prefix = os.path.join(cache_dir, cache_prefix)
5959

@@ -211,7 +211,12 @@ def _write_pascal_tfrecord(cls, images_dir, annotations_dir, label_map_dict,
211211
for idx, name in label_map_dict.items():
212212
label_name2id_dict[name] = idx
213213
writers = [tf.io.TFRecordWriter(path) for path in tfrecord_files]
214+
214215
ann_json_dict = {'images': [], 'annotations': [], 'categories': []}
216+
for class_id, class_name in label_map_dict.items():
217+
c = {'supercategory': 'none', 'id': class_id, 'name': class_name}
218+
ann_json_dict['categories'].append(c)
219+
215220
# Gets the paths to annotations.
216221
if annotations_list:
217222
ann_path_list = [
@@ -246,7 +251,7 @@ def _write_pascal_tfrecord(cls, images_dir, annotations_dir, label_map_dict,
246251
writer.close()
247252

248253
with tf.io.gfile.GFile(annotations_json_file, 'w') as f:
249-
json.dump(ann_json_dict, f)
254+
json.dump(ann_json_dict, f, indent=2)
250255

251256
def gen_dataset(self,
252257
model_spec,

tensorflow_examples/lite/model_maker/core/data_util/object_detector_dataloader_test.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
from __future__ import division
1717
from __future__ import print_function
1818

19+
import filecmp
1920
import os
2021

2122
import numpy as np
22-
import PIL.Image
23+
2324
import tensorflow as tf
2425
from tensorflow_examples.lite.model_maker.core import test_util
2526

@@ -40,28 +41,9 @@ def __init__(self, model_name):
4041

4142
class ObjectDectectorDataLoaderTest(tf.test.TestCase):
4243

43-
def _create_pascal_voc(self):
44-
# Saves the image into images_dir.
45-
image_file_name = '2012_12.jpg'
46-
image_data = np.random.rand(256, 256, 3)
47-
images_dir = os.path.join(self.get_temp_dir(), 'images')
48-
os.mkdir(images_dir)
49-
save_path = os.path.join(images_dir, image_file_name)
50-
image = PIL.Image.fromarray(image_data, 'RGB')
51-
image.save(save_path)
52-
53-
# Gets the annonation path.
54-
annotations_path = test_util.get_test_data_path('2012_12.xml')
55-
annotations_dir = os.path.dirname(annotations_path)
56-
57-
label_map = {
58-
1: 'person',
59-
2: 'notperson',
60-
}
61-
return images_dir, annotations_dir, label_map
62-
6344
def test_from_pascal_voc(self):
64-
images_dir, annotations_dir, label_map = self._create_pascal_voc()
45+
images_dir, annotations_dir, label_map = test_util.create_pascal_voc(
46+
self.get_temp_dir())
6547
model_spec = MockDetectorModelSpec('efficientdet-lite0')
6648

6749
data = object_detector_dataloader.DataLoader.from_pascal_voc(
@@ -71,6 +53,11 @@ def test_from_pascal_voc(self):
7153
self.assertLen(data, 1)
7254
self.assertEqual(data.label_map, label_map)
7355

56+
self.assertTrue(os.path.isfile(data.annotations_json_file))
57+
self.assertGreater(os.path.getsize(data.annotations_json_file), 0)
58+
expected_json_file = test_util.get_test_data_path('annotations.json')
59+
self.assertTrue(filecmp.cmp(data.annotations_json_file, expected_json_file))
60+
7461
ds = data.gen_dataset(model_spec, batch_size=1, is_training=False)
7562
for i, (images, labels) in enumerate(ds):
7663
self.assertEqual(i, 0)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
{
2+
"images": [
3+
{
4+
"file_name": "2012_12.jpg",
5+
"height": 256,
6+
"width": 256,
7+
"id": 1
8+
}
9+
],
10+
"annotations": [
11+
{
12+
"area": 16384,
13+
"iscrowd": 0,
14+
"image_id": 1,
15+
"bbox": [
16+
64,
17+
64,
18+
128,
19+
128
20+
],
21+
"category_id": 1,
22+
"id": 1,
23+
"ignore": 0,
24+
"segmentation": []
25+
}
26+
],
27+
"categories": [
28+
{
29+
"supercategory": "none",
30+
"id": 1,
31+
"name": "person"
32+
},
33+
{
34+
"supercategory": "none",
35+
"id": 2,
36+
"name": "notperson"
37+
}
38+
]
39+
}

tensorflow_examples/lite/model_maker/core/test_util.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
import functools
2020
import os
2121
import shutil
22+
import tempfile
2223

2324
from absl import flags
2425
import numpy as np
26+
import PIL.Image
2527

2628
import tensorflow.compat.v2 as tf
2729
from tensorflow_examples.lite.model_maker.core import compat
@@ -138,6 +140,31 @@ def get_dataloader(data_size, input_shape, num_classes, max_input_value=1000):
138140
return data
139141

140142

143+
def create_pascal_voc(temp_dir=None):
144+
"""Creates test data with PASCAL VOC format."""
145+
if temp_dir is None or not tf.io.gfile.exists(temp_dir):
146+
temp_dir = tempfile.mkdtemp()
147+
148+
# Saves the image into images_dir.
149+
image_file_name = "2012_12.jpg"
150+
image_data = np.random.rand(256, 256, 3)
151+
images_dir = os.path.join(temp_dir, "images")
152+
os.mkdir(images_dir)
153+
save_path = os.path.join(images_dir, image_file_name)
154+
image = PIL.Image.fromarray(image_data, "RGB")
155+
image.save(save_path)
156+
157+
# Gets the annonation path.
158+
annotations_path = get_test_data_path("2012_12.xml")
159+
annotations_dir = os.path.dirname(annotations_path)
160+
161+
label_map = {
162+
1: "person",
163+
2: "notperson",
164+
}
165+
return images_dir, annotations_dir, label_map
166+
167+
141168
def is_same_output(tflite_file,
142169
keras_model,
143170
input_tensors,

0 commit comments

Comments
 (0)