Skip to content

Commit

Permalink
onnx.Resize and aten._interpolate : allow n spatial dims. (llvm#3368)
Browse files Browse the repository at this point in the history
The old lowering only had logic for 2d (i.e. images). this patch allows
interpolation for n spatial dims, which is required for some 3d vision
models such as

- onnx/models/pytorch-3dunet_vaiq_int8

which successfully compiles and runs with this patch.
  • Loading branch information
zjgarvey authored and mgehre-amd committed Jun 6, 2024
1 parent 7ec27a9 commit f1e7ed2
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 96 deletions.
2 changes: 1 addition & 1 deletion lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2180,7 +2180,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::ArrayRef(selectSizes), operandType.getOptionalDtype());

MLIRContext *context = binder.op->getContext();
for (int i = sizes[0] - 2; i < sizes[0]; i++) {
for (int i = 2; i < sizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Expand Down
151 changes: 67 additions & 84 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2589,68 +2589,58 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
};
} // namespace

static Value NearestInterpolate(OpBuilder &b, Location loc, Value outputSizeH,
Value outputSizeW, Value input,
Value inputSizeH, Value inputSizeW) {
static Value NearestInterpolate(OpBuilder &b, Location loc,
SmallVector<Value> outputSizes, Value input,
SmallVector<Value> inputSizes) {

auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();

Value yOut = b.create<linalg::IndexOp>(loc, 2);
Value xOut = b.create<linalg::IndexOp>(loc, 3);

Value inputHFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeH);
Value inputWFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeW);
SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}

Value outputSizeHFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeH);
Value outputSizeWFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeW);
for (unsigned i = 2; i < inputRank; i++) {
Value outIndex = indices[i];

// scale = length_resized / length_original
// x_original = x_resized / scale
Value hScale = b.create<arith::DivFOp>(loc, outputSizeHFP, inputHFP);
Value wScale = b.create<arith::DivFOp>(loc, outputSizeWFP, inputWFP);
Value inputSizeFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i - 2]);

Value yOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), yOut);
Value yOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), yOutInt);
Value yProj = b.create<arith::DivFOp>(loc, yOutFP, hScale);
Value outputSizeFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizes[i - 2]);

Value xOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), xOut);
Value xOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), xOutInt);
Value xProj = b.create<arith::DivFOp>(loc, xOutFP, wScale);
// scale = length_resized / length_original
// x_original = x_resized / scale
Value scale = b.create<arith::DivFOp>(loc, outputSizeFP, inputSizeFP);

// get nearest pixel using floor
Value yNearestFP = b.create<math::FloorOp>(loc, yProj);
Value xNearestFP = b.create<math::FloorOp>(loc, xProj);
Value outInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), outIndex);
Value outFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), outInt);
Value proj = b.create<arith::DivFOp>(loc, outFP, scale);

Value yNearestInt =
b.create<arith::FPToSIOp>(loc, b.getI64Type(), yNearestFP);
Value yNearest =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), yNearestInt);
// get nearest pixel using floor
Value nearestFP = b.create<math::FloorOp>(loc, proj);

Value xNearestInt =
b.create<arith::FPToSIOp>(loc, b.getI64Type(), xNearestFP);
Value xNearest =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), xNearestInt);
Value nearestInt =
b.create<arith::FPToSIOp>(loc, b.getI64Type(), nearestFP);
Value nearest =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), nearestInt);

SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
indices[i] = nearest;
}

int hDimOffset = 2;
indices[hDimOffset] = yNearest;
indices[hDimOffset + 1] = xNearest;
Value retVal = b.create<tensor::ExtractOp>(loc, input, indices);
return retVal;
}

