Skip to content

Commit c00643c

Browse files
matt-deboerjsbroks
authored andcommitted
Copy annotations (#55)
* ignore local/generated files * remove speculative code bug * working propagate next/previous * use a separate route for copying annotations * use single file; using single initializer method * upsert function * upsert category test * validate from image ids * copy endpoint * simplified category create * all images endpoint * copy annotations * copy annotation bug fixes * copy selected categories * test case fix
1 parent bb05021 commit c00643c

15 files changed

+399
-117
lines changed

app/api/annotator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from flask_restplus import Namespace, Api, Resource
1+
import copy
2+
from flask_restplus import Namespace, Api, Resource, reqparse
23
from flask import request
34

45
from ..util import query_util
@@ -95,7 +96,6 @@ class AnnotatorId(Resource):
9596

9697
def get(self, image_id):
9798
""" Called when loading from the annotator client """
98-
9999
image = ImageModel.objects(id=image_id).first()
100100

101101
if image is None:

app/api/categories.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@ def post(self):
3535
color = args.get('color')
3636

3737
try:
38-
category = CategoryModel.create_category(
38+
category = CategoryModel(
3939
name=name,
4040
supercategory=supercategory,
4141
color=color,
4242
metadata=metadata
4343
)
44+
category.save()
4445
except (ValueError, TypeError) as e:
4546
return {'message': str(e)}, 400
4647

app/api/datasets.py

+2-27
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,7 @@ def post(self):
6060
name = args['name']
6161
categories = args.get('categories', [])
6262

63-
category_ids = []
64-
65-
for category in categories:
66-
if isinstance(category, int):
67-
category_ids.append(category)
68-
else:
69-
category_model = CategoryModel.objects(name=category).first()
70-
71-
if category_model is None:
72-
new_category = CategoryModel.create_category(name=category)
73-
category_ids.append(new_category.id)
74-
else:
75-
category_ids.append(category_model.id)
63+
category_ids = CategoryModel.bulk_create(categories)
7664

7765
try:
7866
dataset = DatasetModel(name=name, categories=category_ids)
@@ -139,21 +127,8 @@ def post(self, dataset_id):
139127
default_annotation_metadata = args.get('default_annotation_metadata')
140128

141129
if categories is not None:
142-
category_ids = []
143-
144-
for category in categories:
145-
if isinstance(category, int):
146-
category_ids.append(category)
147-
else:
148-
category_model = CategoryModel.objects(name=category).first()
149-
150-
if category_model is None:
151-
new_category = CategoryModel.create_category(name=category)
152-
category_ids.append(new_category.id)
153-
else:
154-
category_ids.append(category_model.id)
130+
dataset.categories = CategoryModel.bulk_create(categories)
155131

156-
dataset.categories = category_ids
157132

158133
if default_annotation_metadata is not None:
159134
dataset.default_annotation_metadata = default_annotation_metadata

app/api/images.py

+64-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from flask_restplus import Namespace, Resource, reqparse
22
from werkzeug.datastructures import FileStorage
3-
43
from flask import send_file
54

65
from ..util import query_util, coco_util, thumbnail_util
@@ -14,6 +13,12 @@
1413
api = Namespace('image', description='Image related operations')
1514

1615

16+
image_all = reqparse.RequestParser()
17+
image_all.add_argument('fields', required=False, type=str)
18+
image_all.add_argument('page', default=1, type=int)
19+
image_all.add_argument('perPage', default=50, type=int, required=False)
20+
21+
1722
image_upload = reqparse.RequestParser()
1823
image_upload.add_argument('image', location='files',
1924
type=FileStorage, required=True,
@@ -26,12 +31,37 @@
2631
image_download.add_argument('width', type=int, required=False, default=0)
2732
image_download.add_argument('height', type=int, required=False, default=0)
2833

34+
copy_annotations = reqparse.RequestParser()
35+
copy_annotations.add_argument('category_ids', location='json', type=list,
36+
required=False, default=None, help='Categories to copy')
37+
2938

3039
@api.route('/')
3140
class Images(Resource):
41+
@api.expect(image_all)
3242
def get(self):
3343
""" Returns all images """
34-
return query_util.fix_ids(ImageModel.objects(deteled=False).all())
44+
args = image_all.parse_args()
45+
per_page = args['perPage']
46+
page = args['page']-1
47+
fields = args.get('fields', "")
48+
49+
images = ImageModel.objects(deleted=False)
50+
total = images.count()
51+
pages = int(total/per_page) + 1
52+
53+
images = images.skip(page*per_page).limit(per_page)
54+
if fields:
55+
images = images.only(*fields.split(','))
56+
57+
return {
58+
"total": total,
59+
"pages": pages,
60+
"page": page,
61+
"fields": fields,
62+
"per_page": per_page,
63+
"images": query_util.fix_ids(images.all())
64+
}
3565

3666
@api.expect(image_upload)
3767
def post(self):
@@ -114,6 +144,38 @@ def delete(self, image_id):
114144
return {"success": True}
115145

116146

147+
@api.route('/copy/<int:from_id>/<int:to_id>/annotations')
148+
class ImageCopyAnnotations(Resource):
149+
150+
@api.expect(copy_annotations)
151+
def post(self, from_id, to_id):
152+
args = copy_annotations.parse_args()
153+
category_ids = args.get('category_ids')
154+
155+
image_from = ImageModel.objects(id=from_id).first()
156+
image_to = ImageModel.objects(id=to_id).first()
157+
158+
if image_from is None or image_to is None:
159+
return {'success': False, 'message': 'Invalid image ids'}, 400
160+
161+
if image_from == image_to:
162+
return {'success': False, 'message': 'Cannot copy self'}, 400
163+
164+
if image_from.width != image_to.width or image_from.height != image_to.height:
165+
return {'success': False, 'message': 'Image sizes do not match'}, 400
166+
167+
if category_ids is None:
168+
category_ids = DatasetModel.objects(id=image_from.dataset_id).first().categories
169+
170+
query = AnnotationModel.objects(
171+
image_id=image_from.id,
172+
category_id__in=category_ids,
173+
deleted=False
174+
)
175+
176+
return {'annotations_created': image_to.copy_annotations(query)}
177+
178+
117179
@api.route('/<int:image_id>/thumbnail')
118180
class ImageCoco(Resource):
119181

app/models.py

+69-16
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import os
22
import sys
33
import json
4+
import copy
5+
import numpy as np
6+
47
from flask_mongoengine import MongoEngine
8+
from .util.coco_util import decode_seg
59
from .util import color_util
610
from .config import Config
711
from PIL import Image
812

9-
1013
db = MongoEngine()
1114

1215

@@ -115,6 +118,24 @@ def thumbnail_path(self):
115118

116119
return '/'.join(folders)
117120

121+
def copy_annotations(self, annotations):
122+
"""
123+
Creates a copy of the annotations for this image
124+
:param annotations: QuerySet of annotation models
125+
:return: number of annotations
126+
"""
127+
annotations = annotations.filter(width=self.width, height=self.height, area__gt=0)
128+
129+
for annotation in annotations:
130+
clone = annotation.clone()
131+
132+
clone.dataset_id = self.dataset_id
133+
clone.image_id = self.id
134+
135+
clone.save(copy=True)
136+
137+
return annotations.count()
138+
118139

119140
class AnnotationModel(db.DynamicDocument):
120141

@@ -131,7 +152,7 @@ class AnnotationModel(db.DynamicDocument):
131152
width = db.IntField()
132153
height = db.IntField()
133154

134-
color = db.StringField(default=color_util.random_color_hex())
155+
color = db.StringField()
135156

136157
metadata = db.DictField(default={})
137158
paper_object = db.ListField(default=[])
@@ -153,21 +174,34 @@ def __init__(self, image_id=None, **data):
153174

154175
super(AnnotationModel, self).__init__(**data)
155176

156-
def save(self, *args, **kwargs):
177+
def save(self, copy=False, *args, **kwargs):
157178

158-
if self.dataset_id is not None:
179+
if not self.dataset_id and not copy:
159180
dataset = DatasetModel.objects(id=self.dataset_id).first()
160181

161182
if dataset is not None:
162-
metadata = dataset.default_annotation_metadata.copy()
163-
metadata.update(self.metadata)
164-
self.metadata = metadata
183+
self.metadata = dataset.default_annotation_metadata.copy()
184+
185+
if self.color is None:
186+
self.color = color_util.random_color_hex()
165187

166188
return super(AnnotationModel, self).save(*args, **kwargs)
167189

168190
def is_empty(self):
169191
return len(self.segmentation) == 0 or self.area == 0
170192

193+
def mask(self):
194+
""" Returns binary mask of annotation """
195+
mask = np.zeros((self.height, self.width))
196+
return decode_seg(mask, self.segmentation)
197+
198+
def clone(self):
199+
""" Creates a clone """
200+
create = json.loads(self.to_json())
201+
del create['_id']
202+
203+
return AnnotationModel(**create)
204+
171205

172206
class CategoryModel(db.DynamicDocument):
173207
id = db.SequenceField(primary_key=True)
@@ -180,12 +214,30 @@ class CategoryModel(db.DynamicDocument):
180214
deleted_date = db.DateTimeField()
181215

182216
@classmethod
183-
def create_category(cls, name, color=None, metadata=None, supercategory=None):
184-
category = CategoryModel(name=name, supercategory=supercategory)
185-
category.metadata = metadata if metadata is not None else {}
186-
category.color = color_util.random_color_hex() if color is None else color
187-
category.save()
188-
return category
217+
def bulk_create(cls, categories):
218+
219+
if not categories:
220+
return []
221+
222+
category_ids = []
223+
for category in categories:
224+
category_model = CategoryModel.objects(name=category).first()
225+
226+
if category_model is None:
227+
new_category = CategoryModel(name=category)
228+
new_category.save()
229+
category_ids.append(new_category.id)
230+
else:
231+
category_ids.append(category_model.id)
232+
233+
return category_ids
234+
235+
def save(self, *args, **kwargs):
236+
237+
if not self.color:
238+
self.color = color_util.random_color_hex()
239+
240+
return super(CategoryModel, self).save(*args, **kwargs)
189241

190242
def save(self, *args, **kwargs):
191243

@@ -230,19 +282,20 @@ def create_from_json(json_file):
230282
for category in data_json.get('categories', []):
231283
name = category.get('name')
232284
if name is not None:
233-
upsert(CategoryModel, query={ "name": name }, update=category)
285+
upsert(CategoryModel, query={"name": name}, update=category)
234286

235287
for dataset_json in data_json.get('datasets', []):
236288
name = dataset_json.get('name')
237289
if name:
238290
# map category names to ids; create as needed
239291
category_ids = []
240292
for category in dataset_json.get('categories', []):
241-
category_obj = { "name": category }
293+
category_obj = {"name": category}
294+
242295
category_model = upsert(CategoryModel, query=category_obj)
243296
category_ids.append(category_model.id)
244297

245298
dataset_json['categories'] = category_ids
246299
upsert(DatasetModel, query={ "name": name}, update=dataset_json)
247300

248-
sys.stdout.flush()
301+

app/util/coco_util.py

+32
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import pycocotools.mask as mask
2+
import numpy as np
3+
import cv2
24

35
from .query_util import fix_ids
46
from ..models import *
@@ -64,6 +66,23 @@ def paperjs_to_coco(image_width, image_height, paperjs):
6466
return segments, mask.area(rle), mask.toBbox(rle)
6567

6668

69+
def get_annotations_iou(annotation_a, annotation_b):
70+
"""
71+
Computes the IOU between two annotation objects
72+
"""
73+
seg_a = list([list(part) for part in annotation_a.segmentation])
74+
seg_b = list([list(part) for part in annotation_b.segmentation])
75+
76+
rles_a = mask.frPyObjects(
77+
seg_a, annotation_a.height, annotation_a.width)
78+
79+
rles_b = mask.frPyObjects(
80+
seg_b, annotation_b.height, annotation_b.width)
81+
82+
ious = mask.iou(rles_a, rles_b, [0])
83+
return ious[0][0]
84+
85+
6786
def get_image_coco(image):
6887
"""
6988
Generates coco for an image
@@ -157,6 +176,19 @@ def get_dataset_coco(dataset):
157176
return coco
158177

159178

179+
def decode_seg(mask, segmentation):
180+
"""
181+
Create binary mask from segmentation
182+
"""
183+
pts = [
184+
np.array(anno).reshape(-1, 2).round().astype(int)
185+
for anno in segmentation
186+
]
187+
mask = cv2.fillPoly(mask, pts, 1)
188+
189+
return mask
190+
191+
160192
def _fit(value, max_value, min_value):
161193

162194
if value > max_value:

0 commit comments

Comments
 (0)