Skip to content

Commit e6bdeda

Browse files
Qualcomm AI Engine Direct - OSS models breakage fix (#10191)
### Summary - Fastvit breakage fix - ConvFormer breakage fix - Changed dataset for Edsr due to unavailable dataset link - Add test case for ConvertSquareToPow pass ### Test plan ```bash python ./examples/qualcomm/oss_scripts/fastvit.py -m ${soc} -b build-android -H ${host_id} -s ${device_id} --oss_repo ${Path_to_oss_repo} --pretrained_weight ${Path_to_pretrained_weight} -d ${Path_to_dataset_dir} ``` ```bash python ./examples/qualcomm/oss_scripts/conv_former.py -m ${soc} -b build-android -H ${host_id} -s ${device_id} -d ${Path_to_dataset_dir} ``` cc @cccclai @cbilgin
1 parent f692ff5 commit e6bdeda

File tree

8 files changed

+89
-10
lines changed

8 files changed

+89
-10
lines changed

backends/qualcomm/_passes/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .annotate_unbind import AnnotateUnbind
1010
from .convert_bmm_to_matmul import ConvertBmmToMatmul
1111
from .convert_conv1d_to_conv2d import ConvertConv1dToConv2d
12+
from .convert_square_to_pow import ConvertSquareToPow
1213
from .convert_upsample_bicubic2d import ConvertUpsampleBicubicWithBilinear
1314
from .decompose_any import DecomposeAny
1415
from .decompose_cdist import DecomposeCDist
@@ -42,6 +43,7 @@
4243
AnnotateUnbind,
4344
ConvertBmmToMatmul,
4445
ConvertConv1dToConv2d,
46+
ConvertSquareToPow,
4547
ConvertUpsampleBicubicWithBilinear,
4648
DecomposeAny,
4749
DecomposeCDist,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
from executorch.exir.pass_base import ExportPass, PassResult
8+
9+
from .utils import copy_meta
10+
11+
12+
class ConvertSquareToPow(ExportPass):
13+
"""
14+
Convert square to pow with a scalar value of 2.
15+
This allows LiftConstantScalarOperands to lift the scalar into a scalar.
16+
Otherwise, the square op will be converted to pow.tensor_scalar after to_edge.
17+
"""
18+
19+
def __init__(self) -> None:
20+
super().__init__()
21+
22+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
23+
graph = graph_module.graph
24+
for node in graph.nodes:
25+
if node.target == torch.ops.aten.square.default:
26+
input_node = node.args[0]
27+
with graph_module.graph.inserting_after(input_node):
28+
pow_op = torch.ops.aten.pow.Tensor_Scalar
29+
pow_node = graph.create_node(
30+
"call_function", pow_op, (input_node, 2)
31+
)
32+
pow_node.meta = copy_meta(node.meta)
33+
for user in node.users.copy():
34+
user.replace_input_with(node, pow_node)
35+
36+
graph.eliminate_dead_code()
37+
graph_module.recompile()
38+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
AnnotateUnbind,
1515
ConvertBmmToMatmul,
1616
ConvertConv1dToConv2d,
17+
ConvertSquareToPow,
1718
ConvertUpsampleBicubicWithBilinear,
1819
DecomposeAny,
1920
DecomposeCDist,
@@ -199,6 +200,7 @@ def transform_for_export_pipeline(self, exported_program: ExportedProgram):
199200
self.add_pass(DecomposeScaledDotProductAttention())
200201
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
201202
self.add_pass(DecomposeExpM1())
203+
self.add_pass(ConvertSquareToPow())
202204
self.add_pass(LiftConstantScalarOperands())
203205
self._transform(exported_program.graph_module)
204206
ep = lift_constant_tensor_pass(exported_program)

backends/qualcomm/tests/models.py

+9
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,15 @@ def forward(self, x):
14361436
return x / torch.sqrt(torch.tensor([64.0]))
14371437

14381438

1439+
class SquaredReLU(torch.nn.Module):
1440+
def __init__(self, inplace=False):
1441+
super().__init__()
1442+
self.relu = torch.nn.ReLU(inplace=inplace)
1443+
1444+
def forward(self, x):
1445+
return torch.square(self.relu(x))
1446+
1447+
14391448
class Squeeze(torch.nn.Module):
14401449
def __init__(self):
14411450
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -843,6 +843,11 @@ def test_qnn_backend_softmax(self):
843843
sample_input = (torch.randn([1, 4, 8, 8]),)
844844
self.lower_module_and_test_output(module, sample_input)
845845

846+
def test_qnn_backend_squared_relu(self):
847+
module = SquaredReLU() # noqa: F405
848+
sample_input = (torch.randn([2, 5, 1, 3]),)
849+
self.lower_module_and_test_output(module, sample_input)
850+
846851
def test_qnn_backend_squeeze(self):
847852
module = Squeeze() # noqa: F405
848853
sample_input = (torch.randn([1, 3, 3]),)
@@ -2001,6 +2006,12 @@ def test_qnn_backend_softmax(self):
20012006
module = self.get_qdq_module(module, sample_input)
20022007
self.lower_module_and_test_output(module, sample_input)
20032008

2009+
def test_qnn_backend_squared_relu(self):
2010+
module = SquaredReLU() # noqa: F405
2011+
sample_input = (torch.randn([2, 5, 1, 3]),)
2012+
module = self.get_qdq_module(module, sample_input)
2013+
self.lower_module_and_test_output(module, sample_input)
2014+
20042015
def test_qnn_backend_squeeze(self):
20052016
module = Squeeze() # noqa: F405
20062017
sample_input = (torch.randn([1, 3, 3]),)
@@ -3642,7 +3653,7 @@ def test_efficientSAM(self):
36423653
self.skipTest("missing required envs")
36433654
cmds = [
36443655
"python",
3645-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/efficientSAM.py",
3656+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/efficientSAM/efficientSAM.py",
36463657
"--dataset",
36473658
self.image_dataset,
36483659
"--artifact",

examples/qualcomm/oss_scripts/conv_former.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
import numpy as np
1313
import timm
1414
import torch
15-
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16-
from executorch.backends.qualcomm.utils.constants import (
17-
QCOM_PASS_EXPAND_BROADCAST_SHAPE,
15+
from executorch.backends.qualcomm._passes.expand_broadcast_tensor_shape import (
16+
ExpandBroadcastTensorShape,
17+
)
18+
from executorch.backends.qualcomm._passes.qnn_pass_manager import (
19+
get_capture_program_passes,
1820
)
21+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
22+
from executorch.backends.qualcomm.utils.constants import QCOM_PASS_ACTIVATE_KEY
1923
from executorch.examples.qualcomm.utils import (
2024
build_executorch_binary,
2125
get_imagenet_dataset,
@@ -55,14 +59,17 @@ def main(args):
5559

5660
model = model.eval()
5761

62+
# lower to QNN
63+
passes_job = get_capture_program_passes()
64+
passes_job[ExpandBroadcastTensorShape][QCOM_PASS_ACTIVATE_KEY] = True
5865
build_executorch_binary(
5966
model,
6067
inputs[0],
6168
args.model,
6269
f"{args.artifact}/{pte_filename}",
6370
inputs,
6471
quant_dtype=QuantDtype.use_8a8w,
65-
custom_pass_config={QCOM_PASS_EXPAND_BROADCAST_SHAPE},
72+
passes_job=passes_job,
6673
)
6774

6875
if args.compile_only:

examples/qualcomm/oss_scripts/fastvit.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,16 @@ def main(args):
101101
),
102102
)
103103
# rewrite default per-channel ptq config
104-
quantizer.per_channel_quant_config = QuantizationConfig(
104+
quantizer.default_quant_config.per_channel_quant_config = QuantizationConfig(
105105
input_activation=act_qspec,
106106
output_activation=act_qspec,
107107
weight=weight_qspec,
108108
bias=_derived_bias_quant_spec,
109109
)
110110

111111
# rewrite default ptq config
112-
q_config = quantizer.quant_config
113-
quantizer.quant_config = QuantizationConfig(
112+
q_config = quantizer.default_quant_config.quant_config
113+
quantizer.default_quant_config.quant_config = QuantizationConfig(
114114
input_activation=act_qspec,
115115
output_activation=act_qspec,
116116
weight=q_config.weight,

examples/qualcomm/scripts/edsr.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from PIL import Image
2626
from torch.utils.data import Dataset
27-
from torchsr.datasets import B100
27+
from torchsr.datasets import B100, Div2K
2828
from torchvision.transforms.functional import to_pil_image, to_tensor
2929

3030

@@ -75,6 +75,16 @@ def get_b100(
7575
return SrDataset(hr_dir, lr_dir)
7676

7777

78+
def get_Div2K(
79+
dataset_dir: str,
80+
):
81+
hr_dir = f"{dataset_dir}/sr_bm_dataset/DIV2K/DIV2K_valid_HR"
82+
lr_dir = f"{dataset_dir}/sr_bm_dataset/DIV2K/DIV2K_valid_LR_bicubic/X2"
83+
if not os.path.exists(hr_dir) or not os.path.exists(lr_dir):
84+
Div2K(root=f"{dataset_dir}/sr_bm_dataset", scale=2, download=True)
85+
return SrDataset(hr_dir, lr_dir)
86+
87+
7888
def get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str):
7989
if not (lr_dir and hr_dir) and not default_dataset:
8090
raise RuntimeError(
@@ -85,7 +95,7 @@ def get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str
8595
raise RuntimeError("Either use custom dataset, or use default dataset.")
8696

8797
if default_dataset:
88-
return get_b100(dataset_dir)
98+
return get_Div2K(dataset_dir)
8999

90100
return SrDataset(hr_dir, lr_dir)
91101

0 commit comments

Comments
 (0)