|
| 1 | +# Copyright 2024 NXP |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import flatbuffers |
| 7 | +from torch.export import ExportedProgram |
| 8 | +from torch.export.graph_signature import InputKind |
| 9 | +from torch.fx import Node |
| 10 | +from torch.nn.parameter import Parameter |
| 11 | + |
| 12 | +import executorch.backends.nxp.backend.ir.logger as logger |
| 13 | +from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig |
| 14 | +from executorch.backends.nxp.backend.ir.conversion_context import ConversionContext |
| 15 | +from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import AtenModelBuilderDirector |
| 16 | +from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * |
| 17 | +from executorch.backends.nxp.backend.node_format_inference import NodeFormatInference, NodeFormat |
| 18 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 19 | + |
| 20 | +# noinspection PyProtectedMember |
| 21 | +functions_converters = { |
| 22 | + exir_ops.edge.aten.addmm.default: AddMMConverter, |
| 23 | + exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, |
| 24 | + exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, |
| 25 | + exir_ops.edge.aten.convolution.default: ConvolutionConverter, |
| 26 | + exir_ops.edge.aten.max_pool2d.default: MaxPool2dConverter, |
| 27 | + exir_ops.edge.aten.mm.default: MMConverter, |
| 28 | + exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, |
| 29 | + exir_ops.edge.aten.relu.default: ReLUConverter, |
| 30 | + exir_ops.edge.aten._softmax.default: SoftmaxConverter, |
| 31 | + exir_ops.edge.aten.view_copy.default: ViewCopyConverter, |
| 32 | +} |
| 33 | + |
| 34 | + |
| 35 | +class EdgeProgramToIRConverter: |
| 36 | + """ |
| 37 | + Converter from convertion of ExportedProgram in Edge dialect to IR (TFLite Flatbuffers). |
| 38 | + """ |
| 39 | + |
| 40 | + def convert_program(self, edge_program: ExportedProgram, conversion_config=ConversionConfig()) -> (bytes, dict): |
| 41 | + """ |
| 42 | + Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes. |
| 43 | +
|
| 44 | + :param edge_program: Converter ExportedProgram. |
| 45 | + :param conversion_config: ConversionConfig instance. |
| 46 | + :return: TFLite flatbuffers as bytes. |
| 47 | + """ |
| 48 | + node_formats = NodeFormatInference(edge_program).identify_node_formats() |
| 49 | + parameters_mapping = self.map_inputs_to_parameters(edge_program) |
| 50 | + |
| 51 | + cc = self.build_conversion_context(parameters_mapping, node_formats, conversion_config) |
| 52 | + |
| 53 | + # Program conversion |
| 54 | + self.append_placeholders_and_tensors(edge_program.graph.nodes, cc) |
| 55 | + self._convert_qdq_cluster_q_dq_nodes(edge_program.graph.nodes, cc) |
| 56 | + self._process_nodes(edge_program.graph.nodes, cc) |
| 57 | + |
| 58 | + # Assign output |
| 59 | + io_formats = cc.tflite_builder.assign_model_io_to_subgraph_and_get_io_formats(edge_program.graph_signature) |
| 60 | + |
| 61 | + # TFLite model generation |
| 62 | + internal_tflite_model = cc.tflite_builder.finish() |
| 63 | + flatbuffers_builder = flatbuffers.Builder() |
| 64 | + internal_tflite_model.gen_tflite(flatbuffers_builder) |
| 65 | + |
| 66 | + return bytes(flatbuffers_builder.Output()), io_formats |
| 67 | + |
| 68 | + @staticmethod |
| 69 | + def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext): |
| 70 | + for node in nodes: |
| 71 | + if node.op == "placeholder": |
| 72 | + node_format = context.node_formats[node] |
| 73 | + |
| 74 | + if node.name in context.parameters_mapping: |
| 75 | + # Node is placeholder and has data -> append as static tensor with data |
| 76 | + tensor = context.parameters_mapping[node.name] |
| 77 | + context.tflite_builder.append_as_static_tensor(node, node_format, tensor) |
| 78 | + else: |
| 79 | + # Node is placeholder and doesn't have data (user input) -> append as fake tensor |
| 80 | + context.tflite_builder.append_as_fake_tensor(node, node_format) |
| 81 | + elif node.op == "call_function": |
| 82 | + # Node is call function -> append only output as a tensor |
| 83 | + node_format = context.node_formats[node] |
| 84 | + context.tflite_builder.append_as_fake_tensor(node, node_format) |
| 85 | + elif node.op == "output": |
| 86 | + # Nothing to do |
| 87 | + pass |
| 88 | + else: |
| 89 | + logger.e(logger.Code.INTERNAL_ERROR, f"Unexpected node op type: '{node.op}'!") |
| 90 | + |
| 91 | + def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContext): |
| 92 | + """ |
| 93 | + Go through program nodes and append their TFLite siblings into ModelBuilder. |
| 94 | +
|
| 95 | + :param nodes: Program's nodes. |
| 96 | + :param conversion_context: ConversionContext instance. |
| 97 | + """ |
| 98 | + |
| 99 | + qdq_related_functions = [ |
| 100 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
| 101 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default |
| 102 | + ] |
| 103 | + |
| 104 | + for node in nodes: |
| 105 | + if node.op == "call_function": |
| 106 | + if node.target in qdq_related_functions and "cluster" in node.meta: |
| 107 | + # Skip (De)Quantize nodes that were already processed |
| 108 | + pass |
| 109 | + elif node.target in functions_converters: |
| 110 | + functions_converters[node.target](conversion_context).convert(node) |
| 111 | + else: |
| 112 | + logger.e(logger.Code.NOT_IMPLEMENTED, f"Converter for '{node.target.__name__}' not implemented!") |
| 113 | + |
| 114 | + @staticmethod |
| 115 | + def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Parameter]: |
| 116 | + """ |
| 117 | + Create mapping between program parameters (input nodes & static data nodes) and their names. |
| 118 | +
|
| 119 | + :param edge_program: EdgeProgram instance. |
| 120 | + :return: Mapping from parameter name to parameter instance. |
| 121 | + """ |
| 122 | + result_map = {} |
| 123 | + |
| 124 | + for input_spec in edge_program.graph_signature.input_specs: |
| 125 | + if input_spec.kind in [InputKind.PARAMETER, InputKind.BUFFER]: |
| 126 | + result_map[input_spec.arg.name] = edge_program.state_dict[input_spec.target] |
| 127 | + |
| 128 | + return result_map |
| 129 | + |
| 130 | + @staticmethod |
| 131 | + def build_conversion_context( |
| 132 | + parameters_mapping: dict, |
| 133 | + node_formats: dict[Node, NodeFormat], |
| 134 | + conversion_config: ConversionConfig = ConversionConfig(), |
| 135 | + ) -> ConversionContext: |
| 136 | + tflite_builder = AtenModelBuilderDirector(3, "TFLite from EdgeProgram", conversion_config) |
| 137 | + |
| 138 | + # Add "sentinel" buffer (defined in schema.fbs) |
| 139 | + tflite_builder.build_empty_buffer() |
| 140 | + |
| 141 | + context = ConversionContext(tflite_builder, conversion_config, parameters_mapping, node_formats) |
| 142 | + |
| 143 | + return context |
| 144 | + |
| 145 | + def _convert_qdq_cluster_q_dq_nodes(self, nodes: list[Node], conversion_context: ConversionContext): |
| 146 | + """ |
| 147 | + Go through program and convert De(Quantize) nodes that are part of the QDQ cluster into |
| 148 | + tensors. |
| 149 | +
|
| 150 | + :param nodes: Program's nodes. |
| 151 | + :param conversion_context: ConversionContext instance. |
| 152 | + """ |
| 153 | + qdq_q_ops_converters = { |
| 154 | + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQDequantizeConverter, |
| 155 | + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: QDQQuantizeConverter, |
| 156 | + } |
| 157 | + |
| 158 | + for node in nodes: |
| 159 | + part_of_qdq_cluster = "cluster" in node.meta |
| 160 | + if node.op == "call_function" and node.target in qdq_q_ops_converters and part_of_qdq_cluster: |
| 161 | + qdq_q_ops_converters[node.target](conversion_context).convert(node) |
0 commit comments