Skip to content

NXP Backend: Add eIQ Neutron Backend #10196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Copyright 2023 Arm Limited and/or its affiliates.
Copyright (c) Qualcomm Innovation Center, Inc.
Copyright (c) 2023 Apple Inc.
Copyright (c) 2024 MediaTek Inc.
Copyright 2023 NXP

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.pass_manager import PassManager
from torch.fx import GraphModule
Expand Down
40 changes: 40 additions & 0 deletions backends/nxp/backend/edge_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2024 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch.fx import Node


def input_tensor(node: Node, input_index: int) -> torch.Tensor:
if len(node.all_input_nodes) <= input_index:
raise IndexError

return node.all_input_nodes[input_index].meta["val"]


def output_tensor(node: Node) -> torch.Tensor:
return node.meta["val"]


def tensor_rank(tensor: torch.Tensor) -> int:
return len(tensor.size())


def input_rank(node: Node, input_index: int) -> int:
return tensor_rank(input_tensor(node, input_index))


def input_tensor_safe(node: Node, input_index: int) -> torch.Tensor | None:
"""Return the input tensor of 'node' at index 'input_index', or None if the node doesn't have that input.

:param node: Edge node to get the input tensor from.
:param input_index: Index of the input tensor to get.
:return: The input tensor at index 'input_index', or None.
"""

if len(node.all_input_nodes) <= input_index:
return None

return input_tensor(node, input_index)
194 changes: 194 additions & 0 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright 2024 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.nxp.backend.ir.logger as logger
import flatbuffers
from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
from executorch.backends.nxp.backend.ir.conversion_context import ConversionContext
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
AtenModelBuilderDirector,
)
from torch.export import ExportedProgram
from torch.export.graph_signature import InputKind
from torch.fx import Node
from torch.nn.parameter import Parameter
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters import * # noqa F403
from executorch.backends.nxp.backend.node_format_inference import (
NodeFormat,
NodeFormatInference,
)
from executorch.exir.dialects._ops import ops as exir_ops

# noinspection PyProtectedMember
functions_converters = {
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, # noqa F405
exir_ops.edge.aten.convolution.default: ConvolutionConverter, # noqa F405
exir_ops.edge.aten.max_pool2d.default: MaxPool2dConverter, # noqa F405
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
}


class EdgeProgramToIRConverter:
"""
Converter from convertion of ExportedProgram in Edge dialect to IR (TFLite Flatbuffers).
"""

_default_conversion_config = ConversionConfig()

def convert_program(
self,
edge_program: ExportedProgram,
conversion_config=_default_conversion_config,
) -> (bytes, dict):
"""
Convert ExportedProgram in Edge dialect to IR (TFLite flatbuffers) as bytes.

:param edge_program: Converter ExportedProgram.
:param conversion_config: ConversionConfig instance.
:return: TFLite flatbuffers as bytes.
"""
node_formats = NodeFormatInference(edge_program).identify_node_formats()
parameters_mapping = self.map_inputs_to_parameters(edge_program)

cc = self.build_conversion_context(
parameters_mapping, node_formats, conversion_config
)

# Program conversion
self.append_placeholders_and_tensors(edge_program.graph.nodes, cc)
self._convert_qdq_cluster_q_dq_nodes(edge_program.graph.nodes, cc)
self._process_nodes(edge_program.graph.nodes, cc)

# Assign output
io_formats = cc.tflite_builder.assign_model_io_to_subgraph_and_get_io_formats(
edge_program.graph_signature
)

# TFLite model generation
internal_tflite_model = cc.tflite_builder.finish()
flatbuffers_builder = flatbuffers.Builder()
internal_tflite_model.gen_tflite(flatbuffers_builder)

return bytes(flatbuffers_builder.Output()), io_formats

@staticmethod
def append_placeholders_and_tensors(nodes: list[Node], context: ConversionContext):
for node in nodes:
if node.op == "placeholder":
node_format = context.node_formats[node]

if node.name in context.parameters_mapping:
# Node is placeholder and has data -> append as static tensor with data
tensor = context.parameters_mapping[node.name]
context.tflite_builder.append_as_static_tensor(
node, node_format, tensor
)
else:
# Node is placeholder and doesn't have data (user input) -> append as fake tensor
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "call_function":
# Node is call function -> append only output as a tensor
node_format = context.node_formats[node]
context.tflite_builder.append_as_fake_tensor(node, node_format)
elif node.op == "output":
# Nothing to do
pass
else:
logger.e(
logger.Code.INTERNAL_ERROR, f"Unexpected node op type: '{node.op}'!"
)

def _process_nodes(self, nodes: list[Node], conversion_context: ConversionContext):
"""
Go through program nodes and append their TFLite siblings into ModelBuilder.

:param nodes: Program's nodes.
:param conversion_context: ConversionContext instance.
"""

qdq_related_functions = [
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
]

