|
13 | 13 | from executorch.backends.qualcomm._passes.qnn_pass_manager import ( |
14 | 14 | get_qnn_pass_manager_cls, |
15 | 15 | ) |
| 16 | +from executorch.backends.qualcomm.builders.qnn_constants import ( |
| 17 | + is_context_loader_node, |
| 18 | + is_context_loader_target, |
| 19 | + OpContextLoader, |
| 20 | +) |
| 21 | +from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner |
| 22 | +from executorch.backends.qualcomm.qnn_preprocess import QnnBackend |
16 | 23 | from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype |
17 | 24 | from executorch.backends.qualcomm.serialization.qc_schema import ( |
18 | 25 | QcomChipset, |
|
28 | 35 | generate_qnn_executorch_compiler_spec, |
29 | 36 | to_edge_transform_and_lower_to_qnn, |
30 | 37 | ) |
31 | | -from executorch.exir import to_edge |
| 38 | +from executorch.exir import EdgeCompileConfig, to_edge |
32 | 39 | from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY |
33 | 40 | from executorch.exir.dialects._ops import ops as exir_ops |
| 41 | +from torch.library import Library |
34 | 42 | from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e |
35 | 43 |
|
36 | 44 |
|
37 | 45 | class TestPasses(unittest.TestCase): |
| 46 | + def test_context_loader_edge_op_is_delegated(self): |
| 47 | + op_name = "ctx_loader_delegation" |
| 48 | + graph_name = "forward" |
| 49 | + ctx_bin = b"qnn_context_binary" |
| 50 | + custom_op = Library(OpContextLoader.namespace, "FRAGMENT") |
| 51 | + self.addCleanup(custom_op._destroy) |
| 52 | + custom_op.define(f"{op_name}(Tensor[] inputs) -> Any") |
| 53 | + |
| 54 | + @torch.library.impl( |
| 55 | + custom_op, op_name, dispatch_key="CompositeExplicitAutograd" |
| 56 | + ) |
| 57 | + def op_impl(inputs): |
| 58 | + return (torch.zeros((1, 2), device="meta", dtype=inputs[0].dtype),) |
| 59 | + |
| 60 | + class Model(torch.nn.Module): |
| 61 | + def forward(self, x): |
| 62 | + return getattr( |
| 63 | + getattr(torch.ops, OpContextLoader.namespace), op_name |
| 64 | + ).default((x,)) |
| 65 | + |
| 66 | + exported_program = torch.export.export( |
| 67 | + Model(), (torch.ones(1, 2),), strict=True |
| 68 | + ) |
| 69 | + edge_program_manager = to_edge( |
| 70 | + {graph_name: exported_program}, |
| 71 | + compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| 72 | + ) |
| 73 | + |
| 74 | + context_loader_nodes = [ |
| 75 | + node |
| 76 | + for node in edge_program_manager._edge_programs[graph_name].graph.nodes |
| 77 | + if is_context_loader_node(node, op_name) |
| 78 | + ] |
| 79 | + self.assertEqual(1, len(context_loader_nodes)) |
| 80 | + self.assertTrue(is_context_loader_node(context_loader_nodes[0])) |
| 81 | + context_loader_nodes[0].meta[OpContextLoader.meta_ctx_bin] = ctx_bin |
| 82 | + self.assertEqual( |
| 83 | + ctx_bin, |
| 84 | + context_loader_nodes[0].meta[OpContextLoader.meta_ctx_bin], |
| 85 | + ) |
| 86 | + |
| 87 | + compiler_specs = generate_qnn_executorch_compiler_spec( |
| 88 | + soc_model=QcomChipset.SM8650, |
| 89 | + backend_options=generate_htp_compiler_spec(use_fp16=False), |
| 90 | + is_from_context_binary=True, |
| 91 | + ) |
| 92 | + edge_program = edge_program_manager._edge_programs[graph_name] |
| 93 | + partition_result = QnnPartitioner(compiler_specs).partition(edge_program) |
| 94 | + context_loader_node = next( |
| 95 | + ( |
| 96 | + node |
| 97 | + for node in partition_result.tagged_exported_program.graph.nodes |
| 98 | + if is_context_loader_node(node, op_name) |
| 99 | + ), |
| 100 | + None, |
| 101 | + ) |
| 102 | + self.assertIsNotNone(context_loader_node) |
| 103 | + delegation_tag = context_loader_node.meta.get("delegation_tag") |
| 104 | + self.assertIsNotNone(delegation_tag) |
| 105 | + self.assertIn(delegation_tag, partition_result.partition_tags) |
| 106 | + |
| 107 | + def test_is_context_loader_target_predicate(self): |
| 108 | + op_name = "ctx_loader_predicate" |
| 109 | + custom_op = Library(OpContextLoader.namespace, "FRAGMENT") |
| 110 | + self.addCleanup(custom_op._destroy) |
| 111 | + custom_op.define(f"{op_name}(Tensor[] inputs) -> Any") |
| 112 | + |
| 113 | + # Plain OpOverload in the context-loader namespace must match (the |
| 114 | + # _op unwrap must not break the non-edge-dialect target case). |
| 115 | + qaisw_op = getattr( |
| 116 | + getattr(torch.ops, OpContextLoader.namespace), op_name |
| 117 | + ).default |
| 118 | + self.assertTrue(is_context_loader_target(qaisw_op, op_name)) |
| 119 | + self.assertFalse(is_context_loader_target(qaisw_op, "different_op")) |
| 120 | + |
| 121 | + # Ops in other namespaces must not match, including an edge op |
| 122 | + # (unwrapped via _op) whose namespace is not the loader's. |
| 123 | + self.assertFalse(is_context_loader_target(torch.ops.aten.add.default)) |
| 124 | + self.assertFalse(is_context_loader_target(exir_ops.edge.aten.add.Tensor)) |
| 125 | + |
| 126 | + def test_build_op_wrappers_returns_context_binary(self): |
| 127 | + op_name = "ctx_loader_build" |
| 128 | + graph_name = "forward" |
| 129 | + ctx_bin = b"qnn_context_binary" |
| 130 | + custom_op = Library(OpContextLoader.namespace, "FRAGMENT") |
| 131 | + self.addCleanup(custom_op._destroy) |
| 132 | + custom_op.define(f"{op_name}(Tensor[] inputs) -> Any") |
| 133 | + |
| 134 | + @torch.library.impl( |
| 135 | + custom_op, op_name, dispatch_key="CompositeExplicitAutograd" |
| 136 | + ) |
| 137 | + def op_impl(inputs): |
| 138 | + return (torch.zeros((1, 2), device="meta", dtype=inputs[0].dtype),) |
| 139 | + |
| 140 | + class Model(torch.nn.Module): |
| 141 | + def forward(self, x): |
| 142 | + return getattr( |
| 143 | + getattr(torch.ops, OpContextLoader.namespace), op_name |
| 144 | + ).default((x,)) |
| 145 | + |
| 146 | + exported_program = torch.export.export( |
| 147 | + Model(), (torch.ones(1, 2),), strict=True |
| 148 | + ) |
| 149 | + edge_program = to_edge( |
| 150 | + {graph_name: exported_program}, |
| 151 | + compile_config=EdgeCompileConfig(_check_ir_validity=False), |
| 152 | + )._edge_programs[graph_name] |
| 153 | + for node in edge_program.graph.nodes: |
| 154 | + if is_context_loader_node(node, op_name): |
| 155 | + node.meta[OpContextLoader.meta_ctx_bin] = ctx_bin |
| 156 | + |
| 157 | + # For a graph whose only op is the context-binary loader, _build_op_wrappers |
| 158 | + # returns the stamped context binary directly, before any QNN compilation. |
| 159 | + result = QnnBackend._build_op_wrappers( |
| 160 | + edge_program, |
| 161 | + enable_tensor_dump=False, |
| 162 | + op_package_infos=[], |
| 163 | + use_mha2sha=False, |
| 164 | + backend_type=QnnExecuTorchBackendType.kHtpBackend, |
| 165 | + ) |
| 166 | + self.assertEqual(ctx_bin, result) |
| 167 | + |
38 | 168 | def _build_quantized_graph(self): |
39 | 169 | """Build a quantized graph through AnnotateQuantAttrs + FoldQDQ.""" |
40 | 170 |
|
|
0 commit comments