Skip to content

Commit 605bfa6

Browse files
authored
Remove unused pass and test to replace linalg.vector_norm.
Differential Revision: D73235672 Pull Request resolved: #10296
1 parent 06a4944 commit 605bfa6

File tree

2 files changed

+0
-56
lines changed

2 files changed

+0
-56
lines changed

backends/cadence/aot/replace_ops.py

-25
Original file line numberDiff line numberDiff line change
@@ -1806,30 +1806,6 @@ def call_operator(self, op, args, kwargs, meta):
18061806
return super().call_operator(op, tuple(new_args), kwargs, meta)
18071807

18081808

1809-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
1810-
class ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass(ExportPass):
1811-
"""
1812-
Replace the aten.linalg_vector_norm op with a custom op.
1813-
aten.linalg_vector_norm is not supported by Jarvis, so we
1814-
need to replace it with native_batch_norm at all optimization levels.
1815-
"""
1816-
1817-
def call_operator(self, op, args, kwargs, meta):
1818-
if op != exir_ops.edge.aten.linalg_vector_norm.default:
1819-
return super().call_operator(op, args, kwargs, meta)
1820-
1821-
assert (
1822-
len(args) == 1
1823-
), "aten.linalg_vector_norm should have 1 argument (a tensor), we do not support any custom variants"
1824-
1825-
return super().call_operator(
1826-
exir_ops.edge.cadence.linalg_vector_norm.default,
1827-
args,
1828-
kwargs,
1829-
meta,
1830-
)
1831-
1832-
18331809
@register_cadence_pass(CadencePassAttribute(opt_level=1))
18341810
class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass):
18351811
"""
@@ -2243,7 +2219,6 @@ class CadenceReplaceOpsInGraph:
22432219
ReplacePT2DequantWithCadenceDequantPass,
22442220
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
22452221
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
2246-
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
22472222
ReplaceWhereWithFullArgsWithWhereScalar,
22482223
# ReplaceGeluWithApproximateGeluPass,
22492224
]

backends/cadence/aot/tests/test_replace_ops_passes.py

-31
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
MakeSliceAndCatDimOutermostPass,
2424
ReplaceAddMMWithLinearPass,
2525
ReplaceAtenConvolutionWithJarvisConvolutionPass,
26-
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
2726
ReplaceConstantPadNdWithSlicePass,
2827
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
2928
ReplaceConvWithIm2RowAndLinear,
@@ -1189,36 +1188,6 @@ def forward(self, x):
11891188
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0
11901189
)
11911190

1192-
def test_replace_aten_linalg_vector_norm_with_cadence_linalg_vector_norm(self):
1193-
class LinalgVectorNorm(torch.nn.Module):
1194-
def forward(self, x: torch.Tensor):
1195-
return torch.linalg.vector_norm(x)
1196-
1197-
x = torch.randn(32)
1198-
1199-
graph_module = (
1200-
export_to_edge(LinalgVectorNorm(), (x,)).exported_program().graph_module
1201-
)
1202-
1203-
p = ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass()
1204-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1205-
1206-
# Assert that aten.linalg_vector_norm op was replaced by a
1207-
# cadence.linalg_vector_norm op
1208-
self.assertEqual(
1209-
count_node(
1210-
graph_after_passes,
1211-
exir_ops.edge.aten.linalg_vector_norm.default,
1212-
),
1213-
0,
1214-
)
1215-
self.assertEqual(
1216-
count_node(
1217-
graph_after_passes, exir_ops.edge.cadence.linalg_vector_norm.default
1218-
),
1219-
1,
1220-
)
1221-
12221191
def test_replace_aten_where_with_cadence_where_Scalar(self):
12231192
class WhereScalarModel(torch.nn.Module):
12241193
def forward(self, cond: torch.Tensor):

0 commit comments

Comments
 (0)