Skip to content

Commit f2016da

Browse files
committed
Arm backend: Add TOSA dialect FFT node visitors
Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com> Change-Id: I16991695a3f1d8828e4c17b3f1d4b4b32157298d
1 parent a8cfb75 commit f2016da

6 files changed

Lines changed: 168 additions & 2 deletions

File tree

backends/arm/_passes/aten_to_tosa_tensor_operators.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ def rewrite_argmax(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec:
2626
)
2727

2828

29+
def rewrite_rfft2(node: Node, pass_: AtenToDialectPass) -> DialectNodeSpec | None:
30+
fft_size = node.args[1] if len(node.args) > 1 else node.kwargs.get("s")
31+
fft_dims = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", [-2, -1])
32+
norm = node.args[3] if len(node.args) > 3 else node.kwargs.get("norm")
33+
if fft_size is not None or fft_dims not in ([-2, -1], (-2, -1)) or norm is not None:
34+
return None
35+
36+
return DialectNodeSpec(
37+
exir_ops.backend.tosa.RFFT2D.default,
38+
(node.args[0],),
39+
{},
40+
)
41+
42+
2943
def rewrite_binary_operator(
3044
node: Node, pass_: AtenToDialectPass
3145
) -> DialectNodeSpec | None:

backends/arm/_passes/exir_to_tosa_pass.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from executorch.backends.arm._passes.aten_to_tosa_tensor_operators import (
1313
rewrite_argmax,
1414
rewrite_binary_operator,
15+
rewrite_rfft2,
1516
)
1617
from executorch.backends.transforms.aten_to_dialect_pass import (
1718
AtenToDialectPass,
@@ -43,13 +44,24 @@ def decorator(func: SubstitutionFn) -> SubstitutionFn:
4344
return decorator
4445

4546

46-
@ExirToTosaPass.register_dialect_substitution(exir_ops.edge.aten.argmax.default)
47+
@register_dialect_substitutions(
48+
exir_ops.edge.aten.argmax.default,
49+
)
4750
def _get_tensor_operators_replacement(
4851
node: Node, pass_: AtenToDialectPass
49-
) -> DialectNodeSpec:
52+
) -> DialectNodeSpec | None:
5053
return rewrite_argmax(node, pass_)
5154

5255

56+
@register_dialect_substitutions(
57+
exir_ops.edge.aten.fft_rfft2.default,
58+
)
59+
def _get_fft_replacement(
60+
node: Node, pass_: AtenToDialectPass
61+
) -> DialectNodeSpec | None:
62+
return rewrite_rfft2(node, pass_)
63+
64+
5365
@register_dialect_substitutions(
5466
exir_ops.edge.aten.add.Tensor,
5567
exir_ops.edge.aten.bitwise_and.Tensor,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
op_tosa_depthwise_conv2d,
4747
op_tosa_eq,
4848
op_tosa_erf,
49+
op_tosa_fft,
4950
op_tosa_gather,
5051
op_tosa_ge,
5152
op_tosa_gt,
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, List
7+
8+
import torch.fx
9+
import tosa_serializer as ts
10+
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.operators.operator_validation_utils import (
16+
validate_num_inputs,
17+
validate_same_dtype,
18+
validate_valid_dtype,
19+
)
20+
from executorch.backends.arm.tosa.mapping import TosaArg
21+
22+
23+
@register_node_visitor
24+
class FFT2dVisitor(NodeVisitor):
25+
target = "tosa.FFT2D.default"
26+
27+
def define_node(
28+
self,
29+
node: torch.fx.Node,
30+
tosa_graph: Any,
31+
inputs: List[TosaArg],
32+
output: TosaArg,
33+
) -> None:
34+
validate_num_inputs(self.target, inputs, 2)
35+
validate_same_dtype(self.target, inputs, ts)
36+
validate_valid_dtype(self.target, inputs, ts.DType.FP32, self.tosa_spec)
37+
38+
attr = ts.TosaSerializerAttribute()
39+
attr.FFT2dAttribute(
40+
node.kwargs.get("inverse", False),
41+
node.kwargs.get("local_bound", False),
42+
)
43+
self._serialize_operator(
44+
node,
45+
tosa_graph,
46+
ts.Op.FFT2D,
47+
[inputs[0].name, inputs[1].name],
48+
output.multiple_output_names,
49+
attr,
50+
)
51+
52+
53+
@register_node_visitor
54+
class RFFT2dVisitor(NodeVisitor):
55+
target = "tosa.RFFT2D.default"
56+
57+
def define_node(
58+
self,
59+
node: torch.fx.Node,
60+
tosa_graph: Any,
61+
inputs: List[TosaArg],
62+
output: TosaArg,
63+
) -> None:
64+
validate_num_inputs(self.target, inputs, 1)
65+
validate_valid_dtype(self.target, inputs, ts.DType.FP32, self.tosa_spec)
66+
67+
attr = ts.TosaSerializerAttribute()
68+
attr.RFFT2dAttribute(node.kwargs.get("local_bound", False))
69+
self._serialize_operator(
70+
node,
71+
tosa_graph,
72+
ts.Op.RFFT2D,
73+
[inputs[0].name],
74+
output.multiple_output_names,
75+
attr,
76+
)

backends/arm/test/ops/test_fft.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
TosaPipelineFP,
12+
VgfPipeline,
13+
)
14+
15+
aten_op = "torch.ops.aten.fft_rfft2.default"
16+
exir_op = "executorch_exir_dialects_edge__ops_aten_fft_rfft2_default"
17+
18+
input_t1 = Tuple[torch.Tensor]
19+
20+
21+
class RFFT2D(torch.nn.Module):
22+
test_parameters = {
23+
"rank2": lambda: (torch.randn(8, 16),),
24+
"rank3": lambda: (torch.randn(2, 8, 16),),
25+
"rank4": lambda: (torch.randn(1, 2, 8, 16),),
26+
"ones": lambda: (torch.ones(2, 8, 16),),
27+
"zeros": lambda: (torch.zeros(2, 8, 16),),
28+
}
29+
30+
def forward(self, x: torch.Tensor):
31+
output = torch.fft.rfft2(x)
32+
return output.real, output.imag
33+
34+
35+
@common.parametrize("test_data", RFFT2D.test_parameters)
36+
def test_rfft2d_tosa_FP(test_data: input_t1):
37+
pipeline = TosaPipelineFP[input_t1](
38+
RFFT2D(),
39+
test_data(),
40+
aten_op,
41+
exir_op,
42+
run_on_tosa_ref_model=False,
43+
tosa_version="1.1",
44+
tosa_extensions=["fft"],
45+
)
46+
pipeline.run()
47+
48+
49+
@common.parametrize("test_data", RFFT2D.test_parameters)
50+
@common.SkipIfNoModelConverter
51+
def test_rfft2d_vgf_no_quant(test_data: input_t1):
52+
pipeline = VgfPipeline[input_t1](
53+
RFFT2D(),
54+
test_data(),
55+
aten_op,
56+
exir_op,
57+
run_on_vulkan_runtime=False,
58+
quantize=False,
59+
tosa_version="TOSA-1.1+FP",
60+
tosa_extensions=["fft"],
61+
)
62+
pipeline.run()

backends/arm/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def define_arm_tests():
3737
"ops/test_cos.py",
3838
"ops/test_to_copy.py",
3939
"ops/test_exp.py",
40+
"ops/test_fft.py",
4041
"ops/test_reciprocal.py",
4142
"ops/test_mean_dim.py",
4243
"ops/test_var.py",

0 commit comments

Comments
 (0)