From ced0ad5c59d87960df68d6ec3da46e99ad671d0b Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Wed, 5 Feb 2025 13:52:32 -0500 Subject: [PATCH] [Task] : Use op.dtype to create EmptyMemorFormat during decomposition. --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 3 ++ .../Torch/Transforms/DecomposeComplexOps.cpp | 52 ++++++++++++------- lib/Dialect/Torch/Utils/Utils.cpp | 30 +++++++++++ .../test_suite/constant_alloc.py | 20 +++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 46 ++++++++++++++++ 5 files changed, 131 insertions(+), 20 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index b0a40e35c652..a000b7ab2f98 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -36,6 +36,9 @@ Type getTypeForTorchType( MLIRContext *context, Type type, mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed); +template +FailureOr getDtypeFromOp(PatternRewriter &rewriter, OpTy op); + FailureOr getTorchTypeForScalarType(MLIRContext *context, torch_upstream::ScalarType dtypeInt); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 1226ad2c03e2..e7d412a2ae5d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -7065,9 +7065,16 @@ class DecomposeAtenEmptyLikeOp : public OpRewritePattern { Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = rewriter.create(op.getLoc(), sizeListType, op.getSelf()); + + FailureOr dtype = getDtypeFromOp(rewriter, op); + if (failed(dtype)) { + return rewriter.notifyMatchFailure( + op, "could not determine dtype from the op."); + } + rewriter.replaceOpWithNewOp( - op, op.getType(), sizeList, op.getDtype(), op.getLayout(), - op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); + op, op.getType(), sizeList, *dtype, op.getLayout(), op.getDevice(), + op.getPinMemory(), op.getMemoryFormat()); return success(); } }; @@ -7816,18 +7823,13 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenNewEmptyOp op, PatternRewriter &rewriter) const override { Value noneVal = rewriter.create(op.getLoc()); - Value dtype = op.getDtype(); - if (isa(dtype.getType())) { - BaseTensorType tensorType = cast(op.getSelf().getType()); - if (!tensorType.hasDtype()) { - return rewriter.notifyMatchFailure( - op, "expected input tensor to have a dtype"); - } - dtype = - getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); + FailureOr dtype = getDtypeFromOp(rewriter, op); + if (failed(dtype)) { + return rewriter.notifyMatchFailure( + op, "could not determine dtype from the op."); } rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSize(), dtype, op.getLayout(), op.getDevice(), + op, op.getType(), op.getSize(), *dtype, op.getLayout(), op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal); return success(); } @@ -9286,12 +9288,12 @@ class DecomposeAtenRandnGeneratorOp Location loc = op.getLoc(); auto resultType = cast(op.getType()); - if (!resultType.hasDtype()) { + FailureOr dtype = getDtypeFromOp(rewriter, op); + if (failed(dtype)) { return rewriter.notifyMatchFailure( - op, "expected result type to have a dtype"); + op, "could not determine dtype from the op."); } - Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); Value none = rewriter.create(loc); Value low = rewriter.create( loc, rewriter.getF64FloatAttr((double)0.0)); @@ -9303,12 +9305,12 @@ class DecomposeAtenRandnGeneratorOp loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159))); Value emptyTensorA = rewriter.create( - loc, resultType, op.getSize(), /*dtype=*/dtype, + loc, resultType, op.getSize(), /*dtype=*/*dtype, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); Value emptyTensorB = rewriter.create( - loc, resultType, op.getSize(), /*dtype=*/dtype, + loc, resultType, op.getSize(), /*dtype=*/*dtype, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); @@ -9406,8 +9408,13 @@ class DecomposeAtenRandOp : public OpRewritePattern { loc, rewriter.getF64FloatAttr((double)0.0)); Value high = rewriter.create( loc, rewriter.getF64FloatAttr((double)1.0)); + FailureOr dtype = getDtypeFromOp(rewriter, op); + if (failed(dtype)) { + return rewriter.notifyMatchFailure( + op, "could not determine dtype from the op."); + } Value emptyTensor = rewriter.create( - loc, resultType, op.getSize(), /*dtype=*/op.getDtype(), + loc, resultType, op.getSize(), /*dtype=*/*dtype, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/noneVal); @@ -9565,9 +9572,14 @@ class DecomposeAtenEmptyStridedOp Value noneVal = rewriter.create(op.getLoc()); + FailureOr dtype = getDtypeFromOp(rewriter, op); + if (failed(dtype)) { + return rewriter.notifyMatchFailure( + op, "could not determine dtype from the op."); + } rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(), - op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal); + op, op.getType(), op.getSize(), *dtype, op.getLayout(), op.getDevice(), + op.getPinMemory(), /*memoryFormat=*/noneVal); return success(); } }; diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 7f80e84044df..bf6c37962ba7 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -237,6 +237,36 @@ Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, rewriter.getI64IntegerAttr(intType)); } +template +FailureOr Torch::getDtypeFromOp(PatternRewriter &rewriter, OpTy op) { + Value dtype = op.getDtype(); + if (isa(dtype.getType())) { + BaseTensorType tensorType = cast(op.getType()); + if (!tensorType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "expected input tensor to have a dtype"); + } + dtype = + getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype()); + } + return dtype; +} +// Template instantiation template std::optional +template FailureOr +Torch::getDtypeFromOp(PatternRewriter &rewriter, + AtenEmptyLikeOp op); +template FailureOr +Torch::getDtypeFromOp(PatternRewriter &rewriter, + AtenNewEmptyOp op); +template FailureOr +Torch::getDtypeFromOp(PatternRewriter &rewriter, AtenRandOp op); +template FailureOr +Torch::getDtypeFromOp(PatternRewriter &rewriter, + AtenEmptyStridedOp op); +template FailureOr +Torch::getDtypeFromOp(PatternRewriter &rewriter, + AtenRandnGeneratorOp op); + // Helper to convert a tensor to a specific scalar type. Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc, Value input, Type dtype) { diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index ab18aeea2a98..8a338e0dd6f2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -641,6 +641,26 @@ def EmptyLikeModule_falsePinMemory(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4)) +class EmptyLikeDefaultDtypeFloat64InputModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float64, True), + ] + ) + def forward(self, x): + return torch.empty_like(x).fill_(0) + + +@register_test_case(module_factory=lambda: EmptyLikeDefaultDtypeFloat64InputModule()) +def EmptyLikeDefaultDtypeFloat64InputModule_basic(module, tu: TestUtils): + module.forward(torch.ones((200, 200, 26), dtype=torch.float64)) + + # ============================================================================== diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 4c99f4949a38..6f33eee6c066 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -278,3 +278,49 @@ func.func @torch.aten._assert_scalar(%arg0: !torch.int) -> !torch.int { torch.aten._assert_scalar %3, %str : !torch.int, !torch.str return %arg0 : !torch.int } +// CHECK-LABEL: func.func @emptyLikeNoneDtype( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> { +// CHECK: %[[DTYPE:.*]] = torch.constant.int 7 +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[C200:.*]] = torch.constant.int 200 +// CHECK: %[[C26:.*]] = torch.constant.int 26 +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64> +func.func @emptyLikeNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> { + %none = torch.constant.none + %none_0 = torch.constant.none + %none_1 = torch.constant.none + %false = torch.constant.bool false + %none_2 = torch.constant.none + %0 = torch.aten.empty_like %arg0, %none, %none_0, %none_1, %false, %none_2 : !torch.vtensor<[200,200,26],f64>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64> + return %0 : !torch.vtensor<[200,200,26],f64> +} + +// ----- +// CHECK-LABEL: func.func @randNoneDtype( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> { +// CHECK: %[[DTYPE:.*]] = torch.constant.int 7 +// CHECK: %[[C1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: %[[C0:.*]] = torch.constant.float 0.000000e+00 +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[C200:.*]] = torch.constant.int 200 +// CHECK: %[[C26:.*]] = torch.constant.int 26 +// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[CPU:.*]] = torch.constant.device "cpu" +// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[CPU]], %[[FALSE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64> +// CHECK: %[[UNIFORM:.*]] = torch.aten.uniform %[[MEM_FMT]], %[[C0]], %[[C1]], %[[NONE]] : !torch.vtensor<[200,200,26],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[200,200,26],f64> +// CHECK: return %[[UNIFORM]] : !torch.vtensor<[200,200,26],f64> +func.func @randNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> { + %int200 = torch.constant.int 200 + %int200_0 = torch.constant.int 200 + %int26 = torch.constant.int 26 + %0 = torch.prim.ListConstruct %int200, %int200_0, %int26 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %none = torch.constant.none + %none_1 = torch.constant.none + %cpu = torch.constant.device "cpu" + %false = torch.constant.bool false + %1 = torch.aten.rand %0, %none, %none_1, %cpu, %false : !torch.list, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[200,200,26],f64> + return %1 : !torch.vtensor<[200,200,26],f64> + }