From 30ee8bd3a16cc1f98f4d2d023be940d07a6ab685 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Tue, 15 Apr 2025 08:21:00 -0700 Subject: [PATCH] init Summary: Just a placeholder q/dq AoT ops in a new namespace with a test Differential Revision: D72987759 --- backends/cortex_m/README.md | 3 + backends/cortex_m/ops/TARGETS | 21 ++ backends/cortex_m/ops/operators.py | 90 ++++++++ backends/cortex_m/passes/TARGETS | 15 ++ .../passes/replace_quant_nodes_pass.py | 65 ++++++ backends/cortex_m/test/TARGETS | 12 + .../cortex_m/test/test_replace_quant_nodes.py | 207 ++++++++++++++++++ 7 files changed, 413 insertions(+) create mode 100644 backends/cortex_m/README.md create mode 100644 backends/cortex_m/ops/TARGETS create mode 100644 backends/cortex_m/ops/operators.py create mode 100644 backends/cortex_m/passes/TARGETS create mode 100644 backends/cortex_m/passes/replace_quant_nodes_pass.py create mode 100644 backends/cortex_m/test/TARGETS create mode 100644 backends/cortex_m/test/test_replace_quant_nodes.py diff --git a/backends/cortex_m/README.md b/backends/cortex_m/README.md new file mode 100644 index 00000000000..00d20439678 --- /dev/null +++ b/backends/cortex_m/README.md @@ -0,0 +1,3 @@ +# Cortex-M Backend + +WIP. This is a temporary backend for Cortex-M CPUs. It is not intended to be used in production, but rather as a proof of concept. Things will change without notice. diff --git a/backends/cortex_m/ops/TARGETS b/backends/cortex_m/ops/TARGETS new file mode 100644 index 00000000000..11c8307fc09 --- /dev/null +++ b/backends/cortex_m/ops/TARGETS @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") +load("@fbcode_macros//build_defs:export_files.bzl", "export_file") +load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib") + +oncall("executorch") + +python_library( + name = "ops", + srcs = [ + "operators.py", + ], + deps = [ + "fbcode//caffe2:torch", + ] +) diff --git a/backends/cortex_m/ops/operators.py b/backends/cortex_m/ops/operators.py new file mode 100644 index 00000000000..377aaeef59c --- /dev/null +++ b/backends/cortex_m/ops/operators.py @@ -0,0 +1,90 @@ +import torch +from torch.library import impl, Library, register_fake +from executorch.exir.dialects._ops import ( + ops as exir_ops, +) # To provide the implementation of the operators + +# New operator library with a custom namespace to allow fusion etc. +lib = Library("cortex_m", "DEF") + +### +# dequantize_per_tensor +### + +lib.define( + "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) + +lib.define( + "quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + +@register_fake("cortex_m::quantize_per_tensor") +def quantize_per_tensor_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty_like(input, dtype=dtype) + + +@impl(lib, "quantize_per_tensor", "CompositeExplicitAutograd") +def quantize_per_tensor_impl( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + The implementation of the quantize_per_tensor operator is the same as the + quantize_per_tensor operator in the edge dialect. + """ + return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default( + input, scale, zero_point, quant_min, quant_max, dtype + ) + + +### +# dequantize_per_tensor +### + +lib.define( + "dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)" +) +lib.define( + "dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)" +) + +@register_fake("cortex_m::dequantize_per_tensor") +def dequantize_per_tensor_meta( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return torch.empty_like(input, dtype=torch.float) + + +@impl(lib, "dequantize_per_tensor", "CompositeExplicitAutograd") +def dequantize_per_tensor_impl( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + The implementation of the dequantize_per_tensor operator is the same as the + dequantize_per_tensor operator in the edge dialect. + """ + return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default( + input, scale, zero_point, quant_min, quant_max, dtype + ) diff --git a/backends/cortex_m/passes/TARGETS b/backends/cortex_m/passes/TARGETS new file mode 100644 index 00000000000..57b268d1ff7 --- /dev/null +++ b/backends/cortex_m/passes/TARGETS @@ -0,0 +1,15 @@ +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") + +oncall("executorch") + +python_library( + name = "cortex_m_passes", + srcs = ["replace_quant_nodes_pass.py"], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//executorch/backends/cortex_m/ops:ops", + ] +) diff --git a/backends/cortex_m/passes/replace_quant_nodes_pass.py b/backends/cortex_m/passes/replace_quant_nodes_pass.py new file mode 100644 index 00000000000..61f83a20c6f --- /dev/null +++ b/backends/cortex_m/passes/replace_quant_nodes_pass.py @@ -0,0 +1,65 @@ +from typing import Callable, Dict, Tuple +import torch + +import executorch.backends.cortex_m.ops.operators # noqa + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue + + +class ReplaceQuantNodesPass(ExportPass): + """ + Replace quantize and dequantize nodes with the corresponding + quantize_per_tensor and dequantize_per_tensor nodes. + """ + + @staticmethod + def is_qualified_quantize_per_tensor(args) -> bool: + return ( + args[3] >= torch.iinfo(torch.int8).min # qmin + and args[4] <= torch.iinfo(torch.int8).max # qmax + and args[5] == torch.int8 # output dtype + ) + + @staticmethod + def is_qualified_dequantize_per_tensor(args) -> bool: + return ( + args[3] >= torch.iinfo(torch.int8).min # qmin + and args[4] <= torch.iinfo(torch.int8).max # qmax + and args[5] == torch.int8 # input dtype + ) + + def call_operator( + self, + op: Callable[..., object], + args: Tuple[object, ...], + kwargs: Dict[str, object], + meta: NodeMetadata, + ) -> ProxyValue: + assert isinstance( + op, EdgeOpOverload + ), f"Op must be an EdgeOpOverload, got {type(op)} for op {op}. Try running this pass after to_edge()." + if ( + op == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + and self.is_qualified_quantize_per_tensor(args) + ): + return super().call_operator( + exir_ops.edge.cortex_m.quantize_per_tensor.default, + args, + kwargs, + meta, + ) + elif ( + op == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + and self.is_qualified_dequantize_per_tensor(args) + ): + return super().call_operator( + exir_ops.edge.cortex_m.dequantize_per_tensor.default, + args, + kwargs, + meta, + ) + # For all other operators, pass through unchanged + else: + return super().call_operator(op, args, kwargs, meta) diff --git a/backends/cortex_m/test/TARGETS b/backends/cortex_m/test/TARGETS new file mode 100644 index 00000000000..19293263944 --- /dev/null +++ b/backends/cortex_m/test/TARGETS @@ -0,0 +1,12 @@ +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") + +python_unittest( + name = "test_replace_quant_nodes", + srcs = ["test_replace_quant_nodes.py"], + deps = [ + "//pytorch/ao:torchao", # @manual + "//caffe2:torch", + "//executorch/backends/cortex_m/passes:cortex_m_passes", + "//executorch/backends/cortex_m/ops:ops", + ], +) diff --git a/backends/cortex_m/test/test_replace_quant_nodes.py b/backends/cortex_m/test/test_replace_quant_nodes.py new file mode 100644 index 00000000000..da4cfa0b96e --- /dev/null +++ b/backends/cortex_m/test/test_replace_quant_nodes.py @@ -0,0 +1,207 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +import unittest +from dataclasses import dataclass +from typing import Optional + +import executorch +import executorch.backends.cortex_m.ops.operators # noqa + +import torch +from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import ( + ReplaceQuantNodesPass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch.ao.quantization.observer import HistogramObserver +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer.quantizer import ( + QuantizationAnnotation, + QuantizationSpec, + Quantizer, +) +from torch.export import export, export_for_training +from torch.fx import GraphModule, Node + + +@dataclass(eq=True, frozen=True) +class QuantizationConfig: + input_activation: Optional[QuantizationSpec] + output_activation: Optional[QuantizationSpec] + + +class AddQuantizer(Quantizer): + def __init__(self): + super().__init__() + + @staticmethod + def _get_qspec(): + return QuantizationSpec( + dtype=torch.int8, + quant_min=-128, + quant_max=127, + qscheme=torch.per_tensor_symmetric, + is_dynamic=False, + observer_or_fake_quant_ctr=HistogramObserver.with_args(eps=2**-12), + ) + + @staticmethod + def _get_qconfig(): + return QuantizationConfig( + input_activation=AddQuantizer._get_qspec(), + output_activation=AddQuantizer._get_qspec(), + ) + + @staticmethod + def _is_annotated(nodes: list[Node]): + annotated = False + for node in nodes: + annotated = annotated or ( + "quantization_annotation" in node.meta + and node.meta["quantization_annotation"]._annotated + ) + return annotated + + def annotate(self, model: GraphModule) -> torch.fx.GraphModule: + config = self._get_qconfig() + annotated_partitions = [] + for node in model.graph.nodes: + if node.op != "call_function" or node.target not in [ + torch.ops.aten.add.Tensor, + torch.ops.aten.add_.Tensor, + ]: + continue + + if self._is_annotated([node]): + continue + + input_qspec_map = { + node.args[0]: config.input_activation, + node.args[1]: config.input_activation, + } + output_qspec = config.output_activation + + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_qspec, + _annotated=True, + ) + annotated_partitions.append([node]) + return annotated_partitions + + def validate(self, model: GraphModule) -> None: + pass + + +def check_count( + graph_module: GraphModule, op: torch.fx.node.Target, expected_count: int +): + """ + Check that the graph module contains exactly the expected number of nodes with the given op. + """ + actual_count = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == op: + actual_count += 1 + assert ( + actual_count == expected_count + ), f"Expected {expected_count} {op} nodes, got {actual_count}" + + +class TestReplaceQuantOps(unittest.TestCase): + """ + Test suite for the ReplaceQuantNodesPass which replaces quantized_decomposed quant/dequant ops + with cortex_m specific implementations. + """ + + def test_replace_quant_ops(self): + """ + Test that quantize_per_tensor and dequantize_per_tensor nodes are correctly replaced + with their cortex_m equivalents while preserving the same functionality. + """ + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + x + + m = M() + example_inputs = (torch.randn(10, 11, 12),) + + # quantize + captured_graph_module = export_for_training( + m.eval(), example_inputs, strict=True + ).module() + quantizer = AddQuantizer() + prepared_graph_module = prepare_pt2e(captured_graph_module, quantizer) + converted_graph_module = convert_pt2e(prepared_graph_module) + + # export + exported = export(converted_graph_module, example_inputs, strict=True) + + # to edge + epm = executorch.exir.to_edge( + exported, + compile_config=executorch.exir.EdgeCompileConfig(_check_ir_validity=False), + ) + graph_module = epm.exported_program().graph_module + + quant_count_before = 0 + dequant_count_before = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function": + if ( + node.target + == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ): + quant_count_before += 1 + elif ( + node.target + == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ): + dequant_count_before += 1 + edge_output = graph_module(*example_inputs) + + # to cortex_m + epm = epm.transform( + [ + ReplaceQuantNodesPass(), + ] + ) + graph_module = epm.exported_program().graph_module + check_count( + graph_module, + exir_ops.edge.cortex_m.quantize_per_tensor.default, + quant_count_before, + ) + check_count( + graph_module, + exir_ops.edge.cortex_m.dequantize_per_tensor.default, + dequant_count_before, + ) + check_count( + graph_module, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + 0, + ) + check_count( + graph_module, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + 0, + ) + cortex_m_output = graph_module(*example_inputs) + + # check output - numerical equivalence should be preserved + torch.testing.assert_close(edge_output, cortex_m_output) + + # To executorch + expm = epm.to_executorch() + for op in expm.executorch_program.execution_plan[0].operators: + if "quantize_per_tensor" in op.name: + assert op.name in [ + "cortex_m::quantize_per_tensor", + "cortex_m::dequantize_per_tensor", + ], f"Unexpected op {op.name}"