Skip to content

Commit 445c845

Browse files
authored
Support prims.sum (#4052)
This is supported by decomposing into `aten.sum`.
1 parent d91e1ac commit 445c845

File tree

7 files changed

+99
-0
lines changed

7 files changed

+99
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18369,6 +18369,31 @@ def Torch_PrimsSqueezeOp : Torch_Op<"prims.squeeze", [
1836918369
}];
1837018370
}
1837118371

18372+
def Torch_PrimsSumOp : Torch_Op<"prims.sum", [
18373+
AllowsTypeRefinement,
18374+
HasValueSemantics,
18375+
ReadOnly
18376+
]> {
18377+
let summary = "Generated op for `prims::sum : (Tensor, int[]?, int?) -> (Tensor)`";
18378+
let arguments = (ins
18379+
AnyTorchTensorType:$inp,
18380+
AnyTorchOptionalListOfTorchIntType:$dims,
18381+
AnyTorchOptionalIntType:$output_dtype
18382+
);
18383+
let results = (outs
18384+
AnyTorchOptionalTensorType:$result
18385+
);
18386+
let hasCustomAssemblyFormat = 1;
18387+
let extraClassDefinition = [{
18388+
ParseResult PrimsSumOp::parse(OpAsmParser &parser, OperationState &result) {
18389+
return parseDefaultTorchOp(parser, result, 3, 1);
18390+
}
18391+
void PrimsSumOp::print(OpAsmPrinter &printer) {
18392+
printDefaultTorchOp(printer, *this, 3, 1);
18393+
}
18394+
}];
18395+
}
18396+
1837218397
def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [
1837318398
AllowsTypeRefinement,
1837418399
ReadOnly

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7542,6 +7542,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
75427542
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
75437543
" return %1 : !torch.list<int>\n"
75447544
" }\n"
7545+
" func.func @\"__torch_mlir_shape_fn.prims.sum\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
7546+
" %false = torch.constant.bool false\n"
7547+
" %0 = torch.derefine %arg2 : !torch.optional<int> to !torch.any\n"
7548+
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %false, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
7549+
" return %1 : !torch.list<int>\n"
7550+
" }\n"
75457551
" func.func @\"__torch_mlir_shape_fn.aten.prod.dim_int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
75467552
" %0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>\n"
75477553
" %1 = torch.derefine %0 : !torch.list<int> to !torch.optional<list<int>>\n"
@@ -11889,6 +11895,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1188911895
" }\n"
1189011896
" return %2 : !torch.int\n"
1189111897
" }\n"
11898+
" func.func @\"__torch_mlir_dtype_fn.prims.sum\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>) -> !torch.int {\n"
11899+
" %str = torch.constant.str \"AssertionError: \"\n"
11900+
" %none = torch.constant.none\n"
11901+
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
11902+
" torch.prim.If %0 -> () {\n"
11903+
" torch.prim.If.yield\n"
11904+
" } else {\n"
11905+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
11906+
" torch.prim.If.yield\n"
11907+
" }\n"
11908+
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
11909+
" return %1#1 : !torch.int\n"
11910+
" }\n"
1189211911
" func.func @\"__torch_mlir_dtype_fn.aten.abs\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
1189311912
" %int6 = torch.constant.int 6\n"
1189411913
" %int9 = torch.constant.int 9\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10495,6 +10495,23 @@ class DecomposeAtenScatterValueOp
1049510495
};
1049610496
} // namespace
1049710497

10498+
namespace {
10499+
// Decompose prims.sum into aten.sum
10500+
class DecomposePrimsSumOp : public OpRewritePattern<PrimsSumOp> {
10501+
public:
10502+
using OpRewritePattern<PrimsSumOp>::OpRewritePattern;
10503+
LogicalResult matchAndRewrite(PrimsSumOp op,
10504+
PatternRewriter &rewriter) const override {
10505+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);
10506+
10507+
rewriter.replaceOpWithNewOp<AtenSumDimIntListOp>(
10508+
op, op.getType(), op.getInp(), op.getDims(), /*keepdim=*/cstFalse,
10509+
op.getOutputDtype());
10510+
return success();
10511+
}
10512+
};
10513+
} // namespace
10514+
1049810515
namespace {
1049910516
// Decompose `aten.sgn` op into comparisons and aten.where.
1050010517
class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
@@ -11812,6 +11829,7 @@ class DecomposeComplexOpsPass
1181211829
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
1181311830
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
1181411831
addPatternIfTargetOpIsIllegal<DecomposeAtenSgnOp>(patterns);
11832+
addPatternIfTargetOpIsIllegal<DecomposePrimsSumOp>(patterns);
1181511833
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
1181611834
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
1181711835
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3017,6 +3017,7 @@
30173017
"PrimsConvertElementTypeModule_basic",
30183018
"PrimsSqueezeEmptyDimensionsModule_basic",
30193019
"PrimsSqueezeModule_basic",
3020+
"PrimsSumFloatModule_basic",
30203021
"PrimsViewOfModule_basic",
30213022
"PrimsViewOfZeroRankModule_basic",
30223023
"QuantizedReluInt8_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,9 @@ def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim
809809
def aten〇sum〇dim_IntList〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
810810
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)
811811

812+
def prims〇sum〡shape(inp: List[int], dims: Optional[List[int]], output_dtype: Optional[int] = None) -> List[int]:
813+
return upstream_shape_functions.sum_mean_dim(inp, dims, False, output_dtype)
814+
812815
def aten〇prod〇dim_int〡shape(self: List[int], dim: int, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
813816
return upstream_shape_functions.sum_mean_dim(self, [dim], keepdim, dtype)
814817

@@ -2892,6 +2895,15 @@ def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
28922895
return self_dtype
28932896
return _get_dtype_of_floating_point_op(self_dtype)
28942897

2898+
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0]))
2899+
def prims〇sum〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], output_dtype: Optional[int] = None) -> int:
2900+
# When invoking prims.sum() with the output_dtype argument, pytorch
2901+
# complains that the argument is not known.
2902+
# See https://github.com/pytorch/pytorch/issues/102610
2903+
assert output_dtype is None
2904+
inp_rank, inp_dtype = inp_rank_dtype
2905+
return inp_dtype
2906+
28952907
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
28962908
def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
28972909
self_rank, self_dtype = self_rank_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,7 @@ def emit_with_mutating_variants(key, **kwargs):
12771277
emit("prims::collapse : (Tensor, int, int) -> (Tensor)")
12781278
emit("prims::split_dim : (Tensor, int, int) -> (Tensor)")
12791279
emit("prims::squeeze : (Tensor, int[]) -> (Tensor)")
1280+
emit("prims::sum : (Tensor, int[]?, int?) -> (Tensor)")
12801281
emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True)
12811282
emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)")
12821283

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,29 @@ def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils):
8181
# ==============================================================================
8282

8383

84+
class PrimsSumFloatModule(torch.nn.Module):
85+
def __init__(self):
86+
super().__init__()
87+
88+
@export
89+
@annotate_args(
90+
[
91+
None,
92+
([-1, -1, -1], torch.float32, True),
93+
]
94+
)
95+
def forward(self, a):
96+
return torch.ops.prims.sum(a, (0, 1))
97+
98+
99+
@register_test_case(module_factory=lambda: PrimsSumFloatModule())
100+
def PrimsSumFloatModule_basic(module, tu: TestUtils):
101+
module.forward(tu.rand(3, 4, 5))
102+
103+
104+
# ==============================================================================
105+
106+
84107
class ReduceProdFloatModule(torch.nn.Module):
85108
def __init__(self):
86109
super().__init__()

0 commit comments

Comments
 (0)