for node in nodes:
if node.op == "call_function":
if node.target in qdq_related_functions and "cluster" in node.meta:
# Skip (De)Quantize nodes that were already processed
pass
elif node.target in functions_converters:
functions_converters[node.target](conversion_context).convert(node)
else:
logger.e(
logger.Code.NOT_IMPLEMENTED,
f"Converter for '{node.target.__name__}' not implemented!",
)

@staticmethod
def map_inputs_to_parameters(edge_program: ExportedProgram) -> dict[str, Parameter]:
"""
Create mapping between program parameters (input nodes & static data nodes) and their names.

:param edge_program: EdgeProgram instance.
:return: Mapping from parameter name to parameter instance.
"""
result_map = {}

for input_spec in edge_program.graph_signature.input_specs:
if input_spec.kind in [InputKind.PARAMETER, InputKind.BUFFER]:
result_map[input_spec.arg.name] = edge_program.state_dict[
input_spec.target
]

return result_map

@staticmethod
def build_conversion_context(
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
conversion_config: ConversionConfig = _default_conversion_config,
) -> ConversionContext:
tflite_builder = AtenModelBuilderDirector(
3, "TFLite from EdgeProgram", conversion_config
)

# Add "sentinel" buffer (defined in schema.fbs)
tflite_builder.build_empty_buffer()

context = ConversionContext(
tflite_builder, conversion_config, parameters_mapping, node_formats
)

return context

def _convert_qdq_cluster_q_dq_nodes(
self, nodes: list[Node], conversion_context: ConversionContext
):
"""
Go through program and convert De(Quantize) nodes that are part of the QDQ cluster into
tensors.

:param nodes: Program's nodes.
:param conversion_context: ConversionContext instance.
"""
qdq_q_ops_converters = {
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: QDQDequantizeConverter, # noqa F405
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: QDQQuantizeConverter, # noqa F405
}

for node in nodes:
part_of_qdq_cluster = "cluster" in node.meta
if (
node.op == "call_function"
and node.target in qdq_q_ops_converters
and part_of_qdq_cluster
):
qdq_q_ops_converters[node.target](conversion_context).convert(node)
64 changes: 64 additions & 0 deletions backends/nxp/backend/ir/conversion_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2024 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


class ConversionConfig:

def __init__(self, args: dict | None = None):
"""
Conversion configuration passed through command line arguments or gathered during
the conversion process.

:param args: Optional dictionary with conversion arguments. Unknown arguments are ignored.
"""
self.keep_io_format: bool = False
self.skip_shape_inference: bool = False
self.allow_inputs_stripping: bool = True
self.qdq_aware_conversion: bool = True
self.symbolic_dimensions_mapping: dict[str, int] | None = None
self.input_shapes_mapping: dict[str, tuple] | None = None
self.dont_skip_nodes_with_known_outputs: bool = False
self.allow_select_ops: bool = True
self.generate_artifacts_after_failed_shape_inference: bool = True

self.optimization_whitelist: list | None = None
self.optimization_blacklist: list | None = None

self.non_negative_indices: bool = False
self.cast_int64_to_int32: bool = False
self.accept_resize_rounding_error: bool = False
self.ignore_opset_version: bool = False

self.tflite_quantization_integrity_check: bool = True

if args is not None:
for key, value in args.items():
if key in self.__dict__:
setattr(self, key, value)

def __repr__(self):
attrs = []
for attr in self.__dict__:
attrs.append(f"{attr}={getattr(self, attr)}")

return "ConversionConfig[" + ", ".join(attrs) + "]"


class SkipShapeInferenceConfig(ConversionConfig):

def __init__(self):
"""
Conversion config shortcut with disabled shape inference.
"""
super().__init__({"skip_shape_inference": True})


class QDQAwareConfig(ConversionConfig):

def __init__(self):
"""
Conversion config shortcut with QDQ aware conversion enabled.
"""
super().__init__({"qdq_aware_conversion": True})
37 changes: 37 additions & 0 deletions backends/nxp/backend/ir/conversion_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2024 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.backends.nxp.backend.ir.conversion_config import ConversionConfig
from executorch.backends.nxp.backend.ir.converter.builder.aten_model_builder_director import (
AtenModelBuilderDirector,
)
from executorch.backends.nxp.backend.node_format_inference import NodeFormat
from torch import Node
from torch.nn import Parameter


class ConversionContext:
tflite_builder: AtenModelBuilderDirector
conversion_config: ConversionConfig
parameters_mapping: dict[str, Parameter]
node_formats: dict[Node, NodeFormat]

def __init__(
self,
tflite_builder: AtenModelBuilderDirector,
conversion_config: ConversionConfig,
parameters_mapping: dict,
node_formats: dict[Node, NodeFormat],
):
"""
Context with data related to current conversion.

:param tflite_builder: TFLite model builder.
:param conversion_config: Conversion configuration flags and metadata.
"""
self.tflite_builder = tflite_builder
self.conversion_config = conversion_config
self.parameters_mapping = parameters_mapping
self.node_formats = node_formats
Empty file.
Empty file.
Loading
Loading