static Value BilinearInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, Value outputSizeH,
Value outputSizeW, Value input,
Value inputSizeH, Value inputSizeW) {
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes) {
Value inputSizeH = inputSizes[0];
Value inputSizeW = inputSizes[1];
Value outputSizeH = outputSizes[0];
Value outputSizeW = outputSizes[1];

int hDimOffset = 2;
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
Expand Down Expand Up @@ -2805,7 +2795,6 @@ static Value BilinearInterpolate(OpBuilder &b,
rhs = b.create<arith::MulFOp>(loc, w1, xInter1);

Value retVal = b.create<arith::AddFOp>(loc, lhs, rhs);

return retVal;
}

Expand All @@ -2828,46 +2817,43 @@ class ConvertInterpolateOp
Value input = adaptor.getInput();
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
if (mode == "bilinear" && inputRank != 4)
return rewriter.notifyMatchFailure(
op,
"cannot perform bilinear interpolation when input spatial dims != 2");

SmallVector<Value, 2> outputSizeIntValues;
Value inputSizeH = getDimOp(rewriter, loc, input, 2);
inputSizeH = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inputSizeH);
Value inputSizeW = getDimOp(rewriter, loc, input, 3);
inputSizeW = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inputSizeW);
SmallVector<Value> outputSizeIntValues;
SmallVector<Value> inputSizes;
for (unsigned i = 2; i < inputRank; i++) {
Value inputSize = getDimOp(rewriter, loc, input, 2);
inputSizes.push_back(rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inputSize));
}

if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
SmallVector<Value, 2> ScaleFactorTorchFloat;
SmallVector<Value> ScaleFactorTorchFloat;
if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat))
return rewriter.notifyMatchFailure(
op, "unimplemented: the output_size is not constructed from "
"ListConstruct");
SmallVector<Value, 2> ScaleFactorFloatValues;
SmallVector<Value> ScaleFactorFloatValues;
ScaleFactorFloatValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat);
Value inputHFP = rewriter.create<arith::SIToFPOp>(
loc, rewriter.getF32Type(), inputSizeH);
Value scale = rewriter.create<arith::TruncFOp>(loc, inputHFP.getType(),
ScaleFactorFloatValues[0]);
Value outputSizeH = rewriter.create<arith::MulFOp>(loc, inputHFP, scale);
Value outputH = rewriter.create<math::FloorOp>(loc, outputSizeH);
outputH =
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), outputH);

Value inputWFP = rewriter.create<arith::SIToFPOp>(
loc, rewriter.getF32Type(), inputSizeW);
scale = rewriter.create<arith::TruncFOp>(loc, inputWFP.getType(),
ScaleFactorFloatValues[1]);
Value outputSizeW = rewriter.create<arith::MulFOp>(loc, inputWFP, scale);
Value outputW = rewriter.create<math::FloorOp>(loc, outputSizeW);
outputW =
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), outputW);

outputSizeIntValues.push_back(outputH);
outputSizeIntValues.push_back(outputW);
for (unsigned i = 0; i < inputRank - 2; i++) {
Value inputSizeFP = rewriter.create<arith::SIToFPOp>(
loc, rewriter.getF32Type(), inputSizes[i]);
Value scale = rewriter.create<arith::TruncFOp>(
loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]);
Value outputSize =
rewriter.create<arith::MulFOp>(loc, inputSizeFP, scale);
outputSize = rewriter.create<math::FloorOp>(loc, outputSize);
outputSize = rewriter.create<arith::FPToSIOp>(
loc, rewriter.getI64Type(), outputSize);

outputSizeIntValues.push_back(outputSize);
}
} else {
SmallVector<Value, 2> outputSizeTorchInt;
SmallVector<Value> outputSizeTorchInt;
if (!getListConstructElements(op.getSize(), outputSizeTorchInt))
return rewriter.notifyMatchFailure(
op, "unimplemented: the output_size is not constructed from "
Expand All @@ -2876,8 +2862,9 @@ class ConvertInterpolateOp
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
}
SmallVector<Value> dims = getTensorSizesUntilDim(rewriter, loc, input, 1);
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[0]));
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[1]));
for (unsigned i = 2; i < inputRank; i++) {
dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[i - 2]));
}

Value outTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(dims), inputType.getElementType());
Expand All @@ -2894,17 +2881,13 @@ class ConvertInterpolateOp
/*indexingMaps=*/idMap,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value outputSizeH = outputSizeIntValues[0];
Value outputSizeW = outputSizeIntValues[1];
Value retVal;
if (mode == "nearest") {
retVal =
NearestInterpolate(b, loc, outputSizeH, outputSizeW,
input, inputSizeH, inputSizeW);
retVal = NearestInterpolate(b, loc, outputSizeIntValues,
input, inputSizes);
} else if (mode == "bilinear") {
retVal = BilinearInterpolate(b, op, loc, outputSizeH,
outputSizeW, input, inputSizeH,
inputSizeW);
retVal = BilinearInterpolate(
b, op, loc, outputSizeIntValues, input, inputSizes);
}
b.create<linalg::YieldOp>(loc, retVal);
})
Expand Down
94 changes: 83 additions & 11 deletions test/Conversion/TorchToLinalg/resize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,31 +94,29 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:

