@@ -80,33 +80,32 @@ def post_processing(self, batch_result, **kwargs):
80
80
det_results = []
81
81
82
82
for result in batch_result :
83
- box_result , seg_result = result
84
- boxes_pred = box_result [0 ]
85
- seg_pred = seg_result [0 ]
86
83
det_result = dict ()
84
+ box_result , seg_result = result
87
85
det_result ['points' ] = []
88
86
det_result ['confidence' ] = []
89
-
90
- assert boxes_pred .shape [0 ] == len (seg_pred )
91
-
92
- for box_id in range (boxes_pred .shape [0 ]):
93
- prob = boxes_pred [box_id , 4 ]
94
- seg = seg_pred [box_id ]
95
- seg = np .array (seg [:,:, np .newaxis ], dtype = 'uint8' )
96
- curve_poly = self .approx_poly (seg )
97
-
98
- if len (curve_poly ) == 0 :
99
- continue
100
-
101
- curve_poly = curve_poly [0 ].squeeze ()
102
-
103
- # Filter out curve poly with less than 2 points.
104
- curve_poly = curve_poly .astype (np .int )
105
- if len (curve_poly .shape ) < 2 :
106
- continue
107
- curve_poly = curve_poly .reshape (- 1 ).tolist ()
108
- det_result ['points' ].append (curve_poly )
109
- det_result ['confidence' ].append (prob )
87
+ det_result ['labels' ] = []
88
+ for i in range (len (box_result )):
89
+ boxes_pred = box_result [i ]
90
+ seg_pred = seg_result [i ]
91
+ assert boxes_pred .shape [0 ] == len (seg_pred )
92
+ for box_id in range (boxes_pred .shape [0 ]):
93
+ prob = boxes_pred [box_id , 4 ]
94
+ seg = seg_pred [box_id ]
95
+ seg = np .array (seg [:,:, np .newaxis ], dtype = 'uint8' )
96
+
97
+ curve_poly = self .approx_poly (seg )
98
+ if len (curve_poly ) == 0 :
99
+ continue
100
+ curve_poly = curve_poly [0 ].squeeze ()
101
+ if len (curve_poly .shape ) < 2 :
102
+ continue
103
+
104
+ curve_poly = curve_poly .astype (np .int )
105
+ curve_poly = curve_poly .reshape (- 1 ).tolist ()
106
+ det_result ['points' ].append (curve_poly )
107
+ det_result ['confidence' ].append (prob )
108
+ det_result ['labels' ].append ([i ])
110
109
det_results .append (det_result )
111
110
112
111
return det_results
0 commit comments