Skip to content

Commit 2564d7a

Browse files
authored
Add center_point_box=1 support in NonMaxSuppression. (llvm#3976)
When center_point_box=1, the supplied boxes come with a format of [x_center, y_center, width, height], this patch converts the format into [x1, y1, x2, y2] format before they are consumed. The e2e test is added in nod-ai/SHARK-TestSuite#436
1 parent 481da8d commit 2564d7a

File tree

2 files changed

+114
-4
lines changed

2 files changed

+114
-4
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3697,11 +3697,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
36973697
binder.tensorResultType(resultType))
36983698
return failure();
36993699

3700-
// TODO: Add support for non-zero center_point_box value.
3701-
if (centerPointBox != 0)
3700+
if (centerPointBox != 0 && centerPointBox != 1)
37023701
return rewriter.notifyMatchFailure(
3703-
binder.op, "unimplemented: expected center_point_box "
3704-
"attribute value to be 0");
3702+
binder.op, "expected center_point_box attribute to be 0 or 1");
37053703

37063704
// TODO: Support multiple batches and classes
37073705
// Squeeze the boxes and scores tensor.
@@ -3727,6 +3725,49 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
37273725
"failed to squeeze scores tensor");
37283726
boxes = squeezedBoxes.value();
37293727
scores = squeezedScores.value();
3728+
if (centerPointBox == 1) {
3729+
// When center_point_box is 1, the box data is supplied as
3730+
// [[x_center, y_center, width, height], ...]. Slice it to
3731+
// [[x_center, y_center], ...] and [[width, height], ...],
3732+
// calculate the [[x1, y1], ...] and [[x2, y2], ...], and concatnate
3733+
// to [[x1, y1, x2, y2], ...]
3734+
auto boxesTensorType =
3735+
dyn_cast<Torch::ValueTensorType>(boxes.getType());
3736+
Value const0 = rewriter.create<Torch::ConstantIntOp>(
3737+
loc, rewriter.getI64IntegerAttr(0));
3738+
Value const1 = rewriter.create<Torch::ConstantIntOp>(
3739+
loc, rewriter.getI64IntegerAttr(1));
3740+
Value const2 = rewriter.create<Torch::ConstantIntOp>(
3741+
loc, rewriter.getI64IntegerAttr(2));
3742+
Value const4 = rewriter.create<Torch::ConstantIntOp>(
3743+
loc, rewriter.getI64IntegerAttr(4));
3744+
Value const2F = rewriter.create<Torch::ConstantFloatOp>(
3745+
loc, rewriter.getF64FloatAttr(2.0));
3746+
3747+
// extract scaled ranges for regions of interest
3748+
auto sliceShape = SmallVector<int64_t>{Torch::kUnknownSize, 2};
3749+
auto sliceTensorType = rewriter.getType<Torch::ValueTensorType>(
3750+
sliceShape, boxesTensorType.getDtype());
3751+
Value centers = rewriter.create<Torch::AtenSliceTensorOp>(
3752+
loc, sliceTensorType, boxes, const1, const0, const2, const1);
3753+
Value sizes = rewriter.create<Torch::AtenSliceTensorOp>(
3754+
loc, sliceTensorType, boxes, const1, const2, const4, const1);
3755+
Value halfSizes = rewriter.create<Torch::AtenDivScalarOp>(
3756+
loc, sizes.getType(), sizes, const2F);
3757+
Value x1y1s = rewriter.create<Torch::AtenSubTensorOp>(
3758+
loc, centers.getType(), centers, halfSizes, const1);
3759+
Value x2y2s = rewriter.create<Torch::AtenAddTensorOp>(
3760+
loc, centers.getType(), centers, halfSizes, const1);
3761+
3762+
Type listElemType = boxesTensorType.getWithSizesAndDtype(
3763+
/*optionalSizes=*/std::nullopt,
3764+
/*optionalDtype=*/nullptr);
3765+
Type listType = Torch::ListType::get(listElemType);
3766+
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
3767+
loc, listType, SmallVector<Value>{x1y1s, x2y2s});
3768+
boxes = rewriter.create<Torch::AtenCatOp>(loc, boxesTensorType,
3769+
tensorList, const1);
3770+
}
37303771

