Skip to content

Commit afc7899

Browse files
committed
support multi-classes tasks
1 parent ad5bfc2 commit afc7899

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

davarocr/davarocr/davar_det/core/post_processing/post_mask_rcnn.py

+23-24
Original file line numberDiff line numberDiff line change
@@ -80,33 +80,32 @@ def post_processing(self, batch_result, **kwargs):
8080
det_results = []
8181

8282
for result in batch_result:
83-
box_result, seg_result = result
84-
boxes_pred = box_result[0]
85-
seg_pred = seg_result[0]
8683
det_result = dict()
84+
box_result, seg_result = result
8785
det_result['points'] = []
8886
det_result['confidence'] = []
89-
90-
assert boxes_pred.shape[0] == len(seg_pred)
91-
92-
for box_id in range(boxes_pred.shape[0]):
93-
prob = boxes_pred[box_id, 4]
94-
seg = seg_pred[box_id]
95-
seg = np.array(seg[:,:, np.newaxis], dtype='uint8')
96-
curve_poly = self.approx_poly(seg)
97-
98-
if len(curve_poly) == 0:
99-
continue
100-
101-
curve_poly = curve_poly[0].squeeze()
102-
103-
# Filter out curve poly with less than 2 points.
104-
curve_poly = curve_poly.astype(np.int)
105-
if len(curve_poly.shape) < 2:
106-
continue
107-
curve_poly = curve_poly.reshape(-1).tolist()
108-
det_result['points'].append(curve_poly)
109-
det_result['confidence'].append(prob)
87+
det_result['labels'] = []
88+
for i in range(len(box_result)):
89+
boxes_pred = box_result[i]
90+
seg_pred = seg_result[i]
91+
assert boxes_pred.shape[0] == len(seg_pred)
92+
for box_id in range(boxes_pred.shape[0]):
93+
prob = boxes_pred[box_id, 4]
94+
seg = seg_pred[box_id]
95+
seg = np.array(seg[:,:, np.newaxis], dtype='uint8')
96+
97+
curve_poly = self.approx_poly(seg)
98+
if len(curve_poly) == 0:
99+
continue
100+
curve_poly = curve_poly[0].squeeze()
101+
if len(curve_poly.shape) < 2:
102+
continue
103+
104+
curve_poly = curve_poly.astype(np.int)
105+
curve_poly = curve_poly.reshape(-1).tolist()
106+
det_result['points'].append(curve_poly)
107+
det_result['confidence'].append(prob)
108+
det_result['labels'].append([i])
110109
det_results.append(det_result)
111110

112111
return det_results

0 commit comments

Comments
 (0)