-
Notifications
You must be signed in to change notification settings - Fork 527
/
Copy pathdecompose_layernorm_pass.py
177 lines (163 loc) · 6.81 KB
/
decompose_layernorm_pass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
import operator
import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import PassResult
def get_layer_norm_decomposition(op) -> tuple:
if op == exir_ops.edge.aten.native_layer_norm.default:
return (
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.var.correction,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.view_copy.default,
)
if op == torch.ops.aten.layer_norm.default:
return (
torch.ops.aten.mean.dim,
torch.ops.aten.sub.Tensor,
torch.ops.aten.var.correction,
torch.ops.aten.full.default,
torch.ops.aten.add.Tensor,
torch.ops.aten.rsqrt.default,
torch.ops.aten.mul.Tensor,
torch.ops.aten.view_copy.default,
)
raise RuntimeError(f"Can't get layer_norm composition for op {op}")
class DecomposeLayerNormPass(ArmPass):
"""
layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
mean = op_mean(x, dims) # E[x]
var = op_var(x, dims) # Var[x]
numerator = op_sub(x, mean) # (x - E[x])
add = op_add(var, eps) # Var[x] + eps
rsqrt = op_rsqrt(add) # 1 / sqrt(Var[x] + eps)
mul = op_mul(numerator, rsqrt) # ((x - E[x]) / sqrt(Var[x] + eps))
weigths = op_mul(mul, weigths) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths
bias = op_add(weigths, bias) # ((x - E[x]) / sqrt(Var[x] + eps)) * weigths + bias
Source: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html
"""
def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in (
exir_ops.edge.aten.native_layer_norm.default,
torch.ops.aten.layer_norm.default,
):
continue
# epsilon default value
epsilon = torch.finfo().eps
weights = None
bias = None
args = node.args
meta = node.meta
match len(args):
case 5:
x, normalized_shape, weights, bias, epsilon = args
case 4:
x, normalized_shape, weights, bias = args
case 3:
x, normalized_shape, weights = args
case _:
x, normalized_shape = args
n_dims = len(normalized_shape)
if isinstance(meta["val"], tuple):
shape = meta["val"][0].size()
dtype = meta["val"][0].dtype
else:
shape = meta["val"].size()
dtype = meta["val"].dtype
rank = len(shape)
dims = list(range(-1, -1 * (n_dims + 1), -1))
dims = [dim % rank for dim in dims]
weights_reshaped_shape = [shape[i] if i in dims else 1 for i in range(rank)]
epsilon_reshaped_shape = [1] * rank
(
mean_op,
sub_op,
var_op,
full_op,
add_op,
rsqrt_op,
mul_op,
view_op,
) = get_layer_norm_decomposition(node.target)
with graph_module.graph.inserting_before(node):
keepdim = True
mean = create_node(graph_module.graph, mean_op, args=(x, dims, keepdim))
sub = create_node(graph_module.graph, sub_op, args=(x, mean))
var = create_node(
graph_module.graph,
var_op,
args=(x, dims),
kwargs={"correction": 0, "keepdim": keepdim},
from_node=node,
)
full = create_node(
graph_module.graph,
full_op,
args=(epsilon_reshaped_shape, epsilon),
kwargs={"dtype": dtype},
from_node=node,
)
add0 = create_node(
graph_module.graph, add_op, args=(var, full), from_node=node
)
rsqrt = create_node(
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
)
mul0 = create_node(
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
)
if weights is not None:
weights_reshaped = create_node(
graph_module.graph,
view_op,
args=(weights, weights_reshaped_shape),
from_node=node,
)
mul1 = create_node(
graph_module.graph,
mul_op,
args=(
mul0,
weights_reshaped,
),
from_node=node,
)
else:
mul1 = mul0
output = mul1
if bias is not None:
bias_reshaped_shape = weights_reshaped_shape
bias_reshaped = create_node(
graph_module.graph,
view_op,
args=(bias, bias_reshaped_shape),
from_node=node,
)
output = create_node(
graph_module.graph,
add_op,
args=(mul1, bias_reshaped),
from_node=node,
)
users = [user for user in node.users if node != user]
node.replace_all_uses_with(output)
for user in users:
if user.target == operator.getitem:
user.replace_all_uses_with(output)
graph_module.graph.erase_node(node)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)