Skip to content

Commit 4c54bab

Browse files
authored
Arm backend: Add initial Llama model test case (#8679)
Adds Llama model test case for TOSA-0.80+MI. Handles Add and Mul where inputs have different ranks. New unit test parameters --llama_inputs added, without it test will be skipped. Tested with smaller stories, see examples/models/llama/UTILS.md. Adds get_llama_model() to export_llama_lib used in test case.
1 parent e433e61 commit 4c54bab

File tree

14 files changed

+368
-26
lines changed

14 files changed

+368
-26
lines changed

backends/arm/operator_support/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
pool_2d_support,
1212
reduce_sum_support,
1313
right_shift_support,
14+
slice_copy_support,
1415
to_copy_support,
1516
tosa_supported_operators,
1617
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import logging
8+
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import TosaSpecification
15+
from executorch.backends.arm.tosa_utils import getNodeArgs
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
18+
logger = logging.getLogger(__name__)
19+
logger.setLevel(logging.WARNING)
20+
21+
22+
@register_tosa_support_check
23+
class SliceCopySupported(SupportedTOSAOperatorCheck):
24+
targets = [exir_ops.edge.aten.slice_copy.Tensor]
25+
26+
tosa_specs = [
27+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
28+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
29+
]
30+
31+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc]
32+
if tosa_spec not in self.tosa_specs:
33+
return False
34+
35+
inputs = getNodeArgs(node)
36+
if len(inputs) == 5 and (step := inputs[4].number) != 1:
37+
logging.warning(f"{node.target} with step size of {step} not supported.")
38+
return False
39+
return True

backends/arm/operator_support/tosa_supported_operators.py