37313772
// TODO: Support score_threshold input
37323773
// Filter out the boxes if the score < score_threshold

test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,6 +2145,75 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>,
21452145
return %0 : !torch.vtensor<[1,3],si64>
21462146
}
21472147

2148+
// CHECK-LABEL: func.func @test_nonmaxsuppression_center_point_box(
2149+
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4],f32>,
2150+
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,1,1],f32>,
2151+
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1],si64>,
2152+
// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[1],f32>,
2153+
// CHECK-SAME: %[[VAL_4:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
2154+
func.func @test_nonmaxsuppression_center_point_box(%arg0: !torch.vtensor<[1,1,4],f32>, %arg1: !torch.vtensor<[1,1,1],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
2155+
// CHECK: %[[VAL_5:.*]] = torch.constant.int 0
2156+
// CHECK: %[[VAL_6:.*]] = torch.constant.int 1
2157+
// CHECK: %[[VAL_7:.*]] = torch.aten.size.int %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.int
2158+
// CHECK: %[[VAL_8:.*]] = torch.aten.eq.int %[[VAL_7]], %[[VAL_6]] : !torch.int, !torch.int -> !torch.bool
2159+
// CHECK: torch.runtime.assert %[[VAL_8]], "squeeze operation possible for dim only when input_shape[dim] == 1."
2160+
// CHECK: %[[VAL_9:.*]] = torch.aten.squeeze.dim %[[VAL_0]], %[[VAL_5]] : !torch.vtensor<[1,1,4],f32>, !torch.int -> !torch.vtensor<[1,4],f32>
2161+
// CHECK: %[[VAL_10:.*]] = torch.constant.int 0
2162+
// CHECK: %[[VAL_11:.*]] = torch.constant.int 1
2163+
// CHECK: %[[VAL_12:.*]] = torch.aten.size.int %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.int
2164+
// CHECK: %[[VAL_13:.*]] = torch.aten.eq.int %[[VAL_12]], %[[VAL_11]] : !torch.int, !torch.int -> !torch.bool
2165+
// CHECK: torch.runtime.assert %[[VAL_13]], "squeeze operation possible for dim only when input_shape[dim] == 1."
2166+
// CHECK: %[[VAL_14:.*]] = torch.aten.squeeze.dim %[[VAL_1]], %[[VAL_10]] : !torch.vtensor<[1,1,1],f32>, !torch.int -> !torch.vtensor<[1,1],f32>
2167+
// CHECK: %[[VAL_15:.*]] = torch.constant.int 0
2168+
// CHECK: %[[VAL_16:.*]] = torch.constant.int 1
2169+
// CHECK: %[[VAL_17:.*]] = torch.aten.size.int %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.int
2170+
// CHECK: %[[VAL_18:.*]] = torch.aten.eq.int %[[VAL_17]], %[[VAL_16]] : !torch.int, !torch.int -> !torch.bool
2171+
// CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1."
2172+
// CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32>
2173+
// CHECK: %[[VAL_20:.*]] = torch.constant.int 0
2174+
// CHECK: %[[VAL_21:.*]] = torch.constant.int 1
2175+
// CHECK: %[[VAL_22:.*]] = torch.constant.int 2
2176+
// CHECK: %[[VAL_23:.*]] = torch.constant.int 4
2177+
// CHECK: %[[VAL_24:.*]] = torch.constant.float 2.000000e+00
2178+
// CHECK: %[[VAL_25:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_20]], %[[VAL_22]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32>
2179+
// CHECK: %[[VAL_26:.*]] = torch.aten.slice.Tensor %[[VAL_9]], %[[VAL_21]], %[[VAL_22]], %[[VAL_23]], %[[VAL_21]] : !torch.vtensor<[1,4],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,2],f32>
2180+
// CHECK: %[[VAL_27:.*]] = torch.aten.div.Scalar %[[VAL_26]], %[[VAL_24]] : !torch.vtensor<[?,2],f32>, !torch.float -> !torch.vtensor<[?,2],f32>
2181+
// CHECK: %[[VAL_28:.*]] = torch.aten.sub.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32>
2182+
// CHECK: %[[VAL_29:.*]] = torch.aten.add.Tensor %[[VAL_25]], %[[VAL_27]], %[[VAL_21]] : !torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>, !torch.int -> !torch.vtensor<[?,2],f32>
2183+
// CHECK: %[[VAL_30:.*]] = torch.prim.ListConstruct %[[VAL_28]], %[[VAL_29]] : (!torch.vtensor<[?,2],f32>, !torch.vtensor<[?,2],f32>) -> !torch.list<vtensor>
2184+
// CHECK: %[[VAL_31:.*]] = torch.aten.cat %[[VAL_30]], %[[VAL_21]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,4],f32>
2185+
// CHECK: %[[VAL_32:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float
2186+
// CHECK: %[[VAL_33:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32>
2187+
// CHECK: %[[VAL_34:.*]] = torch.aten.item %[[VAL_33]] : !torch.vtensor<[],f32> -> !torch.float
2188+
// CHECK: %[[VAL_35:.*]] = torch.aten.ge.float %[[VAL_34]], %[[VAL_32]] : !torch.float, !torch.float -> !torch.bool
2189+
// CHECK: torch.runtime.assert %[[VAL_35]], "unimplemented: score_threshold should be <= min(scores)"
2190+
// CHECK: %[[VAL_36:.*]] = torch.constant.int 0
2191+
// CHECK: %[[VAL_37:.*]] = torch.constant.int 1
2192+
// CHECK: %[[VAL_38:.*]] = torch.constant.float 0.000000e+00
2193+
// CHECK: %[[VAL_39:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float
2194+
// CHECK: %[[VAL_40:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int
2195+
// CHECK: %[[VAL_41:.*]] = torch.torchvision.nms %[[VAL_31]], %[[VAL_19]], %[[VAL_39]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64>
2196+
// CHECK: %[[VAL_42:.*]] = torch.aten.size.int %[[VAL_41]], %[[VAL_36]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int
2197+
// CHECK: %[[VAL_43:.*]] = torch.aten.gt.int %[[VAL_42]], %[[VAL_40]] : !torch.int, !torch.int -> !torch.bool
2198+
// CHECK: %[[VAL_44:.*]] = torch.prim.If %[[VAL_43]] -> (!torch.vtensor<[1],si64>) {
2199+
// CHECK: %[[VAL_45:.*]] = torch.aten.slice.Tensor %[[VAL_41]], %[[VAL_36]], %[[VAL_36]], %[[VAL_40]], %[[VAL_37]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
2200+
// CHECK: torch.prim.If.yield %[[VAL_45]] : !torch.vtensor<[1],si64>
2201+
// CHECK: } else {
2202+
// CHECK: %[[VAL_46:.*]] = torch.tensor_static_info_cast %[[VAL_41]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64>
2203+
// CHECK: torch.prim.If.yield %[[VAL_46]] : !torch.vtensor<[1],si64>
2204+
// CHECK: }
2205+
// CHECK: %[[VAL_47:.*]] = torch.aten.unsqueeze %[[VAL_44]], %[[VAL_37]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64>
2206+
// CHECK: %[[VAL_48:.*]] = torch.aten.size.int %[[VAL_47]], %[[VAL_36]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int
2207+
// CHECK: %[[VAL_49:.*]] = torch.constant.int 2
2208+
// CHECK: %[[VAL_50:.*]] = torch.prim.ListConstruct %[[VAL_48]], %[[VAL_49]] : (!torch.int, !torch.int) -> !torch.list<int>
2209+
// CHECK: %[[VAL_51:.*]] = torch.constant.none
2210+
// CHECK: %[[VAL_52:.*]] = torch.aten.zeros %[[VAL_50]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]], %[[VAL_51]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64>
2211+
// CHECK: %[[VAL_53:.*]] = torch.prim.ListConstruct %[[VAL_52]], %[[VAL_47]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list<vtensor>
2212+
// CHECK: %[[VAL_54:.*]] = torch.aten.cat %[[VAL_53]], %[[VAL_37]] : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,3],si64>
2213+
// CHECK: return %[[VAL_54]] : !torch.vtensor<[1,3],si64>
2214+
%0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.center_point_box = 1 : si64} : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64>
2215+
return %0 : !torch.vtensor<[1,3],si64>
2216+
}
21482217
// -----
21492218

21502219
// CHECK-LABEL: func.func @test_mwm

0 commit comments

Comments
 (0)