Skip to content

Arm backend: Add TOSA support for GroupNorm #10198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
from .decompose_div_pass import DecomposeDivPass # noqa
from .decompose_gelu_pass import DecomposeGeluPass # noqa
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
from .decompose_linear_pass import DecomposeLinearPass # noqa
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DecomposeBatchNormPass,
DecomposeDivPass,
DecomposeGeluPass,
DecomposeGroupNormPass,
DecomposeLayerNormPass,
DecomposeLeakyReLUPass,
DecomposeLinearPass,
Expand Down Expand Up @@ -127,6 +128,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
self.add_pass(DecomposeLinearPass())
self.add_pass(DecomposeLeakyReLUPass())
self.add_pass(DecomposeBatchNormPass())
self.add_pass(DecomposeGroupNormPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeVarPass())
self.add_pass(DecomposeMeanDimPass())
Expand Down Expand Up @@ -180,6 +182,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeGroupNormPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeVarPass())
self.add_pass(DecomposeMeanDimPass())
Expand Down
208 changes: 208 additions & 0 deletions backends/arm/_passes/decompose_groupnorm_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import operator

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult


def get_group_norm_decomposition(op) -> tuple:
if op == exir_ops.edge.aten.native_group_norm.default:
return (
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.var.correction,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.view_copy.default,
)
if op == torch.ops.aten.group_norm.default:
return (
torch.ops.aten.mean.dim,
torch.ops.aten.sub.Tensor,
torch.ops.aten.var.correction,
torch.ops.aten.full.default,
torch.ops.aten.add.Tensor,
torch.ops.aten.rsqrt.default,
torch.ops.aten.mul.Tensor,
torch.ops.aten.view_copy.default,
)
raise RuntimeError(f"Can't get group_norm composition for op {op}")


class DecomposeGroupNormPass(ArmPass):
"""
groupnorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
Decompose groupnorm(x, weight, bias, N, C, HxW, group, eps) to a sequence of:
mean = op_mean(x, dims) # E[x]
var = op_var(x, dims) # Var[x]
numerator = op_sub(x, mean) # (x - E[x])
add = op_add(var, eps) # Var[x] + eps
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps))
weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
where x can viewed with shape [N, group, C//group, HxW] dims=[C//group, HxW]

Source: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html
"""

def call(self, graph_module: torch.fx.GraphModule):
modified = False
for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in (
exir_ops.edge.aten.native_group_norm.default,
torch.ops.aten.group_norm.default,
):
continue

# epsilon default value
eps = torch.finfo().eps
weights = None
bias = None
args = node.args
meta = node.meta
if isinstance(meta["val"], tuple):
shape = meta["val"][0].size()
dtype = meta["val"][0].dtype
else:
shape = meta["val"].size()
dtype = meta["val"].dtype
match len(args):
# MI profile always provides all the args: x, weight, bias, N, C, HxW, group, eps
case 8:
x, weights, bias, N, C, HxW, group, eps = args
# BI profile: affine=[True|False], eps!=1e-5
case 5:
x, group, weights, bias, eps = args
# BI profile: affine=True, eps=1e-5
case 4:
x, group, weights, bias = args
# BI profile: affine=False, eps=1e=5
case 2:
x, group = args
# Unsupported args
case _:
raise ValueError(
f"Unsupported group_norm argument pattern with {len(args)} args"
)
N = shape[0]
C = shape[1]
HxW = 1
for dim in shape[2:]:
HxW *= dim
channels_per_group = C // group
grouped_shape = torch.Size([N, group, channels_per_group, HxW])
dims = [2, 3]
epsilon_reshaped_shape = torch.Size([1] * len(grouped_shape))
weights_reshaped_shape = torch.Size([1, group, channels_per_group, 1])
(
mean_op,
sub_op,
var_op,
full_op,
add_op,
rsqrt_op,
mul_op,
view_op,
) = get_group_norm_decomposition(node.target)
with graph_module.graph.inserting_before(node):
keepdim = True
x_reshaped = create_node(
graph_module.graph,
view_op,
args=(x, grouped_shape),
from_node=node,
)
mean = create_node(
graph_module.graph, mean_op, args=(x_reshaped, dims, keepdim)
)
sub = create_node(graph_module.graph, sub_op, args=(x_reshaped, mean))
var = create_node(
graph_module.graph,
var_op,
args=(x_reshaped, dims),
kwargs={"correction": 0, "keepdim": keepdim},
from_node=node,
)
full = create_node(
graph_module.graph,
full_op,
args=(epsilon_reshaped_shape, eps),
kwargs={"dtype": dtype},
from_node=node,
)
add0 = create_node(
graph_module.graph, add_op, args=(var, full), from_node=node
)
rsqrt = create_node(
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
)
mul0 = create_node(
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
)
if weights is not None:
weights_reshaped = create_node(
graph_module.graph,
view_op,
args=(weights, weights_reshaped_shape),
from_node=node,
)
mul1 = create_node(
graph_module.graph,
mul_op,
args=(
mul0,
weights_reshaped,
),
from_node=node,
)
else:
mul1 = mul0
if bias is not None:
bias_reshaped_shape = weights_reshaped_shape
bias_reshaped = create_node(
graph_module.graph,
view_op,
args=(bias, bias_reshaped_shape),
from_node=node,
)
output = create_node(
graph_module.graph,
add_op,
args=(mul1, bias_reshaped),
from_node=node,
)
else:
output = mul1

output_reshaped = create_node(
graph_module.graph,
view_op,
args=(output, shape),
from_node=node,
)

users = [user for user in node.users if node != user]
node.replace_all_uses_with(output_reshaped)
for user in users:
if user.target == operator.getitem:
user.replace_all_uses_with(output_reshaped)
graph_module.graph.erase_node(node)
graph_module.graph.eliminate_dead_code()
modified = True
if modified:
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
8 changes: 4 additions & 4 deletions backends/arm/_passes/decompose_layernorm_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -47,11 +46,12 @@ class DecomposeLayerNormPass(ArmPass):
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
mean = op_mean(x, dims) # E[x]
var = op_var(x, dims) # Var[x]
denominator = op_sub(x, mean) # (x - E[x])
numerator = op_sub(x, mean) # (x - E[x])
add = op_add(var, eps) # Var[x] + eps
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
mul = op_mul(denominator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
bias = op_add(mul, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps))
weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias

Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
"""
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def is_node_supported(
exir_ops.edge.aten.div.Scalar,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
exir_ops.edge.aten.native_group_norm.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.mm.default,
Expand Down Expand Up @@ -255,6 +256,7 @@ def is_node_supported(
exir_ops.edge.aten.div.Tensor,
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
exir_ops.edge.aten.native_layer_norm.default,
exir_ops.edge.aten.native_group_norm.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten._log_softmax.default,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/scripts/parse_test_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from executorch.exir.dialects.edge.spec.utils import SAMPLE_INPUT

# Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here.
CUSTOM_EDGE_OPS = ["linspace.default", "eye.default"]
CUSTOM_EDGE_OPS = ["linspace.default", "eye.default", "native_group_norm.default"]
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS

# Add all targets and TOSA profiles we support here.
Expand Down
Loading
Loading