From 3ef5cb233c90856e5317c0c6ed2a587621da96db Mon Sep 17 00:00:00 2001 From: Yufeng Shi Date: Thu, 10 Apr 2025 11:33:26 +0100 Subject: [PATCH] Arm backend: Add TOSA support for GroupNorm - Decompose groupnorm into a sequence of supported operators - Have some numerical issues with BI profile - Fix docstring in decompose_layernorm_pass - Add "native_group_norm.default" to CUSTOM_EDGE_OPS Change-Id: I3f70388c12b8d9afd52876840b6c008a1b0bec4e Signed-off-by: Yufeng Shi --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 3 + .../arm/_passes/decompose_groupnorm_pass.py | 208 ++++++++++++++++++ .../arm/_passes/decompose_layernorm_pass.py | 8 +- .../tosa_supported_operators.py | 2 + backends/arm/scripts/parse_test_names.py | 2 +- backends/arm/test/ops/test_group_norm.py | 132 +++++++++++ 7 files changed, 351 insertions(+), 5 deletions(-) create mode 100644 backends/arm/_passes/decompose_groupnorm_pass.py create mode 100644 backends/arm/test/ops/test_group_norm.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 1a1719bf8ae..50590f59403 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -21,6 +21,7 @@ from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_gelu_pass import DecomposeGeluPass # noqa +from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa from .decompose_linear_pass import DecomposeLinearPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 261ee045790..6f7a4561ded 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -26,6 +26,7 @@ DecomposeBatchNormPass, DecomposeDivPass, DecomposeGeluPass, + DecomposeGroupNormPass, DecomposeLayerNormPass, DecomposeLeakyReLUPass, DecomposeLinearPass, @@ -127,6 +128,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeLinearPass()) self.add_pass(DecomposeLeakyReLUPass()) self.add_pass(DecomposeBatchNormPass()) + self.add_pass(DecomposeGroupNormPass()) self.add_pass(DecomposeLayerNormPass()) self.add_pass(DecomposeVarPass()) self.add_pass(DecomposeMeanDimPass()) @@ -180,6 +182,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram): def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(ScalarsToAttributePass()) + self.add_pass(DecomposeGroupNormPass()) self.add_pass(DecomposeLayerNormPass()) self.add_pass(DecomposeVarPass()) self.add_pass(DecomposeMeanDimPass()) diff --git a/backends/arm/_passes/decompose_groupnorm_pass.py b/backends/arm/_passes/decompose_groupnorm_pass.py new file mode 100644 index 00000000000..c6cb1b05e40 --- /dev/null +++ b/backends/arm/_passes/decompose_groupnorm_pass.py @@ -0,0 +1,208 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import operator + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import PassResult + + +def get_group_norm_decomposition(op) -> tuple: + if op == exir_ops.edge.aten.native_group_norm.default: + return ( + exir_ops.edge.aten.mean.dim, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.var.correction, + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.view_copy.default, + ) + if op == torch.ops.aten.group_norm.default: + return ( + torch.ops.aten.mean.dim, + torch.ops.aten.sub.Tensor, + torch.ops.aten.var.correction, + torch.ops.aten.full.default, + torch.ops.aten.add.Tensor, + torch.ops.aten.rsqrt.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.view_copy.default, + ) + raise RuntimeError(f"Can't get group_norm composition for op {op}") + + +class DecomposeGroupNormPass(ArmPass): + """ + groupnorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias + Decompose groupnorm(x, weight, bias, N, C, HxW, group, eps) to a sequence of: + mean = op_mean(x, dims) # E[x] + var = op_var(x, dims) # Var[x] + numerator = op_sub(x, mean) # (x - E[x]) + add = op_add(var, eps) # Var[x] + eps + rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps) + mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) + weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias + where x can viewed with shape [N, group, C//group, HxW] dims=[C//group, HxW] + + Source: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html + """ + + def call(self, graph_module: torch.fx.GraphModule): + modified = False + for node in graph_module.graph.nodes: + if node.op != "call_function" or node.target not in ( + exir_ops.edge.aten.native_group_norm.default, + torch.ops.aten.group_norm.default, + ): + continue + + # epsilon default value + eps = torch.finfo().eps + weights = None + bias = None + args = node.args + meta = node.meta + if isinstance(meta["val"], tuple): + shape = meta["val"][0].size() + dtype = meta["val"][0].dtype + else: + shape = meta["val"].size() + dtype = meta["val"].dtype + match len(args): + # MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps + case 8: + x, weights, bias, N, C, HxW, group, eps = args + # BI profile: affine=[True|False], eps!=1e-5 + case 5: + x, group, weights, bias, eps = args + # BI profile: affine=True, eps=1e-5 + case 4: + x, group, weights, bias = args + # BI profile: affine=False, eps=1e=5 + case 2: + x, group = args + # Unsupported args + case _: + raise ValueError( + f"Unsupported group_norm argument pattern with {len(args)} args" + ) + N = shape[0] + C = shape[1] + HxW = 1 + for dim in shape[2:]: + HxW *= dim + channels_per_group = C // group + grouped_shape = torch.Size([N, group, channels_per_group, HxW]) + dims = [2, 3] + epsilon_reshaped_shape = torch.Size([1] * len(grouped_shape)) + weights_reshaped_shape = torch.Size([1, group, channels_per_group, 1]) + ( + mean_op, + sub_op, + var_op, + full_op, + add_op, + rsqrt_op, + mul_op, + view_op, + ) = get_group_norm_decomposition(node.target) + with graph_module.graph.inserting_before(node): + keepdim = True + x_reshaped = create_node( + graph_module.graph, + view_op, + args=(x, grouped_shape), + from_node=node, + ) + mean = create_node( + graph_module.graph, mean_op, args=(x_reshaped, dims, keepdim) + ) + sub = create_node(graph_module.graph, sub_op, args=(x_reshaped, mean)) + var = create_node( + graph_module.graph, + var_op, + args=(x_reshaped, dims), + kwargs={"correction": 0, "keepdim": keepdim}, + from_node=node, + ) + full = create_node( + graph_module.graph, + full_op, + args=(epsilon_reshaped_shape, eps), + kwargs={"dtype": dtype}, + from_node=node, + ) + add0 = create_node( + graph_module.graph, add_op, args=(var, full), from_node=node + ) + rsqrt = create_node( + graph_module.graph, rsqrt_op, args=(add0,), from_node=node + ) + mul0 = create_node( + graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node + ) + if weights is not None: + weights_reshaped = create_node( + graph_module.graph, + view_op, + args=(weights, weights_reshaped_shape), + from_node=node, + ) + mul1 = create_node( + graph_module.graph, + mul_op, + args=( + mul0, + weights_reshaped, + ), + from_node=node, + ) + else: + mul1 = mul0 + if bias is not None: + bias_reshaped_shape = weights_reshaped_shape + bias_reshaped = create_node( + graph_module.graph, + view_op, + args=(bias, bias_reshaped_shape), + from_node=node, + ) + output = create_node( + graph_module.graph, + add_op, + args=(mul1, bias_reshaped), + from_node=node, + ) + else: + output = mul1 + + output_reshaped = create_node( + graph_module.graph, + view_op, + args=(output, shape), + from_node=node, + ) + + users = [user for user in node.users if node != user] + node.replace_all_uses_with(output_reshaped) + for user in users: + if user.target == operator.getitem: + user.replace_all_uses_with(output_reshaped) + graph_module.graph.erase_node(node) + graph_module.graph.eliminate_dead_code() + modified = True + if modified: + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/decompose_layernorm_pass.py b/backends/arm/_passes/decompose_layernorm_pass.py index a92434faa7d..e6cbdfb91a0 100644 --- a/backends/arm/_passes/decompose_layernorm_pass.py +++ b/backends/arm/_passes/decompose_layernorm_pass.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -47,11 +46,12 @@ class DecomposeLayerNormPass(ArmPass): Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of: mean = op_mean(x, dims) # E[x] var = op_var(x, dims) # Var[x] - denominator = op_sub(x, mean) # (x - E[x]) + numerator = op_sub(x, mean) # (x - E[x]) add = op_add(var, eps) # Var[x] + eps rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps) - mul = op_mul(denominator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths - bias = op_add(mul, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias + mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) + weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html """ diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index f84bde7fadc..3923184a91f 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -188,6 +188,7 @@ def is_node_supported( exir_ops.edge.aten.div.Scalar, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.native_layer_norm.default, + exir_ops.edge.aten.native_group_norm.default, exir_ops.edge.aten.sigmoid.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mm.default, @@ -255,6 +256,7 @@ def is_node_supported( exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten._native_batch_norm_legit_no_training.default, exir_ops.edge.aten.native_layer_norm.default, + exir_ops.edge.aten.native_group_norm.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten._log_softmax.default, diff --git a/backends/arm/scripts/parse_test_names.py b/backends/arm/scripts/parse_test_names.py index 8aabf7c2c59..7377b9ac139 100644 --- a/backends/arm/scripts/parse_test_names.py +++ b/backends/arm/scripts/parse_test_names.py @@ -5,7 +5,7 @@ from executorch.exir.dialects.edge.spec.utils import SAMPLE_INPUT # Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here. -CUSTOM_EDGE_OPS = ["linspace.default", "eye.default"] +CUSTOM_EDGE_OPS = ["linspace.default", "eye.default", "native_group_norm.default"] ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS # Add all targets and TOSA profiles we support here. diff --git a/backends/arm/test/ops/test_group_norm.py b/backends/arm/test/ops/test_group_norm.py new file mode 100644 index 00000000000..0559bed0230 --- /dev/null +++ b/backends/arm/test/ops/test_group_norm.py @@ -0,0 +1,132 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + + +class GroupNorm(torch.nn.Module): + + def __init__( + self, + num_groups: int, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + ): + super().__init__() + self.group_norm = torch.nn.GroupNorm( + num_groups, + num_channels, + eps=eps, + affine=affine, + ) + + def forward( + self, + x: torch.Tensor, + ): + return self.group_norm(x) + + +input_t = tuple[torch.Tensor] +test_data_suite = { + "rand_4_6_groups_1": ((torch.rand(4, 6),), GroupNorm(1, 6)), + "rand_4_6_groups_2": ((torch.rand(4, 6),), GroupNorm(2, 6)), + "rand_4_6_groups_6": ((torch.rand(4, 6),), GroupNorm(6, 6)), + "rand_4_6_8_groups_2_eps_no_affine": ( + (torch.rand(4, 6, 8),), + GroupNorm(2, 6, eps=1e-3, affine=False), + ), + "randn_1_12_8_6_groups_6_eps": ( + (torch.randn(1, 12, 8, 6),), + GroupNorm(6, 12, eps=1e-2), + ), + "randn_1_12_8_6_groups_12": ((torch.randn(1, 12, 8, 6),), GroupNorm(12, 12)), + "rand_6_8_10_12_groups_1": ((torch.rand(6, 8, 10, 12),), GroupNorm(1, 8)), + "rand_6_8_10_12_groups_4_no_affine": ( + (torch.rand(6, 8, 10, 12),), + GroupNorm(4, 8, affine=False), + ), + "rand_6_8_10_12_groups_8": ((torch.rand(6, 8, 10, 12),), GroupNorm(8, 8)), +} + + +@common.parametrize("test_data", test_data_suite) +def test_native_group_norm_tosa_MI(test_data): + aten_op = "torch.ops.aten.group_norm.default" + exir_op = "executorch_exir_dialects_edge__ops_aten_native_group_norm_default" + pipeline = TosaPipelineMI[input_t]( + test_data[1], + test_data[0], + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_native_group_norm_tosa_BI(test_data): + aten_op = "torch.ops.aten.sub.Tensor" # 'sub' op arbitrarily chosen to confirm groupnorm was decomposed + exir_op = "executorch_exir_dialects_edge__ops_aten_native_group_norm_default" + pipeline = TosaPipelineBI[input_t]( + test_data[1], + test_data[0], + aten_op=aten_op, + exir_op=exir_op, + ) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, + xfails={ + "randn_1_12_8_6_groups_12": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", + "rand_6_8_10_12_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", + "rand_6_8_10_12_groups_4_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", + "rand_6_8_10_12_groups_8": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", + }, + strict=False, +) +@common.XfailIfNoCorstone300 +def test_native_group_norm_u55_BI(test_data): + pipeline = EthosU55PipelineBI[input_t]( + test_data[1], + test_data[0], + "torch.ops.aten.sub.Tensor", # 'sub' op arbitrarily chosen to confirm groupnorm was decomposed + run_on_fvp=True, + ) + pipeline.change_args("run_method_and_compare_outputs", atol=1, qtol=1) + pipeline.run() + + +@common.parametrize( + "test_data", + test_data_suite, + xfails={ + "randn_1_12_8_6_groups_12": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", + "rand_6_8_10_12_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", + "rand_6_8_10_12_groups_4_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", + "rand_6_8_10_12_groups_8": "MLETORCH-925: Fix numerical issue for aten.native_group_norm", + }, + strict=False, +) +@common.XfailIfNoCorstone320 +def test_native_group_norm_u85_BI(test_data): + pipeline = EthosU85PipelineBI[input_t]( + test_data[1], + test_data[0], + "torch.ops.aten.sub.Tensor", # 'sub' op arbitrarily chosen to confirm groupnorm was decomposed + run_on_fvp=True, + ) + pipeline.change_args("run_method_and_compare_outputs", atol=1, qtol=1) + pipeline.run()