From 445c845553a0e376d0206029396d45b71245dff6 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 28 Feb 2025 08:47:26 +0100 Subject: [PATCH] Support `prims.sum` (#4052) This is supported by decomposing into `aten.sum`. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 19 ++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 18 +++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 12 +++++++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/reduction.py | 23 +++++++++++++++++ 7 files changed, 99 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 906f77c39e0f..ec9dd518db73 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -18369,6 +18369,31 @@ def Torch_PrimsSqueezeOp : Torch_Op<"prims.squeeze", [ }]; } +def Torch_PrimsSumOp : Torch_Op<"prims.sum", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `prims::sum : (Tensor, int[]?, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$inp, + AnyTorchOptionalListOfTorchIntType:$dims, + AnyTorchOptionalIntType:$output_dtype + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult PrimsSumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void PrimsSumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 2c70913fb5d3..c1092171fcd6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7542,6 +7542,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.prims.sum\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %0 = torch.derefine %arg2 : !torch.optional to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %false, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.prod.dim_int\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.optional) -> !torch.list {\n" " %0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list\n" " %1 = torch.derefine %0 : !torch.list to !torch.optional>\n" @@ -11889,6 +11895,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.prims.sum\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional) -> !torch.int {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" torch.prim.If %0 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %1#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.abs\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %int9 = torch.constant.int 9\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 377c4e94a928..9384390cb216 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -10495,6 +10495,23 @@ class DecomposeAtenScatterValueOp }; } // namespace +namespace { +// Decompose prims.sum into aten.sum +class DecomposePrimsSumOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimsSumOp op, + PatternRewriter &rewriter) const override { + Value cstFalse = rewriter.create(op.getLoc(), false); + + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getInp(), op.getDims(), /*keepdim=*/cstFalse, + op.getOutputDtype()); + return success(); + } +}; +} // namespace + namespace { // Decompose `aten.sgn` op into comparisons and aten.where. class DecomposeAtenSgnOp : public OpRewritePattern { @@ -11812,6 +11829,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 315226856cc0..e7d6e1ea31d3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3017,6 +3017,7 @@ "PrimsConvertElementTypeModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", "PrimsSqueezeModule_basic", + "PrimsSumFloatModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", "QuantizedReluInt8_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6d64313283c1..cb10a8aead79 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -809,6 +809,9 @@ def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim def aten〇sum〇dim_IntList〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) +def prims〇sum〡shape(inp: List[int], dims: Optional[List[int]], output_dtype: Optional[int] = None) -> List[int]: + return upstream_shape_functions.sum_mean_dim(inp, dims, False, output_dtype) + def aten〇prod〇dim_int〡shape(self: List[int], dim: int, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, [dim], keepdim, dtype) @@ -2892,6 +2895,15 @@ def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return self_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0])) +def prims〇sum〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], output_dtype: Optional[int] = None) -> int: + # When invoking prims.sum() with the output_dtype argument, pytorch + # complains that the argument is not known. + # See https://github.com/pytorch/pytorch/issues/102610 + assert output_dtype is None + inp_rank, inp_dtype = inp_rank_dtype + return inp_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index f7d88bf91fef..46485b3173cc 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1277,6 +1277,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("prims::collapse : (Tensor, int, int) -> (Tensor)") emit("prims::split_dim : (Tensor, int, int) -> (Tensor)") emit("prims::squeeze : (Tensor, int[]) -> (Tensor)") + emit("prims::sum : (Tensor, int[]?, int?) -> (Tensor)") emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True) emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 3e379deacb79..b72bf64dbcfa 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -81,6 +81,29 @@ def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils): # ============================================================================== +class PrimsSumFloatModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.prims.sum(a, (0, 1)) + + +@register_test_case(module_factory=lambda: PrimsSumFloatModule()) +def PrimsSumFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + + +# ============================================================================== + + class ReduceProdFloatModule(torch.nn.Module): def __init__(self): super().__init__()