diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 281637f155e9..d8d5d1cd1e71 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4972,6 +4972,30 @@ def Torch_AtenLogSigmoidOp : Torch_Op<"aten.log_sigmoid", [ }]; } +def Torch_AtenSoftshrinkOp : Torch_Op<"aten.softshrink", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::softshrink : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$lambd + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSoftshrinkOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenSoftshrinkOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index faf3d0a256f6..928c42a9815d 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6514,6 +6514,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.softshrink\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.mish\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9790,6 +9794,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.softshrink\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.logit\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6cb02297d497..cc17b476d14c 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1914,6 +1914,67 @@ class DecomposeAtenLogSigmoidOp : public OpRewritePattern { }; } // namespace +// SoftShrink(x, lambda) function: +// Applies a shrinkage function where: +// - If x > lambda, returns x - lambda +// - If x < -lambda, returns x + lambda +// - Otherwise, returns 0 +namespace { +class DecomposeAtenSoftshrinkOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSoftshrinkOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value lambdValue = op.getLambd(); + + auto resTy = cast(op.getType()); + if (!resTy.hasDtype() || !resTy.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "result should have dtype and size"); + } + + double lambd; + if (!matchPattern(lambdValue, m_TorchConstantFloat(&lambd))) { + return rewriter.notifyMatchFailure( + op, "expected lambd to be a constant float"); + } + + Value zero = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value neglambd = rewriter.create( + loc, rewriter.getF64FloatAttr(-lambd)); + Value poslambd = rewriter.create( + loc, rewriter.getF64FloatAttr(lambd)); + + Value constOneFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + + auto boolResType = + resTy.getWithSizesAndDtype(resTy.getSizes(), rewriter.getI1Type()); + + Value posMask = + rewriter.create(loc, boolResType, self, poslambd); + Value negMask = + rewriter.create(loc, boolResType, self, neglambd); + + Value posValue = rewriter.create(loc, resTy, self, + poslambd, constOneFloat); + Value negValue = rewriter.create(loc, resTy, self, + neglambd, constOneFloat); + + Value result = rewriter.create(loc, resTy, posMask, + posValue, zero); + result = + rewriter.create(loc, resTy, negMask, negValue, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + // Decompose aten.matmul into: aten.mm and aten.bmm according to ranks. namespace { class DecomposeAtenMatmulOp : public OpRewritePattern { @@ -7621,6 +7682,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeConstantTensorAllocLikeOp>(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index c5855a1fa092..32bcae9fd509 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -371,6 +371,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 87344fb99b59..c0efd12f5d4e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1432,6 +1432,7 @@ "ElementwiseTruncIntModule_basic", "ElementwiseTruncModule_basic", "ElementwiseLogSigmoidModule_basic", + "ElementwiseSoftshrinkStaticModule_basic", } STABLEHLO_CRASHING_SET = { @@ -1659,6 +1660,7 @@ "ElementwiseSeluModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSignModule_basic", + "ElementwiseSoftshrinkModule_basic" "ElementwiseSoftshrinkStaticModule_basic", "ElementwiseSqrtIntModule_basic", "ElementwiseSqrtModule_basic", "ElementwiseSubScalarFloatModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6574c0bdc1db..5e5c6d537696 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -254,6 +254,9 @@ def aten〇log〡shape(self: List[int]) -> List[int]: def aten〇log_sigmoid〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇softshrink〡shape(self: List[int], lambd: float = 0.5) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇mish〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2092,6 +2095,11 @@ def aten〇log_sigmoid〡dtype(self_rank_dtype: Tuple[int, int]) -> int: assert not self_dtype == torch.bool return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, lambd=0.5)) +def aten〇softshrink〡dtype(self_rank_dtype: Tuple[int, int], lambd: Union[int, float, complex] = 0.5) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇logit〡dtype(self_rank_dtype: Tuple[int, int], eps: Optional[float] = None) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 3ebd007535e5..7617cd866a43 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -479,6 +479,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)") emit("aten::glu : (Tensor, int) -> (Tensor)") emit("aten::log_sigmoid : (Tensor) -> (Tensor)") + emit("aten::softshrink : (Tensor, Scalar) -> (Tensor)") # Ops with dynamic number of outputs emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 47f4a64038e9..0a323e33a54a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2133,6 +2133,52 @@ def ElementwiseLogSigmoidModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSoftshrinkModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.softshrink(a) + + +@register_test_case(module_factory=lambda: ElementwiseSoftshrinkModule()) +def ElementwiseSoftshrinkModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseSoftshrinkStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4, 5, 6], torch.float32, True), + ] + ) + def forward(self, a): + return torch.ops.aten.softshrink(a, 2.0) + + +@register_test_case(module_factory=lambda: ElementwiseSoftshrinkStaticModule()) +def ElementwiseSoftshrinkStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 5, 6)) + + +# ============================================================================== + + class ElementwiseErfModule(torch.nn.Module): def __init__(self): super().__init__()