diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 2a55378bc4a9..35a3204b7e36 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2180,7 +2180,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::ArrayRef(selectSizes), operandType.getOptionalDtype()); MLIRContext *context = binder.op->getContext(); - for (int i = sizes[0] - 2; i < sizes[0]; i++) { + for (int i = 2; i < sizes[0]; i++) { Value selectIndex = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e73fb1e88dc4..0648508f75bb 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2589,68 +2589,58 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { }; } // namespace -static Value NearestInterpolate(OpBuilder &b, Location loc, Value outputSizeH, - Value outputSizeW, Value input, - Value inputSizeH, Value inputSizeW) { +static Value NearestInterpolate(OpBuilder &b, Location loc, + SmallVector outputSizes, Value input, + SmallVector inputSizes) { auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); - Value yOut = b.create(loc, 2); - Value xOut = b.create(loc, 3); - - Value inputHFP = b.create(loc, b.getF32Type(), inputSizeH); - Value inputWFP = b.create(loc, b.getF32Type(), inputSizeW); + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } - Value outputSizeHFP = - b.create(loc, b.getF32Type(), outputSizeH); - Value outputSizeWFP = - b.create(loc, b.getF32Type(), outputSizeW); + for (unsigned i = 2; i < inputRank; i++) { + Value outIndex = indices[i]; - // scale = length_resized / length_original - // x_original = x_resized / scale - Value hScale = b.create(loc, outputSizeHFP, inputHFP); - Value wScale = b.create(loc, outputSizeWFP, inputWFP); + Value inputSizeFP = + b.create(loc, b.getF32Type(), inputSizes[i - 2]); - Value yOutInt = b.create(loc, b.getI64Type(), yOut); - Value yOutFP = b.create(loc, b.getF32Type(), yOutInt); - Value yProj = b.create(loc, yOutFP, hScale); + Value outputSizeFP = + b.create(loc, b.getF32Type(), outputSizes[i - 2]); - Value xOutInt = b.create(loc, b.getI64Type(), xOut); - Value xOutFP = b.create(loc, b.getF32Type(), xOutInt); - Value xProj = b.create(loc, xOutFP, wScale); + // scale = length_resized / length_original + // x_original = x_resized / scale + Value scale = b.create(loc, outputSizeFP, inputSizeFP); - // get nearest pixel using floor - Value yNearestFP = b.create(loc, yProj); - Value xNearestFP = b.create(loc, xProj); + Value outInt = b.create(loc, b.getI64Type(), outIndex); + Value outFP = b.create(loc, b.getF32Type(), outInt); + Value proj = b.create(loc, outFP, scale); - Value yNearestInt = - b.create(loc, b.getI64Type(), yNearestFP); - Value yNearest = - b.create(loc, b.getIndexType(), yNearestInt); + // get nearest pixel using floor + Value nearestFP = b.create(loc, proj); - Value xNearestInt = - b.create(loc, b.getI64Type(), xNearestFP); - Value xNearest = - b.create(loc, b.getIndexType(), xNearestInt); + Value nearestInt = + b.create(loc, b.getI64Type(), nearestFP); + Value nearest = + b.create(loc, b.getIndexType(), nearestInt); - SmallVector indices; - for (unsigned i = 0; i < inputRank; i++) { - indices.push_back(b.create(loc, i)); + indices[i] = nearest; } - - int hDimOffset = 2; - indices[hDimOffset] = yNearest; - indices[hDimOffset + 1] = xNearest; Value retVal = b.create(loc, input, indices); return retVal; } static Value BilinearInterpolate(OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, - Location loc, Value outputSizeH, - Value outputSizeW, Value input, - Value inputSizeH, Value inputSizeW) { + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes) { + Value inputSizeH = inputSizes[0]; + Value inputSizeW = inputSizes[1]; + Value outputSizeH = outputSizes[0]; + Value outputSizeW = outputSizes[1]; + int hDimOffset = 2; auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); @@ -2805,7 +2795,6 @@ static Value BilinearInterpolate(OpBuilder &b, rhs = b.create(loc, w1, xInter1); Value retVal = b.create(loc, lhs, rhs); - return retVal; } @@ -2828,46 +2817,43 @@ class ConvertInterpolateOp Value input = adaptor.getInput(); auto inputType = input.getType().cast(); auto inputRank = inputType.getRank(); + if (mode == "bilinear" && inputRank != 4) + return rewriter.notifyMatchFailure( + op, + "cannot perform bilinear interpolation when input spatial dims != 2"); - SmallVector outputSizeIntValues; - Value inputSizeH = getDimOp(rewriter, loc, input, 2); - inputSizeH = rewriter.create( - loc, rewriter.getIntegerType(64), inputSizeH); - Value inputSizeW = getDimOp(rewriter, loc, input, 3); - inputSizeW = rewriter.create( - loc, rewriter.getIntegerType(64), inputSizeW); + SmallVector outputSizeIntValues; + SmallVector inputSizes; + for (unsigned i = 2; i < inputRank; i++) { + Value inputSize = getDimOp(rewriter, loc, input, 2); + inputSizes.push_back(rewriter.create( + loc, rewriter.getIntegerType(64), inputSize)); + } if (!op.getScaleFactor().getType().isa()) { - SmallVector ScaleFactorTorchFloat; + SmallVector ScaleFactorTorchFloat; if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat)) return rewriter.notifyMatchFailure( op, "unimplemented: the output_size is not constructed from " "ListConstruct"); - SmallVector ScaleFactorFloatValues; + SmallVector ScaleFactorFloatValues; ScaleFactorFloatValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat); - Value inputHFP = rewriter.create( - loc, rewriter.getF32Type(), inputSizeH); - Value scale = rewriter.create(loc, inputHFP.getType(), - ScaleFactorFloatValues[0]); - Value outputSizeH = rewriter.create(loc, inputHFP, scale); - Value outputH = rewriter.create(loc, outputSizeH); - outputH = - rewriter.create(loc, rewriter.getI64Type(), outputH); - - Value inputWFP = rewriter.create( - loc, rewriter.getF32Type(), inputSizeW); - scale = rewriter.create(loc, inputWFP.getType(), - ScaleFactorFloatValues[1]); - Value outputSizeW = rewriter.create(loc, inputWFP, scale); - Value outputW = rewriter.create(loc, outputSizeW); - outputW = - rewriter.create(loc, rewriter.getI64Type(), outputW); - - outputSizeIntValues.push_back(outputH); - outputSizeIntValues.push_back(outputW); + for (unsigned i = 0; i < inputRank - 2; i++) { + Value inputSizeFP = rewriter.create( + loc, rewriter.getF32Type(), inputSizes[i]); + Value scale = rewriter.create( + loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]); + Value outputSize = + rewriter.create(loc, inputSizeFP, scale); + outputSize = rewriter.create(loc, outputSize); + outputSize = rewriter.create( + loc, rewriter.getI64Type(), outputSize); + + outputSizeIntValues.push_back(outputSize); + } } else { - SmallVector outputSizeTorchInt; + SmallVector outputSizeTorchInt; if (!getListConstructElements(op.getSize(), outputSizeTorchInt)) return rewriter.notifyMatchFailure( op, "unimplemented: the output_size is not constructed from " @@ -2876,8 +2862,9 @@ class ConvertInterpolateOp rewriter, loc, getTypeConverter(), outputSizeTorchInt); } SmallVector dims = getTensorSizesUntilDim(rewriter, loc, input, 1); - dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[0])); - dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[1])); + for (unsigned i = 2; i < inputRank; i++) { + dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[i - 2])); + } Value outTensor = rewriter.create( loc, getAsOpFoldResult(dims), inputType.getElementType()); @@ -2894,17 +2881,13 @@ class ConvertInterpolateOp /*indexingMaps=*/idMap, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value outputSizeH = outputSizeIntValues[0]; - Value outputSizeW = outputSizeIntValues[1]; Value retVal; if (mode == "nearest") { - retVal = - NearestInterpolate(b, loc, outputSizeH, outputSizeW, - input, inputSizeH, inputSizeW); + retVal = NearestInterpolate(b, loc, outputSizeIntValues, + input, inputSizes); } else if (mode == "bilinear") { - retVal = BilinearInterpolate(b, op, loc, outputSizeH, - outputSizeW, input, inputSizeH, - inputSizeW); + retVal = BilinearInterpolate( + b, op, loc, outputSizeIntValues, input, inputSizes); } b.create(loc, retVal); }) diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 9850a5fdabd6..1f6b69a50af0 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -94,31 +94,29 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index // CHECK: %[[x13:.*]] = linalg.index 2 : index // CHECK: %[[x14:.*]] = linalg.index 3 : index // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 - // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 - // CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 - // CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32 // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32 + // CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32 + // CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32 // CHECK: %[[x26:.*]] = arith.index_cast %[[x14]] : index to i64 // CHECK: %[[x27:.*]] = arith.sitofp %[[x26]] : i64 to f32 // CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x22]] : f32 - // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 // CHECK: %[[x30:.*]] = math.floor %[[x28]] : f32 - // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 - // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index // CHECK: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64 // CHECK: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index - // CHECK: %[[x35:.*]] = linalg.index 0 : index - // CHECK: %[[x36:.*]] = linalg.index 1 : index - // CHECK: %[[x37:.*]] = linalg.index 2 : index - // CHECK: %[[x38:.*]] = linalg.index 3 : index - // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x35]], %[[x36]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32> + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32> // CHECK: linalg.yield %[[extracted]] : f32 %none = torch.constant.none %none_0 = torch.constant.none @@ -136,3 +134,77 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> return %5 : !torch.vtensor<[?,?,?,?],f32> } + +// ----- + +func.func @test_resize_nearest_1d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1 : (!torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?],f32> +} + +// ----- + +func.func @test_resize_nearest_3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>, %arg1: !torch.vtensor<[5],si64>) -> !torch.vtensor<[?,?,?,?,?],f32> { + // CHECK: %[[GENERIC:.*]] = linalg.generic + // CHECK: %[[x11:.*]] = linalg.index 0 : index + // CHECK: %[[x12:.*]] = linalg.index 1 : index + // CHECK: %[[x13:.*]] = linalg.index 2 : index + // CHECK: %[[x14:.*]] = linalg.index 3 : index + // CHECK: %[[index4:.*]] = linalg.index 4 : index + // CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32 + // CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32 + // CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32 + // CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64 + // CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32 + // CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 + // CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32 + // CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64 + // CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index + // CHECK: %[[x34:.*]] = arith.index_cast %[[Wfptosi:.*]] : i64 to index + // CHECK: %[[x35:.*]] = arith.index_cast %[[Dfptosi:.*]] : i64 to index + // CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x11]], %[[x12]], %[[x32]], %[[x34]], %[[x35]]] : tensor + // CHECK: linalg.yield %[[extracted]] : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "nearest" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + %int4 = torch.constant.int 4 + %4 = torch.aten.select.int %arg1, %int0, %int4 : !torch.vtensor<[5],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %5 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %6 = torch.prim.ListConstruct %1, %3, %5: (!torch.int, !torch.int, !torch.int) -> !torch.list + %7 = torch.aten.__interpolate.size_list_scale_list %arg0, %6, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32> + return %7 : !torch.vtensor<[?,?,?,?,?],f32> + }