-2
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
7575
def get_registered_tosa_support_checks(
7676
tosa_spec: TosaSpecification,
7777
) -> list[Type[SupportedTOSAOperatorCheck]]:
78-
7978
if tosa_spec not in _tosa_spec_support:
8079
raise RuntimeError(
8180
f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}"
@@ -165,7 +164,6 @@ def is_node_supported(
165164
exir_ops.edge.aten._softmax.default,
166165
exir_ops.edge.aten.select_copy.int,
167166
exir_ops.edge.aten._log_softmax.default,
168-
exir_ops.edge.aten.slice_copy.Tensor,
169167
exir_ops.edge.aten.sub.Tensor,
170168
exir_ops.edge.aten.tanh.default,
171169
exir_ops.edge.aten.upsample_nearest2d.vec,

backends/arm/operators/op_add.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def define_node(
4545
# Handle int8 (quantized) and int32
4646
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
4747

48+
dim_order = (
49+
inputs[0].dim_order
50+
if len(inputs[0].shape) > len(inputs[1].shape)
51+
else inputs[1].dim_order
52+
)
53+
4854
if inputs[0].dtype == ts.DType.INT8:
4955
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5056
tosa_graph, inputs, node
@@ -61,13 +67,14 @@ def define_node(
6167
# output.dtype == ts.DType.INT32
6268
add_output = output
6369

70+
input1, input2 = tutils.reshape_for_broadcast(
71+
tosa_graph, rescaled_inputs, dim_order
72+
)
73+
6474
# Do the INT32 Add
6575
tosa_graph.addOperator(
6676
TosaOp.Op().ADD,
67-
[
68-
rescaled_inputs[0].name,
69-
rescaled_inputs[1].name,
70-
],
77+
[input1.name, input2.name],
7178
[add_output.name],
7279
None,
7380
)
@@ -108,10 +115,12 @@ def define_node(
108115
assert inputs[0].dtype == ts.DType.FP32
109116
assert output.dtype == ts.DType.FP32
110117

118+
input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)
119+
111120
# MI lowering
112121
tosa_graph.addOperator(
113122
TosaOp.Op().ADD,
114-
[inputs[0].name, inputs[1].name],
123+
[input1.name, input2.name],
115124
[output.name],
116125
None,
117126
)

backends/arm/operators/op_mul.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from executorch.backends.arm.tosa_mapping import TosaArg
2626
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.arm.tosa_utils import reshape_for_broadcast
2728
from serializer.tosa_serializer import TosaOp
2829

2930

@@ -43,6 +44,12 @@ def define_node(
4344
output: TosaArg,
4445
) -> None:
4546
assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8
47+
48+
dim_order = (
49+
inputs[0].dim_order
50+
if len(inputs[0].shape) > len(inputs[1].shape)
51+
else inputs[1].dim_order
52+
)
4653
input_A = inputs[0]
4754
input_B = inputs[1]
4855
input_qparams = get_input_qparams(node) # pyre-ignore[16]
@@ -68,15 +75,21 @@ def define_node(
6875
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
6976
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
7077

78+
input1, input2 = tutils.reshape_for_broadcast(
79+
tosa_graph,
80+
[
81+
input_A_rescaled,
82+
input_B_rescaled,
83+
],
84+
dim_order,
85+
)
86+
7187
# Do the INT32 Mul
7288
attr = ts.TosaSerializerAttribute()
7389
attr.MulAttribute(shift=0)
7490
tosa_graph.addOperator(
7591
TosaOp.Op().MUL,
76-
[
77-
input_A_rescaled.name,
78-
input_B_rescaled.name,
79-
],
92+
[input1.name, input2.name],
8093
[mul_output.name],
8194
attr,
8295
)
@@ -101,8 +114,11 @@ def define_node(
101114
) -> None:
102115
if inputs[0].dtype == ts.DType.INT8:
103116
return super().define_node(node, tosa_graph, inputs, output)
117+
118+
input1, input2 = reshape_for_broadcast(tosa_graph, inputs)
119+
104120
attr = ts.TosaSerializerAttribute()
105121
attr.MulAttribute(shift=0)
106122
tosa_graph.addOperator(
107-
TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr
123+
TosaOp.Op().MUL, [input1.name, input2.name], [output.name], attr
108124
)

backends/arm/operators/op_slice.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,12 @@ def define_node(
3232
output: TosaArg,
3333
) -> None:
3434

35+
# See slice_copy_support.py
36+
if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)):
37+
raise ValueError("Unsupported combination of inputs")
38+
3539
# aten.slice_copy supports slicing in 1d at a time.
36-
# The arguments are dimension of slicing, start index and end index.
37-
assert len(inputs) == 4
40+
# The arguments are the actual input, dimension of slicing, start index, end index and optinal step or stride.
3841
input_node, dim, start, end = inputs
3942

4043
# Translate and check parameters in Pytorch dim order.

backends/arm/test/conftest.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def pytest_configure(config):
4444
)
4545
# Only enable if we also have the TOSA reference model available.
4646
pytest._test_options["corstone_fvp"] = True # type: ignore[attr-defined]
47-
47+
pytest._test_options["llama_inputs"] = config.option.llama_inputs # type: ignore[attr-defined]
4848
pytest._test_options["fast_fvp"] = False # type: ignore[attr-defined]
4949
if getattr(config.option, "fast_fvp", False):
5050
pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined]
@@ -70,6 +70,11 @@ def try_addoption(*args, **kwargs):
7070
try_addoption("--arm_quantize_io", action="store_true", help="Deprecated.")
7171
try_addoption("--arm_run_corstoneFVP", action="store_true", help="Deprecated.")
7272
try_addoption("--fast_fvp", action="store_true")
73+
try_addoption(
74+
"--llama_inputs",
75+
nargs="+",
76+
help="List of two files. Firstly .pt file. Secondly .json",
77+
)
7378

7479

