|
| 1 | +# Copyright 2024 NXP |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import executorch.backends.nxp.backend.ir.logger as logger |
| 7 | +import flatbuffers |
| 8 | +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig |
| 9 | +from executorch.backends.nxp.backend.ir.conversion_context import ConversionContext |
| 10 | +from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import ( |
| 11 | + AtenModelBuilderDirector, |
| 12 | +) |
| 13 | +from torch.export import ExportedProgram |
| 14 | +from torch.export.graph_signature import InputKind |
| 15 | +from torch.fx import Node |
| 16 | +from torch.nn.parameter import Parameter |
| 17 | +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * |
| 18 | +from executorch.backends.nxp.backend.node_format_inference import ( |
| 19 | + NodeFormat, |
| 20 | + NodeFormatInference, |
| 21 | +) |
| 22 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 23 | + |
| 24 | +# noinspection PyProtectedMember |
| 25 | +functions_converters = { |
| 26 | + exir_ops.edge.aten.addmm.default: AddMMConverter, |
| 27 | + exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, |
| 28 | + exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, |
| 29 | + exir_ops.edge.aten.convolution.default: ConvolutionConverter, |
| 30 | + exir_ops.edge.aten.max_pool2d.default: MaxPool2dConverter, |
| 31 | + exir_ops.edge.aten.mm.default: MMConverter, |
| 32 | + exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, |
| 33 | + exir_ops.edge.aten.relu.default: ReLUConverter, |
| 34 | + exir_ops.edge.aten._softmax.default: SoftmaxConverter, |
| 35 | + exir_ops.edge.aten.view_copy.default: ViewCopyConverter, |
| 36 | +} |
| 37 | + |
| 38 | + |
| 39 | +class EdgeProgramToIRConverter: |
| 40 | + """ |
| 41 | + Converter from convertion of ExportedProgram in Edge dialect to IR (TFLite Flatbuffers). |
| 42 | + """ |
| 43 | + |
| 44 | + def convert_program( |
| 45 | + self, edge_program: ExportedProgram, conversion_config=ConversionConfig() |
| 46 | + ) -> (bytes, dict): |
| 47 | + """ |
| 48 | + Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes. |
| 49 | +
|
| 50 | + :param edge_program: Converter ExportedProgram. |
| 51 | + :param conversion_config: ConversionConfig instance. |
| 52 | + :return: TFLite flatbuffers as bytes. |
| 53 | + """ |
| 54 | + node_formats = NodeFormatInference(edge_program).identify_node_formats() |
| 55 | + parameters_mapping = self.map_inputs_to_parameters(edge_program) |
| 56 | + |
| 57 | + cc = self.build_conversion_context( |
| 58 | + parameters_mapping, node_formats, conversion_config |
| 59 | + ) |
| 60 | + |
| 61 | + # Program conversion |
| 62 | + self.append_placeholders_and_tensors(edge_program.graph.nodes, cc) |
| 63 | + self._convert_qdq_cluster_q_dq_nodes(edge_program.graph.nodes, cc) |
| 64 | + self._process_nodes(edge_program.graph.nodes, cc) |
| 65 | + |
| 66 | + # Assign output |
| 67 | + io_formats = cc.tflite_builder.assign_model_io_to_subgraph_and_get_io_formats( |
| 68 | + edge_program.graph_signature |
| 69 | + ) |
| 70 | + |
| 71 | + # TFLite model generation |
| 72 | + internal_tflite_model = cc.tflite_builder.finish() |
| 73 | + flatbuffers_builder = flatbuffers.Builder() |
| 74 | + internal_tflite_model.gen_tflite(flatbuffers_builder) |
| 75 | + |
| 76 | + return bytes(flatbuffers_builder.Output()), io_formats |
| 77 | + |
| 78 | + @staticmethod |
| 79 | + def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext): |
| 80 | + for node in nodes: |
| 81 | + if node.op == "placeholder": |
| 82 | + node_format = context.node_formats[node] |
| 83 | + |
| 84 | + if node.name in context.parameters_mapping: |
| 85 | + # Node is placeholder and has data -> append as static tensor with data |
| 86 | + tensor = context.parameters_mapping[node.name] |
| 87 | + context.tflite_builder.append_as_static_tensor( |
| 88 | + node, node_format, tensor |
| 89 | + ) |
| 90 | + else: |
| 91 | + # Node is placeholder and doesn't have data (user input) -> append as fake tensor |
| 92 | + context.tflite_builder.append_as_fake_tensor(node, node_format) |
| 93 | + elif node.op == "call_function": |
| 94 | + # Node is call function -> append only output as a tensor |
| 95 | + node_format = context.node_formats[node] |
| 96 | + context.tflite_builder.append_as_fake_tensor(node, node_format) |
| 97 | + elif node.op == "output": |
| 98 | + # Nothing to do |
| 99 | + pass |
| 100 | + else: |
| 101 | + logger.e( |
| 102 | + logger.Code.INTERNAL_ERROR, f"Unexpected node op type: '{node.op}'!" |
| 103 | + ) |
| 104 | + |
| 105 | + def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContext): |
| 106 | + """ |
| 107 | + Go through program nodes and append their TFLite siblings into ModelBuilder. |
| 108 | +
|
| 109 | + :param nodes: Program's nodes. |
| 110 | + :param conversion_context: ConversionContext instance. |
| 111 | + """ |
| 112 | + |
| 113 | + qdq_related_functions = [ |
| 114 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 115 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
| 116 | + ] |
| 117 | + |
| 118 | + for node in nodes: |
| 119 | + if node.op == "call_function": |
| 120 | + if node.target in qdq_related_functions and "cluster" in node.meta: |
| 121 | + # Skip (De)Quantize nodes that were already processed |
| 122 | + pass |
| 123 | + elif node.target in functions_converters: |
| 124 | + functions_converters[node.target](conversion_context).convert(node) |
| 125 | + else: |
| 126 | + logger.e( |
| 127 | + logger.Code.NOT_IMPLEMENTED, |
| 128 | + f"Converter for '{node.target.__name__}' not implemented!", |
| 129 | + ) |
| 130 | + |
| 131 | + @staticmethod |
| 132 | + def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Parameter]: |
| 133 | + """ |
| 134 | + Create mapping between program parameters (input nodes & static data nodes) and their names. |
| 135 | +
|
| 136 | + :param edge_program: EdgeProgram instance. |
| 137 | + :return: Mapping from parameter name to parameter instance. |
| 138 | + """ |
| 139 | + result_map = {} |
| 140 | + |
| 141 | + for input_spec in edge_program.graph_signature.input_specs: |
| 142 | + if input_spec.kind in [InputKind.PARAMETER, InputKind.BUFFER]: |
| 143 | + result_map[input_spec.arg.name] = edge_program.state_dict[ |
| 144 | + input_spec.target |
| 145 | + ] |
| 146 | + |
| 147 | + return result_map |
| 148 | + |
| 149 | + @staticmethod |
| 150 | + def build_conversion_context( |
| 151 | + parameters_mapping: dict, |
| 152 | + node_formats: dict[Node, NodeFormat], |
| 153 | + conversion_config: ConversionConfig = ConversionConfig(), |
| 154 | + ) -> ConversionContext: |
| 155 | + tflite_builder = AtenModelBuilderDirector( |
| 156 | + 3, "TFLite from EdgeProgram", conversion_config |
| 157 | + ) |
| 158 | + |
| 159 | + # Add "sentinel" buffer (defined in schema.fbs) |
| 160 | + tflite_builder.build_empty_buffer() |
| 161 | + |
| 162 | + context = ConversionContext( |
| 163 | + tflite_builder, conversion_config, parameters_mapping, node_formats |
| 164 | + ) |
| 165 | + |
| 166 | + return context |
| 167 | + |
| 168 | + def _convert_qdq_cluster_q_dq_nodes( |
| 169 | + self, nodes: list[Node], conversion_context: ConversionContext |
| 170 | + ): |
| 171 | + """ |
| 172 | + Go through program and convert De(Quantize) nodes that are part of the QDQ cluster into |
| 173 | + tensors. |
| 174 | +
|
| 175 | + :param nodes: Program's nodes. |
| 176 | + :param conversion_context: ConversionContext instance. |
| 177 | + """ |
| 178 | + qdq_q_ops_converters = { |
| 179 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQDequantizeConverter, |
| 180 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: QDQQuantizeConverter, |
| 181 | + } |
| 182 | + |
| 183 | + for node in nodes: |
| 184 | + part_of_qdq_cluster = "cluster" in node.meta |
| 185 | + if ( |
| 186 | + node.op == "call_function" |
| 187 | + and node.target in qdq_q_ops_converters |
| 188 | + and part_of_qdq_cluster |
| 189 | + ): |
| 190 | + qdq_q_ops_converters[node.target](conversion_context).convert(node) |
0 commit comments