Skip to content

Commit ce774b0

Browse files
JakeStevensfacebook-github-bot
authored andcommitted
Allow context-binary lowering to use edge dialect ops
Summary: Unblocks the QNN context-binary path from lowering through `to_edge` with `_use_edge_ops=True` (the default). Previously it was pinned to `EdgeCompileConfig(_use_edge_ops=False)` purely to keep the `qaisw` context-loader custom op's original name, because loader detection was name-based. Loader detection now goes through a single `is_context_loader_target()` helper that matches the op namespace (so it works on the edge-dialect wrapper), replacing the three name-dependent checks (`eval`, raw `.namespace`, substring). Differential Revision: D109598309
1 parent 5a920c3 commit ce774b0

5 files changed

Lines changed: 178 additions & 18 deletions

File tree

backends/qualcomm/builders/qnn_constants.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from dataclasses import dataclass
88
from enum import IntEnum, unique
99

10+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
11+
from executorch.exir.operator.convert import parse_qualified_opname, unwrap_op_overload
12+
from torch._ops import OpOverload
13+
1014
QNN_OP_PACKAGE_NAME_QTI_AISW = "qti.aisw"
1115

1216
# Below constants should be same as those in QNN headers.
@@ -57,6 +61,32 @@ class OpContextLoader:
5761
meta_ctx_bin: str = "qnn_context_binary"
5862

5963

64+
ContextLoaderTarget = EdgeOpOverload | OpOverload
65+
66+
67+
def is_context_loader_target(
68+
target: ContextLoaderTarget,
69+
op_name: str | None = None,
70+
) -> bool:
71+
namespace, name = parse_qualified_opname(
72+
str(unwrap_op_overload(target)._schema.name)
73+
)
74+
if namespace != OpContextLoader.namespace:
75+
return False
76+
if op_name is None:
77+
return True
78+
return name == op_name
79+
80+
81+
def is_context_loader_node(node: object, op_name: str | None = None) -> bool:
82+
if getattr(node, "op", None) != "call_function":
83+
return False
84+
target = getattr(node, "target", None)
85+
if not isinstance(target, (EdgeOpOverload, OpOverload)):
86+
return False
87+
return is_context_loader_target(target, op_name)
88+
89+
6090
@dataclass(init=False, frozen=True)
6191
class OpConv2d:
6292
op_name: str = "Conv2d"

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from executorch.backends.qualcomm.builders import node_visitor_manager
13-
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
13+
from executorch.backends.qualcomm.builders.qnn_constants import is_context_loader_node
1414
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
1515
from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
1616
flatbuffer_to_option,
@@ -95,7 +95,7 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
9595
if (
9696
node.target in allow_list_operator
9797
# bypass if custom op appears
98-
or OpContextLoader.namespace == node.target.namespace
98+
or is_context_loader_node(node)
9999
# bypass dequantize op for parameters & buffers
100100
or node.meta.get(QCOM_BYPASS_NODE, False)
101101
):