7580
def pytest_sessionstart(session):
+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
10+
import os
11+
import sys
12+
import unittest
13+
14+
import torch
15+
16+
from executorch.backends.arm.test import common, conftest
17+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
18+
from executorch.examples.models.llama.export_llama_lib import (
19+
build_args_parser,
20+
get_llama_model,
21+
)
22+
23+
24+
# Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py
25+
this_files_dir = os.path.dirname(os.path.abspath(__file__))
26+
project_dir = os.path.abspath(os.path.join(this_files_dir, "../../../.."))
27+
sys.path.append(project_dir)
28+
29+
logger = logging.getLogger(__name__)
30+
logger.setLevel(logging.INFO)
31+
32+
33+
class TestLlama(unittest.TestCase):
34+
"""
35+
Test class of Llama models. Type of Llama model depends on command line parameters:
36+
--llama_inputs <path to .pt file> <path to json file>
37+
Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json
38+
"""
39+
40+
def prepare_model(self):
41+
42+
checkpoint = None
43+
params_file = None
44+
if conftest.is_option_enabled("llama_inputs"):
45+
param_list = conftest.get_option("llama_inputs")
46+
assert (
47+
isinstance(param_list, list) and len(param_list) == 2
48+
), "invalid number of inputs for --llama_inputs"
49+
checkpoint = param_list[0]
50+
params_file = param_list[1]
51+
assert isinstance(checkpoint, str) and isinstance(
52+
params_file, str
53+
), "invalid input for --llama_inputs"
54+
else:
55+
logging.warning(
56+
"Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json>"
57+
)
58+
return None, None, None
59+
60+
assert os.path.isfile(checkpoint) and os.path.isfile(
61+
params_file
62+
), "Invalid file paths"
63+
64+
# TODO: Enable key value cache
65+
args = [
66+
"--disable_dynamic_shape",
67+
"-c",
68+
checkpoint,
69+
"-p",
70+
params_file,
71+
"--model",
72+
"stories110m",
73+
]
74+
parser = build_args_parser()
75+
args = parser.parse_args(args)
76+
77+
llama_model, llama_inputs, llama_meta = get_llama_model(args)
78+
79+
# TODO: Remove workaround since attention mask should not be persistent,
80+
# it only works if input shape is always the same
81+
freqs_c = "freqs_cos"
82+
freqs_s = "freqs_sin"
83+
for i in range(llama_model.n_layers):
84+
val = llama_model.layers[i].attention.get_buffer("mask")
85+
llama_model.layers[i].attention.register_buffer(
86+
"mask", val, persistent=True
87+
)
88+
val = llama_model.layers[i].attention.rope.get_buffer(freqs_c)
89+
llama_model.layers[i].attention.rope.register_buffer(
90+
freqs_c, val, persistent=True
91+
)
92+
val = llama_model.layers[i].attention.rope.get_buffer(freqs_s)
93+
llama_model.layers[i].attention.rope.register_buffer(
94+
freqs_s, val, persistent=True
95+
)
96+
97+
return llama_model, llama_inputs, llama_meta
98+
99+
def test_llama_tosa_MI(self):
100+
llama_model, llama_inputs, llama_meta = self.prepare_model()
101+
102+
if llama_model is None and llama_inputs is None and llama_meta is None:
103+
return
104+
105+
with torch.no_grad():
106+
(
107+
ArmTester(
108+
llama_model,
109+
example_inputs=llama_inputs,
110+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
111+
constant_methods=llama_meta,
112+
)
113+
.export()
114+
.to_edge_transform_and_lower()
115+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 14})
116+
.to_executorch()
117+
.run_method_and_compare_outputs(
118+
inputs=llama_inputs, atol=1.8, rtol=0.01 # TODO: decrease tolerance
119+
)
120+
)

backends/arm/test/ops/test_add.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
98
from typing import Tuple
109

1110
import torch
@@ -61,6 +60,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
6160
}
6261

6362

63+
class Add3(torch.nn.Module):
64+
def forward(self, x: torch.Tensor, y: torch.Tensor):
65+
return x + y
66+
67+
test_data: list[input_t2] = {
68+
"3d_randn_diff_rank": (torch.randn(1, 4, 5), torch.randn(4, 1)),
69+
"4d_randn_diff_rank": (torch.randn(1, 1, 4, 4), torch.randn(4, 1)),
70+
"4d_randn_diff_rank_2": (torch.randn(4, 1), torch.randn(1, 1, 4, 5)),
71+
}
72+
73+
6474
@common.parametrize("test_data", Add.test_data)
6575
def test_add_tosa_MI(test_data: input_t1):
6676
pipeline = TosaPipelineMI[input_t1](Add(), test_data, aten_op, exir_op)
@@ -129,6 +139,18 @@ def test_add_2_tosa_MI(test_data: input_t2):
129139
pipeline.run()
130140

131141

142+
@common.parametrize("test_data", Add3.test_data)
143+
def test_add3_tosa_MI(test_data: input_t2):
144+
pipeline = TosaPipelineMI[input_t2](Add3(), test_data, aten_op, exir_op)
145+
pipeline.run()
146+
147+
148+
@common.parametrize("test_data", Add3.test_data)
149+
def test_add3_tosa_BI(test_data: input_t2):
150+
pipeline = TosaPipelineBI[input_t2](Add3(), test_data, aten_op, exir_op)
151+
pipeline.run()
152+
153+
132154
@common.parametrize("test_data", Add2.test_data)
133155
def test_add_2_tosa_BI(test_data: input_t2):
134156
pipeline = TosaPipelineBI[input_t2](Add2(), test_data, aten_op, exir_op)

0 commit comments

Comments
 (0)