func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[GENERIC:.*]] = linalg.generic
// CHECK: %[[x11:.*]] = linalg.index 0 : index
// CHECK: %[[x12:.*]] = linalg.index 1 : index
// CHECK: %[[x13:.*]] = linalg.index 2 : index
// CHECK: %[[x14:.*]] = linalg.index 3 : index
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
// CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
// CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32
// CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64
// CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32
// CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32
// CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32
// CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64
// CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
// CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
// CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32
// CHECK: %[[x26:.*]] = arith.index_cast %[[x14]] : index to i64
// CHECK: %[[x27:.*]] = arith.sitofp %[[x26]] : i64 to f32
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x22]] : f32
// CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32
// CHECK: %[[x30:.*]] = math.floor %[[x28]] : f32
// CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64
// CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index
// CHECK: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64
// CHECK: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index
// CHECK: %[[x35:.*]] = linalg.index 0 : index
// CHECK: %[[x36:.*]] = linalg.index 1 : index
// CHECK: %[[x37:.*]] = linalg.index 2 : index
// CHECK: %[[x38:.*]] = linalg.index 3 : index
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x35]], %[[x36]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32>
// CHECK: linalg.yield %[[extracted]] : f32
%none = torch.constant.none
%none_0 = torch.constant.none
Expand All @@ -136,3 +134,77 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %5 : !torch.vtensor<[?,?,?,?],f32>
}

// -----

func.func @test_resize_nearest_1d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[GENERIC:.*]] = linalg.generic
// CHECK: %[[x11:.*]] = linalg.index 0 : index
// CHECK: %[[x12:.*]] = linalg.index 1 : index
// CHECK: %[[x13:.*]] = linalg.index 2 : index
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
// CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64
// CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32
// CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32
// CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32
// CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64
// CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor<?x?x?xf32>
// CHECK: linalg.yield %[[extracted]] : f32
%none = torch.constant.none
%none_0 = torch.constant.none
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%true = torch.constant.bool true
%str = torch.constant.str "nearest"
%int2 = torch.constant.int 2
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
%4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list<int>
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32>
return %5 : !torch.vtensor<[?,?,?],f32>
}

// -----

func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[5],si64>) -> !torch.vtensor<[?,?,?,?,?],f32> {
// CHECK: %[[GENERIC:.*]] = linalg.generic
// CHECK: %[[x11:.*]] = linalg.index 0 : index
// CHECK: %[[x12:.*]] = linalg.index 1 : index
// CHECK: %[[x13:.*]] = linalg.index 2 : index
// CHECK: %[[x14:.*]] = linalg.index 3 : index
// CHECK: %[[index4:.*]] = linalg.index 4 : index
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
// CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64
// CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32
// CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32
// CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32
// CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64
// CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index
// CHECK: %[[x34:.*]] = arith.index_cast %[[Wfptosi:.*]] : i64 to index
// CHECK: %[[x35:.*]] = arith.index_cast %[[Dfptosi:.*]] : i64 to index
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]], %[[x34]], %[[x35]]] : tensor<?x?x?x?x?xf32>
// CHECK: linalg.yield %[[extracted]] : f32
%none = torch.constant.none
%none_0 = torch.constant.none
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%true = torch.constant.bool true
%str = torch.constant.str "nearest"
%int2 = torch.constant.int 2
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
%int3 = torch.constant.int 3
%2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
%int4 = torch.constant.int 4
%4 = torch.aten.select.int %arg1, %int0, %int4 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%5 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
%6 = torch.prim.ListConstruct %1, %3, %5: (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%7 = torch.aten.__interpolate.size_list_scale_list %arg0, %6, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32>
return %7 : !torch.vtensor<[?,?,?,?,?],f32>
}

0 comments on commit f1e7ed2

Please sign in to comment.