From cdf2b90b0e5741d4e1a58616d8d344e604ee44e9 Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Mon, 15 Jun 2026 15:16:03 +0200 Subject: [PATCH 1/2] Arm backend: Support integer division. torch.div() with two integer tensors and a rounding mode gives integer output. The decomposition for this case instead yielded floating point output, causing issues in indexing operations which is typically where such integer division happens. When supported, utilize the TOSA operator INTDIV. It directly corresponds to the trunc case, and can be adjusted in the floor case. When not supported, use float path by first casting int tensors to float, and then casting the output back. Additionally - Improve scalar handling. - Add cast op to u55 testrunner to match u85 better. Signed-off-by: Erik Lundell Change-Id: Ia5e3c956d5b83b4183171a8a230a510fc7a52149 --- .../arm/_passes/decompose_div_tensor_mode.py | 236 ++++++++++++++++-- .../tosa_profile_supported_op_lists.py | 1 + backends/arm/test/ops/test_div_tensor_mode.py | 76 +++++- backends/arm/test/setup_testing.sh | 2 +- 4 files changed, 279 insertions(+), 36 deletions(-) diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index cc5440b4e5b..b21b907284e 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -4,23 +4,33 @@ # LICENSE file in the root directory of this source tree. -from typing import Set, Type +from typing import cast, Literal, Set, Type import torch from executorch.backends.arm._passes.arm_pass import ArmOpTargetedPass from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass +from executorch.backends.arm.tosa.specification import get_context_spec from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass edge_div_mode_ops = (exir_ops.edge.aten.div.Tensor_mode,) aten_div_mode_ops = (torch.ops.aten.div.Tensor_mode,) +RoundingMode = Literal["trunc", "floor"] edge_unary = { "div": exir_ops.edge.aten.div.Tensor, "floor": exir_ops.edge.aten.floor.default, "ceil": exir_ops.edge.aten.ceil.default, + "eq": exir_ops.edge.aten.eq.Tensor, "full": exir_ops.edge.aten.full.default, "gt": exir_ops.edge.aten.gt.Tensor, + "logical_and": exir_ops.edge.aten.logical_and.default, + "logical_not": exir_ops.edge.aten.logical_not.default, + "logical_xor": exir_ops.edge.aten.logical_xor.default, + "intdiv": exir_ops.backend.tosa.INTDIV.default, + "mul": exir_ops.edge.aten.mul.Tensor, + "sub": exir_ops.edge.aten.sub.Tensor, + "to": exir_ops.edge.dim_order_ops._to_dim_order_copy.default, "where": exir_ops.edge.aten.where.self, } @@ -28,8 +38,15 @@ "div": torch.ops.aten.div.Tensor, "floor": torch.ops.aten.floor.default, "ceil": torch.ops.aten.ceil.default, + "eq": torch.ops.aten.eq.Tensor, "full": torch.ops.aten.full.default, "gt": torch.ops.aten.gt.Tensor, + "logical_and": torch.ops.aten.logical_and.default, + "logical_not": torch.ops.aten.logical_not.default, + "logical_xor": torch.ops.aten.logical_xor.default, + "mul": torch.ops.aten.mul.Tensor, + "sub": torch.ops.aten.sub.Tensor, + "to": torch.ops.dim_order_ops._to_dim_order_copy.default, "where": torch.ops.aten.where.self, } @@ -43,9 +60,9 @@ def _get_opset(op): class DecomposeDivTensorModePass(ArmOpTargetedPass): - """Rewrites aten.div.Tensor_mode into. + """Rewrites aten.div.Tensor_mode into supported arithmetic ops. - Example: + Floating-point flow: rounding_mode=None -> div(a, b) rounding_mode="floor" -> floor(div(a, b)) rounding_mode="trunc" -> where( @@ -54,12 +71,159 @@ class DecomposeDivTensorModePass(ArmOpTargetedPass): floor(div(a, b)), ) + Integer flow: + During transform-for-annotation, keep div.Tensor_mode intact, don't quantize it. + During backend lowering, rewrite the div to a TOSA INTDIV (corresponding to trunc rounding_mode) + + correcting factor for floor mode. + """ _passes_required_after: Set[Type[ExportPass]] = {DecomposeDivPass} target_ops = edge_div_mode_ops + aten_div_mode_ops check_allowed_to_transform = True + def _is_integer_tensor(self, arg) -> bool: + data = getattr(arg, "data", None) + if data is not None: + return arg.data.dtype in { + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + } + return isinstance(arg, int) + + def _cast(self, opset, arg, dtype: torch.dtype, meta): + if isinstance(arg, int): + if dtype.is_floating_point: + return float(arg) + else: + return arg + if isinstance(arg, float): + if dtype.is_floating_point: + return arg + else: + return int(arg) + data = getattr(arg, "data", None) + if data is not None and data.dtype == dtype: + return arg + return super().call_operator( + opset["to"], + (arg,), + {"dtype": dtype}, + meta, + updated=True, + ) + + def _full(self, opset, value, dtype: torch.dtype, meta): + return super().call_operator( + opset["full"], + args=((1,) * len(meta["val"].size()), value), + kwargs={"dtype": dtype, "device": meta["val"].device}, + meta=meta, + updated=True, + ) + + def _correct_intdiv_floor( + self, opset, numerator, denominator, trunced_quotient, meta + ): + """Apply a correcting factor for converting the truncated division to + floored division. + + Done by subtracting one from the result when, elementwise, + - The remainder is nonzero (otherwise the division is even and the rounding trivial) + - The numerator and denominator have different signs (causing a negative quotient) + The sign of the quotient can't be checked directly, there are cases when it is 0 and still needs correction. + + """ + # Condition 1: non-zero remainder + product = super().call_operator( + opset["mul"], (trunced_quotient, denominator), {}, meta, updated=True + ) + remainder = super().call_operator( + opset["sub"], (numerator, product), {}, meta, updated=True + ) + zero = self._full(opset, 0, torch.int32, meta) + remainder_is_zero = super().call_operator( + opset["eq"], (remainder, zero), {}, meta, updated=True + ) + remainder_is_nonzero = super().call_operator( + opset["logical_not"], (remainder_is_zero,), {}, meta, updated=True + ) + # Condition 2: un-rounded quotient is negative + a_is_negative = super().call_operator( + opset["gt"], (zero, numerator), {}, meta, updated=True + ) + b_is_negative = super().call_operator( + opset["gt"], (zero, denominator), {}, meta, updated=True + ) + signs_differ = super().call_operator( + opset["logical_xor"], + (a_is_negative, b_is_negative), + {}, + meta, + updated=True, + ) + # Use conditions to correct quotient. + needs_correction = super().call_operator( + opset["logical_and"], + (remainder_is_nonzero, signs_differ), + {}, + meta, + updated=True, + ) + # (TOSA spec enforces that int(bool_var) == 1 ? bool_var : 0) + correction = self._cast(opset, needs_correction, torch.int32, meta) + return super().call_operator( + opset["sub"], (trunced_quotient, correction), {}, meta, updated=True + ) + + def _call_integer_div(self, opset, a, b, rounding_mode: RoundingMode, meta): + """Cast inputs to int32, do TOSA INTDIV, and apply correcting factor for + floor rounding mode. + """ + + a_int32 = self._cast(opset, a, torch.int32, meta) + b_int32 = self._cast(opset, b, torch.int32, meta) + intdiv = super().call_operator( + opset["intdiv"], + (a_int32, b_int32), + {}, + meta, + updated=True, + ) + if rounding_mode == "floor": + intdiv = self._correct_intdiv_floor(opset, a_int32, b_int32, intdiv, meta) + + output_dtype = meta["val"].dtype + return self._cast(opset, intdiv, output_dtype, meta) + + def _call_fp_div(self, opset, a, b, rounding_mode: RoundingMode | None, meta): + q = super().call_operator(opset["div"], (a, b), {}, meta, updated=True) + + match rounding_mode: + case None: + return q + case "floor": + return super().call_operator( + opset["floor"], (q,), {}, meta, updated=True + ) + case "trunc": + zero = self._full(opset, 0.0, torch.float32, meta) + is_neg = super().call_operator( + opset["gt"], (zero, q), {}, meta, updated=True + ) + ceilq = super().call_operator( + opset["ceil"], (q,), {}, meta, updated=True + ) + floorq = super().call_operator( + opset["floor"], (q,), {}, meta, updated=True + ) + return super().call_operator( + opset["where"], (is_neg, ceilq, floorq), {}, meta, updated=True + ) + def call_operator(self, op, args, kwargs, meta): if op not in self.target_ops or not self.allowed_to_transform(meta): return super().call_operator(op, args, kwargs, meta) @@ -67,35 +231,53 @@ def call_operator(self, op, args, kwargs, meta): opset = _get_opset(op) a, b = args[0], args[1] + a_is_int = self._is_integer_tensor(a) + b_is_int = self._is_integer_tensor(b) rounding_mode = kwargs.get("rounding_mode", None) if rounding_mode is None and len(args) > 2: rounding_mode = args[2] + if rounding_mode not in ("floor", "trunc", None): + raise RuntimeError( + "Integer div.Tensor_mode requires rounding_mode floor, trunc, or None." + f"got {rounding_mode!r}" + ) + rounding_mode = cast(RoundingMode | None, rounding_mode) - q = super().call_operator(opset["div"], (a, b), {}, meta, updated=True) + int_operation = rounding_mode is not None and a_is_int and b_is_int + sufficient_int_support = ( + rounding_mode == "trunc" or get_context_spec().support_integer() + ) + sufficient_int_support &= not get_context_spec().is_U55_subset - if rounding_mode is None: - return q + if int_operation and sufficient_int_support: + """Integer operation and necessary int ops supported -> pure integer + path. + """ + if self.is_tfa_pass: + # No quantization neccessary, so don't do anything in TFA. + return super().call_operator(op, args, kwargs, meta) + return self._call_integer_div(opset, a, b, rounding_mode, meta) + else: + """Otherwise floating point operation -> do fp path. - if rounding_mode == "floor": - return super().call_operator(opset["floor"], (q,), {}, meta, updated=True) - - if rounding_mode == "trunc": - zero = super().call_operator( - opset["full"], - args=((1,) * len(meta["val"].size()), 0.0), - kwargs={"dtype": torch.float32, "device": meta["val"].device}, - meta=meta, - updated=True, - ) - is_neg = super().call_operator( - opset["gt"], (zero, q), {}, meta, updated=True - ) - ceilq = super().call_operator(opset["ceil"], (q,), {}, meta, updated=True) - floorq = super().call_operator(opset["floor"], (q,), {}, meta, updated=True) - return super().call_operator( - opset["where"], (is_neg, ceilq, floorq), {}, meta, updated=True + Cast to and from fp if neccessary. + + """ + if a_is_int: + a = self._cast(opset, a, torch.float32, meta) + if b_is_int: + b = self._cast(opset, b, torch.float32, meta) + + result = self._call_fp_div( + opset, + a, + b, + rounding_mode, + meta, ) - raise RuntimeError( - f"Unsupported rounding_mode for div.Tensor_mode: {rounding_mode!r}" - ) + output_dtype = meta["val"].dtype + if output_dtype != torch.float32: + result = self._cast(opset, result, output_dtype, meta) + + return result diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index dc448ba0d5f..4495ff90450 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -126,6 +126,7 @@ exir_ops.edge.aten.celu.default, exir_ops.edge.aten.bitwise_not.default, exir_ops.edge.aten.copy.default, + exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.tan.default, exir_ops.edge.aten.silu.default, exir_ops.edge.aten.detach_copy.default, diff --git a/backends/arm/test/ops/test_div_tensor_mode.py b/backends/arm/test/ops/test_div_tensor_mode.py index d9d058fccc6..9c9606eef10 100644 --- a/backends/arm/test/ops/test_div_tensor_mode.py +++ b/backends/arm/test/ops/test_div_tensor_mode.py @@ -35,6 +35,18 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.div(x, y, rounding_mode=self.mode) +def _is_integer_rounded_div(mode, inputs) -> bool: + if mode is None: + return False + for input in inputs: + if isinstance(input, torch.Tensor): + if input.dtype.is_floating_point: + return False + if not isinstance(input, int): + return False + return True + + test_data = { "mode_none": lambda: (None, (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3)), "mode_floor": lambda: ( @@ -46,6 +58,48 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: (torch.randn(4, 8), torch.randn(4, 8).abs() + 1e-3), ), "int_denominator": lambda: (None, (torch.randn(4, 8), 2)), + "int8_floor": lambda: ( + "floor", + ( + (torch.randn(4, 8) * 100).to(dtype=torch.int8), + (torch.rand(4, 8) * 100 + 10).to(dtype=torch.int8), + ), + ), + "int8_int_scalar": lambda: ( + "floor", + ( + (torch.randn(4, 8) * 100).to(dtype=torch.int8), + 9, + ), + ), + "int8_float_scalar": lambda: ( + "floor", + ( + (torch.randn(4, 8) * 100).to(dtype=torch.int8), + 9.5, + ), + ), + "int16_trunc": lambda: ( + "trunc", + ( + (torch.randn(4, 8) * 100).to(dtype=torch.int8), + (torch.rand(4, 8) * 100 + 10).to(dtype=torch.int16), + ), + ), + "int32_floor": lambda: ( + "floor", + ( + (torch.randn(4, 8) * 100).to(dtype=torch.int32), + (torch.rand(4, 8) * 100 + 10).to(dtype=torch.int32), + ), + ), + "int32_trunc": lambda: ( + "trunc", + ( + (torch.randn(4, 8) * 100).to(dtype=torch.int32), + (torch.rand(4, 8) * 100 + 10).to(dtype=torch.int32), + ), + ), } @@ -61,7 +115,6 @@ def test_div_tensor_mode_tosa_FP(data): exir_op=[], use_to_edge_transform_and_lower=True, ) - pipeline.pop_stage("check_count.exir") pipeline.run() @@ -73,17 +126,22 @@ def test_div_tensor_mode_tosa_INT(data): pipeline = TosaPipelineINT[input_tt]( model, inputs, - aten_op=model.aten_ops_int, + aten_op=[], exir_op=[], use_to_edge_transform_and_lower=True, ) - pipeline.pop_stage("check_count.exir") pipeline.run() @common.XfailIfNoCorstone300 @common.parametrize( - "data", test_data, xfails={"mode_trunc": "CPU op missing in unittests"} + "data", + test_data, + xfails={ + "mode_trunc": "CPU op missing in unittests", + "int16_trunc": "CPU op missing in unittests", + "int32_trunc": "CPU op missing in unittests", + }, ) def test_div_tensor_mode_u55_INT(data): mode, inputs = data() @@ -92,10 +150,12 @@ def test_div_tensor_mode_u55_INT(data): pipeline = EthosU55PipelineINT[input_tt]( model, inputs, - aten_ops=model.aten_ops_int, + aten_ops=[], exir_ops=[], use_to_edge_transform_and_lower=True, ) + pipeline.tester.use_portable_ops = True + pipeline.pop_stage("check_count.exir") pipeline.run() @@ -108,10 +168,11 @@ def test_div_tensor_mode_u85_INT(data): pipeline = EthosU85PipelineINT[input_tt]( model, inputs, - aten_ops=model.aten_ops_int, + aten_ops=[], exir_ops=[], use_to_edge_transform_and_lower=True, ) + pipeline.tester.use_portable_ops = True pipeline.run() @@ -124,12 +185,11 @@ def test_div_tensor_mode_vgf_quant(data): pipeline = VgfPipeline[input_tt]( model, inputs, - aten_op=model.aten_ops_int, + aten_op=[], exir_op=[], use_to_edge_transform_and_lower=True, quantize=True, ) - pipeline.pop_stage("check_count.exir") pipeline.run() diff --git a/backends/arm/test/setup_testing.sh b/backends/arm/test/setup_testing.sh index 39d8335a26e..fcbdec25043 100755 --- a/backends/arm/test/setup_testing.sh +++ b/backends/arm/test/setup_testing.sh @@ -26,7 +26,7 @@ ${build_executor_runner} --pte=semihosting --target=ethos-u85-128 --system_confi # test setup to make sure models that are not fully delegated can still be tested and run OK # To use this you can set use_portable_ops=True when creating ArmTester() -portable_ops_list_u55="aten::permute_copy.out,aten::convolution.out,aten::relu.out,aten::_native_batch_norm_legit_no_training.out,aten::as_strided_copy.out,aten::mean.out,aten::squeeze_copy.dims,dim_order_ops::_clone_dim_order.out" +portable_ops_list_u55="aten::permute_copy.out,aten::convolution.out,aten::relu.out,aten::_native_batch_norm_legit_no_training.out,aten::as_strided_copy.out,aten::mean.out,aten::squeeze_copy.dims,dim_order_ops::_clone_dim_order.out,dim_order_ops::_to_dim_order_copy.out" portable_ops_list_u65="${portable_ops_list_u55}" portable_ops_list_u85="aten::permute_copy.out,aten::convolution.out,aten::relu.out,aten::_native_batch_norm_legit_no_training.out,aten::as_strided_copy.out,aten::mean.out,aten::full_like.out,aten::bmm.out,aten::scalar_tensor.out,aten::index.Tensor_out,aten::where.self_out,dim_order_ops::_to_dim_order_copy.out" From e55e8785a3202ce9f44dc3a1ceae69e73a8eca7f Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Wed, 17 Jun 2026 12:48:00 +0200 Subject: [PATCH 2/2] Arm backend: Support quantizing no-op fp casts. If a int to fp cast goes directly into a quantized operator, the cast is a noop if it is quantized with unit quantization parameters (scale=1, zp=0). This means it can be handled, even if the backend doesn't support fp. - Add Fixed unit qparam annotation to such a cast. - Allow it in the partitioner. - Finally, make sure the cast dtype is correct after folding quant nodes. This pattern can appear in the wild in a model, or in a decomposition where you want to quantize an integer tensor. Signed-off-by: Erik Lundell Change-Id: I93bfd17ba51e25f121f61cbe3003e9c6b9891401 --- .../arm/_passes/decompose_div_tensor_mode.py | 2 +- .../fold_qdq_with_annotated_qparams_pass.py | 27 +++- .../to_dim_order_copy_support.py | 19 +++ .../arm/quantizer/quantization_annotator.py | 21 +++ backends/arm/test/ops/test_div_tensor_mode.py | 1 - backends/arm/test/ops/test_to_copy.py | 132 ++++++++++++++++-- 6 files changed, 185 insertions(+), 17 deletions(-) diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index b21b907284e..e0e0c219135 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -46,7 +46,7 @@ "logical_xor": torch.ops.aten.logical_xor.default, "mul": torch.ops.aten.mul.Tensor, "sub": torch.ops.aten.sub.Tensor, - "to": torch.ops.dim_order_ops._to_dim_order_copy.default, + "to": torch.ops.aten.to.dtype, "where": torch.ops.aten.where.self, } diff --git a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py index 09e90b88e36..713d6ef354a 100644 --- a/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py +++ b/backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py @@ -305,6 +305,25 @@ def is_foldable(node: Node) -> bool: return False return True + @staticmethod + def _correct_output_dtype(node: torch.fx.Node): + if node.target not in { + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + }: + return + if len(node.meta["output_qparams"]) == 0: + return + output_qparams = cast(QuantArgs, node.meta["output_qparams"][0]) + + if node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default: + if output_qparams.scale != 1.0 or output_qparams.zp != 0.0: + raise ValueError( + f"Expected quantized {node.target} '{node.name}' to have unit scale and zero point." + ) + + set_node_arg(node, "dtype", output_qparams.dtype) + def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 # Loop over the graph nodes and find any node in the 'targeted_ops' list. @@ -355,13 +374,7 @@ def call(self, graph_module: GraphModule) -> PassResult: # noqa: C901 # Some op(s) contain a "dtype" key in their node kwargs. Set this # to the type of output qparams. - output_qparams = n.meta["output_qparams"] - if ( - n.target in {exir_ops.edge.aten.sum.dim_IntList} - and len(output_qparams) > 0 - ): - output_dtype = output_qparams[0].dtype - set_node_arg(n, "dtype", output_dtype) + FoldAndAnnotateQParamsPass._correct_output_dtype(n) if n.target in ( torch.ops.higher_order.cond, diff --git a/backends/arm/operator_support/to_dim_order_copy_support.py b/backends/arm/operator_support/to_dim_order_copy_support.py index a02a8e16276..b7062ebbb97 100644 --- a/backends/arm/operator_support/to_dim_order_copy_support.py +++ b/backends/arm/operator_support/to_dim_order_copy_support.py @@ -139,6 +139,20 @@ def _merge_supported_types( torch.float8_e5m2: [torch.bfloat16], } + @staticmethod + def _is_quantized_identity_cast(node: torch.fx.Node) -> bool: + for user in node.users: + if ( + not user.target + == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ): + return False + scale = user.args[1] + zp = user.args[2] + if scale != 1.0 or zp != 0.0: + return False + return True + def is_node_tosa_supported( # noqa: C901 self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: @@ -228,6 +242,11 @@ def is_node_tosa_supported( # noqa: C901 ) return False if output_val.dtype not in supported_dtypes[input_dtype]: + if ( + tosa_spec.support_integer() + and ToCopySupported._is_quantized_identity_cast(node) + ): + return True self.reporter.report_reject( node, ( diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 13693bd235d..ad4c85c5030 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -956,6 +956,27 @@ def any_or_hardtanh_min_zero(n: Node): shared_qspec = SharedQuantizationSpec(input_node) quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] quant_properties.quant_output = _QuantProperty(0, shared_qspec) + elif node.target == torch.ops.aten.to.dtype: + # If we quantize a cast(fp32) with unit scale and same dtype as input, we can handle it as a no-op in the backend. + input_val = node.all_input_nodes[0].meta.get("val", None) + if input_val is None: + return None + + if input_val.dtype not in (torch.int8, torch.int16, torch.int32): + return None + + quant_properties.quant_output = _QuantProperty( + 0, + FixedQParamsQuantizationSpec( + dtype=input_val.dtype, + scale=1.0, + zero_point=0, + quant_max=torch.iinfo(input_val.dtype).max, + quant_min=torch.iinfo(input_val.dtype).min, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + ), + ) elif node.target in ( torch.ops.higher_order.cond, torch.ops.higher_order.while_loop, diff --git a/backends/arm/test/ops/test_div_tensor_mode.py b/backends/arm/test/ops/test_div_tensor_mode.py index 9c9606eef10..9b9b95650b0 100644 --- a/backends/arm/test/ops/test_div_tensor_mode.py +++ b/backends/arm/test/ops/test_div_tensor_mode.py @@ -172,7 +172,6 @@ def test_div_tensor_mode_u85_INT(data): exir_ops=[], use_to_edge_transform_and_lower=True, ) - pipeline.tester.use_portable_ops = True pipeline.run() diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py index 6718fedea04..16f5ff0e36d 100644 --- a/backends/arm/test/ops/test_to_copy.py +++ b/backends/arm/test/ops/test_to_copy.py @@ -20,6 +20,7 @@ ) input_t1 = Tuple[torch.Tensor] # Input x +input_t2 = Tuple[torch.Tensor, torch.Tensor] # Input x, y class Cast(torch.nn.Module): @@ -40,6 +41,40 @@ def forward(self, x: torch.Tensor): return x.to(dtype=self.target_dtype) + x.to(dtype=self.target_dtype) +class CastAddTensor(torch.nn.Module): + def __init__(self, target_dtype): + super().__init__() + self.target_dtype = target_dtype + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x.to(dtype=self.target_dtype) + y + + +class AddModule(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x + y + + +class CastToAddModule(torch.nn.Module): + def __init__(self, target_dtype): + super().__init__() + self.target_dtype = target_dtype + self.add = AddModule() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return self.add(x.to(dtype=self.target_dtype), y) + + +class CastCatTensor(torch.nn.Module): + def __init__(self, target_dtype, dim: int): + super().__init__() + self.target_dtype = target_dtype + self.dim = dim + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return torch.cat((x.to(dtype=self.target_dtype), y), dim=self.dim) + + """ Tests the _to_copy operation. @@ -262,14 +297,6 @@ def test_to_vgf_no_quant(test_data: Tuple): in ToCopySupported::is_node_tosa_supported() before it goes into the delegated graph. """ _TO_COPY_TEST_DATA_INT = { - "rand_int8_fp32": lambda: ( - torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int8), - torch.float32, - ), - "rand_int16_fp32": lambda: ( - torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int16), - torch.float32, - ), "rand_int32_fp32": lambda: ( torch.randint(-127, 128, (1, 2, 3, 4), dtype=torch.int32), torch.float32, @@ -300,6 +327,95 @@ def test_to_tosa_INT_not_delegated(test_data: Tuple): pipeline.run() +_TO_COPY_QUANTIZED_IDENTITY_CAST_DATA = { + "int8_cast_add": lambda: ( + (torch.randn(1, 3, 4, 4) * 10).to(dtype=torch.int8), + torch.randn(1, 3, 4, 4), + torch.float32, + ), + "int16_cast_add": lambda: ( + (torch.randn(1, 3, 4, 4) * 10).to(dtype=torch.int16), + torch.randn(1, 3, 4, 4), + torch.float32, + ), + "int32_cast_add": lambda: ( + (torch.randn(1, 3, 4, 4) * 10).to(dtype=torch.int32), + torch.randn(1, 3, 4, 4), + torch.float32, + ), +} + + +_TO_COPY_QUANTIZED_IDENTITY_CAST_CAT_DATA = { + "int8_cast_cat": lambda: ( + (torch.randn(1, 2, 4, 4) * 10).to(dtype=torch.int8), + torch.randn(1, 2, 4, 1), + torch.float32, + 3, + ), + "int16_cast_cat": lambda: ( + (torch.randn(1, 2, 4, 4) * 10).to(dtype=torch.int16), + torch.randn(1, 2, 4, 1), + torch.float32, + 3, + ), +} + + +@common.parametrize("test_data", _TO_COPY_QUANTIZED_IDENTITY_CAST_DATA) +def test_to_tosa_INT_quantized_identity_cast_add(test_data: Tuple): + x, y, new_dtype = test_data() + pipeline = TosaPipelineINT[input_t2]( + CastAddTensor(new_dtype), + (x, y), + aten_op=["torch.ops.aten.add.Tensor"], + exir_op=["executorch_exir_dialects_edge__ops_aten_add_Tensor"], + qtol=1, + ) + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 1, + }, + ) + pipeline.run() + + +@common.parametrize("test_data", _TO_COPY_QUANTIZED_IDENTITY_CAST_CAT_DATA) +def test_to_tosa_INT_quantized_identity_cast_cat(test_data: Tuple): + x, y, new_dtype, dim = test_data() + pipeline = TosaPipelineINT[input_t2]( + CastCatTensor(new_dtype, dim), + (x, y), + aten_op=["torch.ops.aten.cat.default"], + exir_op=["executorch_exir_dialects_edge__ops_aten_cat_default"], + ) + pipeline.run() + + +@common.parametrize("test_data", _TO_COPY_QUANTIZED_IDENTITY_CAST_DATA) +def test_to_tosa_INT_quantized_identity_cast_to_unquantized_add_delegated( + test_data: Tuple, +): + x, y, new_dtype = test_data() + pipeline = TosaPipelineINT[input_t2]( + CastToAddModule(new_dtype), + (x, y), + aten_op=["torch.ops.aten.add.Tensor"], + exir_op=["executorch_exir_dialects_edge__ops_aten_add_Tensor"], + ) + pipeline.quantizer.set_module_name("add", None) + pipeline.pop_stage("check_not.exir") + pipeline.change_args( + "check_count.exir", + { + "torch.ops.higher_order.executorch_call_delegate": 1, + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 0, + }, + ) + pipeline.run() + + @common.parametrize("test_data", _TO_COPY_TEST_DATA_INT) @common.SkipIfNoModelConverter def test_to_vgf_quant(test_data: Tuple):