@@ -31,9 +31,15 @@ def test_pr_curve_simple():
3131 score_thresholds = score_thresholds ,
3232 )
3333
34- assert pr_curve .shape == (2 , 1 , 101 )
35- assert np .isclose (pr_curve [0 ][0 ], 1.0 ).all ()
36- assert np .isclose (pr_curve [1 ][0 ], 1 / 3 ).all ()
34+ assert pr_curve .shape == (2 , 1 , 101 , 2 )
35+
36+ # test precision values
37+ assert np .isclose (pr_curve [0 , 0 , :, 0 ], 1.0 ).all ()
38+ assert np .isclose (pr_curve [1 , 0 , :, 0 ], 1 / 3 ).all ()
39+
40+ # test score values
41+ assert np .isclose (pr_curve [0 , 0 , :, 1 ], 0.95 ).all ()
42+ assert np .isclose (pr_curve [1 , 0 , :, 1 ], 0.65 ).all ()
3743
3844
3945def test_pr_curve_using_torch_metrics_example (
@@ -59,111 +65,162 @@ def test_pr_curve_using_torch_metrics_example(
5965 as_dict = True ,
6066 )
6167
62- # AP = 1.0
63- a = [1.0 for _ in range (101 )]
64-
65- # AP = 0.505
66- b = [1.0 for _ in range (51 )] + [0.0 for _ in range (50 )]
67-
68- # AP = 0.791
69- c = (
70- [1.0 for _ in range (71 )]
71- + [8 / 9 for _ in range (10 )]
72- + [0.0 for _ in range (20 )]
73- )
74-
75- # AP = 0.722
76- d = (
77- [1.0 for _ in range (41 )]
78- + [0.8 for _ in range (40 )]
79- + [0.0 for _ in range (20 )]
80- )
81-
82- # AP = 0.576
83- e = (
84- [1.0 for _ in range (41 )]
85- + [0.8571428571428571 for _ in range (20 )]
86- + [0.0 for _ in range (40 )]
87- )
88-
8968 # test PrecisionRecallCurve
9069 actual_metrics = [m for m in metrics [MetricType .PrecisionRecallCurve ]]
9170 expected_metrics = [
9271 {
9372 "type" : "PrecisionRecallCurve" ,
94- "value" : a ,
73+ "value" : {
74+ "precisions" : [1.0 for _ in range (101 )],
75+ "scores" : (
76+ [0.953 for _ in range (21 )]
77+ + [0.805 for _ in range (20 )]
78+ + [0.611 for _ in range (20 )]
79+ + [0.407 for _ in range (20 )]
80+ + [0.335 for _ in range (20 )]
81+ ),
82+ },
9583 "parameters" : {
9684 "iou_threshold" : 0.5 ,
9785 "label" : "0" ,
9886 },
9987 },
10088 {
10189 "type" : "PrecisionRecallCurve" ,
102- "value" : d ,
90+ "value" : {
91+ "precisions" : (
92+ [1.0 for _ in range (41 )]
93+ + [0.8 for _ in range (40 )]
94+ + [0.0 for _ in range (20 )]
95+ ),
96+ "scores" : (
97+ [0.953 for _ in range (21 )]
98+ + [0.805 for _ in range (20 )]
99+ + [0.407 for _ in range (20 )]
100+ + [0.335 for _ in range (20 )]
101+ + [0.0 for _ in range (20 )]
102+ ),
103+ },
103104 "parameters" : {
104105 "iou_threshold" : 0.75 ,
105106 "label" : "0" ,
106107 },
107108 },
108109 {
109110 "type" : "PrecisionRecallCurve" ,
110- "value" : a ,
111+ "value" : {
112+ "precisions" : [1.0 for _ in range (101 )],
113+ "scores" : [0.3 for _ in range (101 )],
114+ },
111115 "parameters" : {
112116 "iou_threshold" : 0.5 ,
113117 "label" : "1" ,
114118 },
115119 },
116120 {
117121 "type" : "PrecisionRecallCurve" ,
118- "value" : a ,
122+ "value" : {
123+ "precisions" : [1.0 for _ in range (101 )],
124+ "scores" : [0.3 for _ in range (101 )],
125+ },
119126 "parameters" : {
120127 "iou_threshold" : 0.75 ,
121128 "label" : "1" ,
122129 },
123130 },
124131 {
125132 "type" : "PrecisionRecallCurve" ,
126- "value" : b ,
133+ "value" : {
134+ "precisions" : [1.0 for _ in range (51 )]
135+ + [0.0 for _ in range (50 )],
136+ "scores" : [0.726 for _ in range (51 )]
137+ + [0.0 for _ in range (50 )],
138+ },
127139 "parameters" : {
128140 "iou_threshold" : 0.5 ,
129141 "label" : "2" ,
130142 },
131143 },
132144 {
133145 "type" : "PrecisionRecallCurve" ,
134- "value" : b ,
146+ "value" : {
147+ "precisions" : [1.0 for _ in range (51 )]
148+ + [0.0 for _ in range (50 )],
149+ "scores" : [0.726 for _ in range (51 )]
150+ + [0.0 for _ in range (50 )],
151+ },
135152 "parameters" : {
136153 "iou_threshold" : 0.75 ,
137154 "label" : "2" ,
138155 },
139156 },
140157 {
141158 "type" : "PrecisionRecallCurve" ,
142- "value" : a ,
159+ "value" : {
160+ "precisions" : [1.0 for _ in range (101 )],
161+ "scores" : [0.546 for _ in range (51 )]
162+ + [0.236 for _ in range (50 )],
163+ },
143164 "parameters" : {
144165 "iou_threshold" : 0.5 ,
145166 "label" : "4" ,
146167 },
147168 },
148169 {
149170 "type" : "PrecisionRecallCurve" ,
150- "value" : a ,
171+ "value" : {
172+ "precisions" : [1.0 for _ in range (101 )],
173+ "scores" : [0.546 for _ in range (51 )]
174+ + [0.236 for _ in range (50 )],
175+ },
151176 "parameters" : {
152177 "iou_threshold" : 0.75 ,
153178 "label" : "4" ,
154179 },
155180 },
156181 {
157182 "type" : "PrecisionRecallCurve" ,
158- "value" : c ,
183+ "value" : {
184+ "precisions" : (
185+ [1.0 for _ in range (71 )]
186+ + [8 / 9 for _ in range (10 )]
187+ + [0.0 for _ in range (20 )]
188+ ),
189+ "scores" : (
190+ [0.883 for _ in range (11 )]
191+ + [0.782 for _ in range (10 )]
192+ + [0.561 for _ in range (10 )]
193+ + [0.532 for _ in range (10 )]
194+ + [0.349 for _ in range (10 )]
195+ + [0.271 for _ in range (10 )]
196+ + [0.204 for _ in range (10 )]
197+ + [0.202 for _ in range (10 )]
198+ + [0.0 for _ in range (20 )]
199+ ),
200+ },
159201 "parameters" : {
160202 "iou_threshold" : 0.5 ,
161203 "label" : "49" ,
162204 },
163205 },
164206 {
165207 "type" : "PrecisionRecallCurve" ,
166- "value" : e ,
208+ "value" : {
209+ "precisions" : (
210+ [1.0 for _ in range (41 )]
211+ + [0.8571428571428571 for _ in range (20 )]
212+ + [0.0 for _ in range (40 )]
213+ ),
214+ "scores" : (
215+ [0.883 for _ in range (11 )]
216+ + [0.782 for _ in range (10 )]
217+ + [0.561 for _ in range (10 )]
218+ + [0.532 for _ in range (10 )]
219+ + [0.271 for _ in range (10 )]
220+ + [0.204 for _ in range (10 )]
221+ + [0.0 for _ in range (40 )]
222+ ),
223+ },
167224 "parameters" : {
168225 "iou_threshold" : 0.75 ,
169226 "label" : "49" ,
0 commit comments