backends/qualcomm/qnn_preprocess.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
get_qnn_pass_manager_cls,
1414
)
1515
from executorch.backends.qualcomm.builders.node_visitor_manager import get_node_visitors
16-
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
16+
from executorch.backends.qualcomm.builders.qnn_constants import (
17+
is_context_loader_node,
18+
OpContextLoader,
19+
)
1720
from executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option
1821
from executorch.backends.qualcomm.serialization.qc_schema import (
1922
QnnExecuTorchBackendType,
@@ -89,16 +92,12 @@ def _build_op_wrappers(
8992
f"For {node}, {node.op}:{node.target.__name__} "
9093
"is not supported in Qnn Delegate"
9194
)
92-
try:
93-
context_loader_target = eval(
94-
f"torch.ops.{OpContextLoader.namespace}.{node.target.__name__}",
95-
globals().update(torch.__dict__),
96-
)
97-
assert node.target == context_loader_target, err_msg
98-
# if graph has context binary loader node, return directly
95+
if (
96+
is_context_loader_node(node)
97+
and OpContextLoader.meta_ctx_bin in node.meta
98+
):
9999
return node.meta[OpContextLoader.meta_ctx_bin]
100-
except:
101-
raise RuntimeError(err_msg)
100+
raise RuntimeError(err_msg)
102101

103102
elif node.op in [
104103
"get_attr",

backends/qualcomm/tests/test_passes.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@
1313
from executorch.backends.qualcomm._passes.qnn_pass_manager import (
1414
get_qnn_pass_manager_cls,
1515
)
16+
from executorch.backends.qualcomm.builders.qnn_constants import (
17+
is_context_loader_node,
18+
is_context_loader_target,
19+
OpContextLoader,
20+
)
21+
from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner
22+
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
1623
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype
1724
from executorch.backends.qualcomm.serialization.qc_schema import (
1825
QcomChipset,
@@ -28,13 +35,136 @@
2835
generate_qnn_executorch_compiler_spec,
2936
to_edge_transform_and_lower_to_qnn,
3037
)
31-
from executorch.exir import to_edge
38+
from executorch.exir import EdgeCompileConfig, to_edge
3239
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
3340
from executorch.exir.dialects._ops import ops as exir_ops
41+
from torch.library import Library
3442
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
3543

3644

3745
class TestPasses(unittest.TestCase):
46+
def test_context_loader_edge_op_is_delegated(self):
47+
op_name = "ctx_loader_delegation"
48+
graph_name = "forward"
49+
ctx_bin = b"qnn_context_binary"
50+
custom_op = Library(OpContextLoader.namespace, "FRAGMENT")
51+
self.addCleanup(custom_op._destroy)
52+
custom_op.define(f"{op_name}(Tensor[] inputs) -> Any")
53+
54+
@torch.library.impl(
55+
custom_op, op_name, dispatch_key="CompositeExplicitAutograd"
56+
)
57+
def op_impl(inputs):
58+
return (torch.zeros((1, 2), device="meta", dtype=inputs[0].dtype),)
59+
60+
class Model(torch.nn.Module):
61+
def forward(self, x):
62+
return getattr(
63+
getattr(torch.ops, OpContextLoader.namespace), op_name
64+
).default((x,))
65+
66+
exported_program = torch.export.export(
67+
Model(), (torch.ones(1, 2),), strict=True
68+
)
69+
edge_program_manager = to_edge(
70+
{graph_name: exported_program},
71+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
72+
)
73+
74+
context_loader_nodes = [
75+
node
76+
for node in edge_program_manager._edge_programs[graph_name].graph.nodes
77+
if is_context_loader_node(node, op_name)
78+
]
79+
self.assertEqual(1, len(context_loader_nodes))
80+
self.assertTrue(is_context_loader_node(context_loader_nodes[0]))
81+
context_loader_nodes[0].meta[OpContextLoader.meta_ctx_bin] = ctx_bin
82+
self.assertEqual(
83+
ctx_bin,
84+
context_loader_nodes[0].meta[OpContextLoader.meta_ctx_bin],
85+
)
86+
87+
compiler_specs = generate_qnn_executorch_compiler_spec(
88+
soc_model=QcomChipset.SM8650,
89+
backend_options=generate_htp_compiler_spec(use_fp16=False),
90+
is_from_context_binary=True,
91+
)
92+
edge_program = edge_program_manager._edge_programs[graph_name]
93+
partition_result = QnnPartitioner(compiler_specs).partition(edge_program)
94+
context_loader_node = next(
95+
(
96+
node
97+
for node in partition_result.tagged_exported_program.graph.nodes
98+
if is_context_loader_node(node, op_name)
99+
),
100+
None,
101+
)
102+
self.assertIsNotNone(context_loader_node)
103+
delegation_tag = context_loader_node.meta.get("delegation_tag")
104+
self.assertIsNotNone(delegation_tag)
105+
self.assertIn(delegation_tag, partition_result.partition_tags)
106+
107+
def test_is_context_loader_target_predicate(self):
108+
op_name = "ctx_loader_predicate"
109+
custom_op = Library(OpContextLoader.namespace, "FRAGMENT")
110+
self.addCleanup(custom_op._destroy)
111+
custom_op.define(f"{op_name}(Tensor[] inputs) -> Any")
112+
113+
# Plain OpOverload in the context-loader namespace must match (the
114+
# _op unwrap must not break the non-edge-dialect target case).
115+
qaisw_op = getattr(
116+
getattr(torch.ops, OpContextLoader.namespace), op_name
117+
).default
118+
self.assertTrue(is_context_loader_target(qaisw_op, op_name))
119+
self.assertFalse(is_context_loader_target(qaisw_op, "different_op"))
120+
121+
# Ops in other namespaces must not match, including an edge op
122+
# (unwrapped via _op) whose namespace is not the loader's.
123+
self.assertFalse(is_context_loader_target(torch.ops.aten.add.default))
124+
self.assertFalse(is_context_loader_target(exir_ops.edge.aten.add.Tensor))
125+
126+
def test_build_op_wrappers_returns_context_binary(self):
127+
op_name = "ctx_loader_build"
128+
graph_name = "forward"
129+
ctx_bin = b"qnn_context_binary"
130+
custom_op = Library(OpContextLoader.namespace, "FRAGMENT")
131+
self.addCleanup(custom_op._destroy)
132+
custom_op.define(f"{op_name}(Tensor[] inputs) -> Any")
133+
134+
@torch.library.impl(
135+
custom_op, op_name, dispatch_key="CompositeExplicitAutograd"
136+
)
137+
def op_impl(inputs):
138+
return (torch.zeros((1, 2), device="meta", dtype=inputs[0].dtype),)
139+
140+
class Model(torch.nn.Module):
141+
def forward(self, x):
142+
return getattr(
143+
getattr(torch.ops, OpContextLoader.namespace), op_name
144+
).default((x,))
145+
146+
exported_program = torch.export.export(
147+
Model(), (torch.ones(1, 2),), strict=True
148+
)
149+
edge_program = to_edge(
150+
{graph_name: exported_program},
151+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
152+
)._edge_programs[graph_name]
153+
for node in edge_program.graph.nodes:
154+
if is_context_loader_node(node, op_name):
155+
node.meta[OpContextLoader.meta_ctx_bin] = ctx_bin
156+
157+
# For a graph whose only op is the context-binary loader, _build_op_wrappers
158+
# returns the stamped context binary directly, before any QNN compilation.
159+
result = QnnBackend._build_op_wrappers(
160+
edge_program,
161+
enable_tensor_dump=False,
162+
op_package_infos=[],
163+
use_mha2sha=False,
164+
backend_type=QnnExecuTorchBackendType.kHtpBackend,
165+
)
166+
self.assertEqual(ctx_bin, result)
167+
38168
def _build_quantized_graph(self):
39169
"""Build a quantized graph through AnnotateQuantAttrs + FoldQDQ."""
40170

backends/qualcomm/utils/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
QNN_QUANT_TYPE_MAP,
2727
QNN_TENSOR_TYPE_MAP,
2828
)
29-
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
29+
from executorch.backends.qualcomm.builders.qnn_constants import (
30+
is_context_loader_node,
31+
OpContextLoader,
32+
)
3033
from executorch.backends.qualcomm.partition.qnn_partitioner import (
3134
generate_qnn_executorch_option,
3235
get_skip_decomp_table,
@@ -959,15 +962,13 @@ def preprocess_binary(ctx_bin, compiler_specs):
959962
# temporarily remove the first parameter name.
960963
edge_prog_mgr = to_edge(
961964
{graph_name: bundle_prog["exported_program"]},
962-
# do not alter name for custom op
963-
compile_config=EdgeCompileConfig(_use_edge_ops=False),
965+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
964966
)
965967

966968
# update meta with context binary
967969
for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes:
968-
if n.op == "call_function" and OpContextLoader.namespace in str(n.target):
970+
if is_context_loader_node(n, op_name):
969971
n.meta[OpContextLoader.meta_ctx_bin] = ctx_bin
970-
break
971972

972973
bundle_prog["edge_program_manager"] = edge_prog_mgr.to_backend(
973974
QnnPartitioner(compiler_specs)

0 commit comments

Comments
 (0)