Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support prims.sum #4052

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18242,6 +18242,31 @@ def Torch_PrimsSqueezeOp : Torch_Op<"prims.squeeze", [
}];
}

def Torch_PrimsSumOp : Torch_Op<"prims.sum", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `prims::sum : (Tensor, int[]?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$inp,
AnyTorchOptionalListOfTorchIntType:$dims,
AnyTorchOptionalIntType:$output_dtype
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult PrimsSumOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void PrimsSumOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}

def Torch_PrimsViewOfOp : Torch_Op<"prims.view_of", [
AllowsTypeRefinement,
ReadOnly
Expand Down
19 changes: 19 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7542,6 +7542,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg2, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.prims.sum\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %0 = torch.derefine %arg2 : !torch.optional<int> to !torch.any\n"
" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %false, %0) : (!torch.list<int>, !torch.optional<list<int>>, !torch.bool, !torch.any) -> !torch.list<int>\n"
" return %1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.prod.dim_int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list<int>\n"
" %1 = torch.derefine %0 : !torch.list<int> to !torch.optional<list<int>>\n"
Expand Down Expand Up @@ -11879,6 +11885,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.prims.sum\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<int>) -> !torch.int {\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %none = torch.constant.none\n"
" %0 = torch.aten.__is__ %arg2, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
" torch.prim.If %0 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %1#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.abs\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %int9 = torch.constant.int 9\n"
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10470,6 +10470,23 @@ class DecomposeAtenScatterValueOp
};
} // namespace

namespace {
// Decompose prims.sum into aten.sum
class DecomposePrimsSumOp : public OpRewritePattern<PrimsSumOp> {
public:
using OpRewritePattern<PrimsSumOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PrimsSumOp op,
PatternRewriter &rewriter) const override {
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(op.getLoc(), false);

rewriter.replaceOpWithNewOp<AtenSumDimIntListOp>(
op, op.getType(), op.getInp(), op.getDims(), /*keepdim=*/cstFalse,
op.getOutputDtype());
return success();
}
};
} // namespace

namespace {
// Decompose `aten.sgn` op into comparisons and aten.where.
class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
Expand Down Expand Up @@ -11712,6 +11729,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSgnOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsSumOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTypeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTileOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2999,6 +2999,7 @@
"PrimsConvertElementTypeModule_basic",
"PrimsSqueezeEmptyDimensionsModule_basic",
"PrimsSqueezeModule_basic",
"PrimsSumFloatModule_basic",
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"QuantizedReluInt8_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,9 @@ def aten〇mean〇dim〡shape(self: List[int], dim: Optional[List[int]], keepdim
def aten〇sum〇dim_IntList〡shape(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype)

def prims〇sum〡shape(inp: List[int], dims: Optional[List[int]], output_dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.sum_mean_dim(inp, dims, False, output_dtype)

def aten〇prod〇dim_int〡shape(self: List[int], dim: int, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.sum_mean_dim(self, [dim], keepdim, dtype)

Expand Down Expand Up @@ -2882,6 +2885,15 @@ def prims〇sqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return self_dtype
return _get_dtype_of_floating_point_op(self_dtype)

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dims=[0]))
def prims〇sum〡dtype(inp_rank_dtype: Tuple[int, int], dims: Optional[List[int]], output_dtype: Optional[int] = None) -> int:
# When invoking prims.sum() with the output_dtype argument, pytorch
# complains that the argument is not known.
# See https://github.com/pytorch/pytorch/issues/102610
assert output_dtype is None
inp_rank, inp_dtype = inp_rank_dtype
return inp_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇abs〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("prims::collapse : (Tensor, int, int) -> (Tensor)")
emit("prims::split_dim : (Tensor, int, int) -> (Tensor)")
emit("prims::squeeze : (Tensor, int[]) -> (Tensor)")
emit("prims::sum : (Tensor, int[]?, int?) -> (Tensor)")
emit("prims::view_of : (Tensor) -> (Tensor)", has_folder=True)
emit("prims::iota : (int, int, int, int, Device, bool) -> (Tensor)")

Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,29 @@ def ReduceSumElementTypeBoolModule_basic(module, tu: TestUtils):
# ==============================================================================


class PrimsSumFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.prims.sum(a, (0, 1))


@register_test_case(module_factory=lambda: PrimsSumFloatModule())
def PrimsSumFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5))


# ==============================================================================


class ReduceProdFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading