Skip to content

[WIP][cortex-m] init #10200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/cortex_m/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Cortex-M Backend

WIP. This is a temporary backend for Cortex-M CPUs. It is not intended to be used in production, but rather as a proof of concept. Things will change without notice.
21 changes: 21 additions & 0 deletions backends/cortex_m/ops/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load("@fbcode_macros//build_defs:export_files.bzl", "export_file")
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib")

oncall("executorch")

python_library(
name = "ops",
srcs = [
"operators.py",
],
deps = [
"fbcode//caffe2:torch",
]
)
90 changes: 90 additions & 0 deletions backends/cortex_m/ops/operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from torch.library import impl, Library, register_fake
from executorch.exir.dialects._ops import (
ops as exir_ops,
) # To provide the implementation of the operators

# New operator library with a custom namespace to allow fusion etc.
lib = Library("cortex_m", "DEF")

###
# dequantize_per_tensor
###

lib.define(
"quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
)

lib.define(
"quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
)

@register_fake("cortex_m::quantize_per_tensor")
def quantize_per_tensor_meta(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty_like(input, dtype=dtype)


@impl(lib, "quantize_per_tensor", "CompositeExplicitAutograd")
def quantize_per_tensor_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
The implementation of the quantize_per_tensor operator is the same as the
quantize_per_tensor operator in the edge dialect.
"""
return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
input, scale, zero_point, quant_min, quant_max, dtype
)


###
# dequantize_per_tensor
###

lib.define(
"dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
)
lib.define(
"dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
)

@register_fake("cortex_m::dequantize_per_tensor")
def dequantize_per_tensor_meta(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty_like(input, dtype=torch.float)


@impl(lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
def dequantize_per_tensor_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
The implementation of the dequantize_per_tensor operator is the same as the
dequantize_per_tensor operator in the edge dialect.
"""
return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
input, scale, zero_point, quant_min, quant_max, dtype
)
15 changes: 15 additions & 0 deletions backends/cortex_m/passes/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("executorch")

python_library(
name = "cortex_m_passes",
srcs = ["replace_quant_nodes_pass.py"],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/backends/cortex_m/ops:ops",
]
)
65 changes: 65 additions & 0 deletions backends/cortex_m/passes/replace_quant_nodes_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Callable, Dict, Tuple
import torch

import executorch.backends.cortex_m.ops.operators # noqa

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue


class ReplaceQuantNodesPass(ExportPass):
"""
Replace quantize and dequantize nodes with the corresponding
quantize_per_tensor and dequantize_per_tensor nodes.
"""

@staticmethod
def is_qualified_quantize_per_tensor(args) -> bool:
return (
args[3] >= torch.iinfo(torch.int8).min # qmin
and args[4] <= torch.iinfo(torch.int8).max # qmax
and args[5] == torch.int8 # output dtype
)

@staticmethod
def is_qualified_dequantize_per_tensor(args) -> bool:
return (
args[3] >= torch.iinfo(torch.int8).min # qmin
and args[4] <= torch.iinfo(torch.int8).max # qmax
and args[5] == torch.int8 # input dtype
)

def call_operator(
self,
op: Callable[..., object],
args: Tuple[object, ...],
kwargs: Dict[str, object],
meta: NodeMetadata,
) -> ProxyValue:
assert isinstance(
op, EdgeOpOverload
), f"Op must be an EdgeOpOverload, got {type(op)} for op {op}. Try running this pass after to_edge()."
if (
op == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
and self.is_qualified_quantize_per_tensor(args)
):
return super().call_operator(
exir_ops.edge.cortex_m.quantize_per_tensor.default,
args,
kwargs,
meta,
)
elif (
op == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
and self.is_qualified_dequantize_per_tensor(args)
):
return super().call_operator(
exir_ops.edge.cortex_m.dequantize_per_tensor.default,
args,
kwargs,
meta,
)
# For all other operators, pass through unchanged
else:
return super().call_operator(op, args, kwargs, meta)
12 changes: 12 additions & 0 deletions backends/cortex_m/test/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")

python_unittest(
name = "test_replace_quant_nodes",
srcs = ["test_replace_quant_nodes.py"],
deps = [
"//pytorch/ao:torchao", # @manual
"//caffe2:torch",
"//executorch/backends/cortex_m/passes:cortex_m_passes",
"//executorch/backends/cortex_m/ops:ops",
],
)
Loading
Loading