@@ -2145,6 +2145,75 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>,
2145
2145
return %0 : !torch.vtensor <[1 ,3 ],si64 >
2146
2146
}
2147
2147
2148
+ // CHECK-LABEL: func.func @test_nonmaxsuppression_center_point_box(
2149
+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4],f32>,
2150
+ // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,1],f32>,
2151
+ // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>,
2152
+ // CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>,
2153
+ // CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
2154
+ func.func @test_nonmaxsuppression_center_point_box (%arg0: !torch.vtensor <[1 ,1 ,4 ],f32 >, %arg1: !torch.vtensor <[1 ,1 ,1 ],f32 >, %arg2: !torch.vtensor <[1 ],si64 >, %arg3: !torch.vtensor <[1 ],f32 >, %arg4: !torch.vtensor <[1 ],f32 >) -> !torch.vtensor <[1 ,3 ],si64 > attributes {torch.onnx_meta.ir_version = 6 : si64 , torch.onnx_meta.opset_version = 11 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
2155
+ // CHECK: %[[VAL_5:.*]] = torch.constant.int 0
2156
+ // CHECK: %[[VAL_6:.*]] = torch.constant.int 1
2157
+ // CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int
2158
+ // CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool
2159
+ // CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1."
2160
+ // CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>
2161
+ // CHECK: %[[VAL_10:.*]] = torch.constant.int 0
2162
+ // CHECK: %[[VAL_11:.*]] = torch.constant.int 1
2163
+ // CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int
2164
+ // CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool
2165
+ // CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1."
2166
+ // CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
2167
+ // CHECK: %[[VAL_15:.*]] = torch.constant.int 0
2168
+ // CHECK: %[[VAL_16:.*]] = torch.constant.int 1
2169
+ // CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int
2170
+ // CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool
2171
+ // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
2172
+ // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32>
2173
+ // CHECK: %[[VAL_20:.*]] = torch.constant.int 0
2174
+ // CHECK: %[[VAL_21:.*]] = torch.constant.int 1
2175
+ // CHECK: %[[VAL_22:.*]] = torch.constant.int 2
2176
+ // CHECK: %[[VAL_23:.*]] = torch.constant.int 4
2177
+ // CHECK: %[[VAL_24:.*]] = torch.constant.float 2.000000e+00
2178
+ // CHECK: %[[VAL_25:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_20]], %[[VAL_22]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32>
2179
+ // CHECK: %[[VAL_26:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32>
2180
+ // CHECK: %[[VAL_27:.*]] = torch.aten.div.Scalar %[[VAL_26]], %[[VAL_24]] : !torch.vtensor<[?,2],f32>, !torch.float -> !torch.vtensor<[?,2],f32>
2181
+ // CHECK: %[[VAL_28:.*]] = torch.aten.sub.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32>
2182
+ // CHECK: %[[VAL_29:.*]] = torch.aten.add.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32>
2183
+ // CHECK: %[[VAL_30:.*]] = torch.prim.ListConstruct %[[VAL_28]], %[[VAL_29]] : (!torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>) -> !torch.list<vtensor>
2184
+ // CHECK: %[[VAL_31:.*]] = torch.aten.cat %[[VAL_30]], %[[VAL_21]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,4],f32>
2185
+ // CHECK: %[[VAL_32:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
2186
+ // CHECK: %[[VAL_33:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32>
2187
+ // CHECK: %[[VAL_34:.*]] = torch.aten.item %[[VAL_33]] : !torch.vtensor<[],f32> -> !torch.float
2188
+ // CHECK: %[[VAL_35:.*]] = torch.aten.ge.float %[[VAL_34]], %[[VAL_32]] : !torch.float, !torch.float -> !torch.bool
2189
+ // CHECK: torch.runtime.assert %[[VAL_35]], "unimplemented: score_threshold should be <= min(scores)"
2190
+ // CHECK: %[[VAL_36:.*]] = torch.constant.int 0
2191
+ // CHECK: %[[VAL_37:.*]] = torch.constant.int 1
2192
+ // CHECK: %[[VAL_38:.*]] = torch.constant.float 0.000000e+00
2193
+ // CHECK: %[[VAL_39:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
2194
+ // CHECK: %[[VAL_40:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
2195
+ // CHECK: %[[VAL_41:.*]] = torch.torchvision.nms %[[VAL_31]], %[[VAL_19]], %[[VAL_39]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64>
2196
+ // CHECK: %[[VAL_42:.*]] = torch.aten.size.int %[[VAL_41]], %[[VAL_36]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int
2197
+ // CHECK: %[[VAL_43:.*]] = torch.aten.gt.int %[[VAL_42]], %[[VAL_40]] : !torch.int, !torch.int -> !torch.bool
2198
+ // CHECK: %[[VAL_44:.*]] = torch.prim.If %[[VAL_43]] -> (!torch.vtensor<[1],si64>) {
2199
+ // CHECK: %[[VAL_45:.*]] = torch.aten.slice.Tensor %[[VAL_41]], %[[VAL_36]], %[[VAL_36]], %[[VAL_40]], %[[VAL_37]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
2200
+ // CHECK: torch.prim.If.yield %[[VAL_45]] : !torch.vtensor<[1],si64>
2201
+ // CHECK: } else {
2202
+ // CHECK: %[[VAL_46:.*]] = torch.tensor_static_info_cast %[[VAL_41]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64>
2203
+ // CHECK: torch.prim.If.yield %[[VAL_46]] : !torch.vtensor<[1],si64>
2204
+ // CHECK: }
2205
+ // CHECK: %[[VAL_47:.*]] = torch.aten.unsqueeze %[[VAL_44]], %[[VAL_37]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
2206
+ // CHECK: %[[VAL_48:.*]] = torch.aten.size.int %[[VAL_47]], %[[VAL_36]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int
2207
+ // CHECK: %[[VAL_49:.*]] = torch.constant.int 2
2208
+ // CHECK: %[[VAL_50:.*]] = torch.prim.ListConstruct %[[VAL_48]], %[[VAL_49]] : (!torch.int, !torch.int) -> !torch.list<int>
2209
+ // CHECK: %[[VAL_51:.*]] = torch.constant.none
2210
+ // CHECK: %[[VAL_52:.*]] = torch.aten.zeros %[[VAL_50]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
2211
+ // CHECK: %[[VAL_53:.*]] = torch.prim.ListConstruct %[[VAL_52]], %[[VAL_47]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
2212
+ // CHECK: %[[VAL_54:.*]] = torch.aten.cat %[[VAL_53]], %[[VAL_37]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
2213
+ // CHECK: return %[[VAL_54]] : !torch.vtensor<[1,3],si64>
2214
+ %0 = torch.operator " onnx.NonMaxSuppression" (%arg0 , %arg1 , %arg2 , %arg3 , %arg4 ) {torch.onnx.center_point_box = 1 : si64 } : (!torch.vtensor <[1 ,1 ,4 ],f32 >, !torch.vtensor <[1 ,1 ,1 ],f32 >, !torch.vtensor <[1 ],si64 >, !torch.vtensor <[1 ],f32 >, !torch.vtensor <[1 ],f32 >) -> !torch.vtensor <[1 ,3 ],si64 >
2215
+ return %0 : !torch.vtensor <[1 ,3 ],si64 >
2216
+ }
2148
2217
// -----
2149
2218
2150
2219
// CHECK-LABEL: func.func @test_mwm
0 commit comments