Skip to content

Commit 7d35c68

Browse files
authored
[Bug Fix] Fix padding when running in NHWC (#9729)
### Summary There is a bug when there is a constant_pad between two convolutions. In order to minimize permutes associated with memory format changes, we sometimes compute ops in NHWC. This is the case for ConstantPad when it is between two convs: ``` a = conv(a) a = constant_pad(a, paddings=[1, 2, 3, 4]) a = conv(a) ``` in this case we need to make sure the paddings given to constant_pad are also permuted to nhwc. ### Test plan python install_executorch.py --editable python -m unittest backends.xnnpack.test.ops.test_static_constant_pad.TestStaticConstantPad.test_fp32_static_constant_pad_nhwc
1 parent ebe8522 commit 7d35c68

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

backends/xnnpack/operators/op_static_constant_pad.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import cast, Dict, List
88

99
import torch
10+
1011
from executorch.backends.xnnpack.operators.node_visitor import (
1112
get_tensor_value,
1213
NodeVisitor,
@@ -17,7 +18,11 @@
1718
XNNStaticConstantPad,
1819
XNode,
1920
)
20-
from executorch.backends.xnnpack.utils.utils import check_or_raise, get_input_node
21+
from executorch.backends.xnnpack.utils.utils import (
22+
check_or_raise,
23+
get_input_node,
24+
PERM_NCHW_TO_NHWC,
25+
)
2126

2227

2328
@register_node_visitor
@@ -113,8 +118,15 @@ def define_node(
113118
# b)
114119
# tuple[0] = prepadding dim[-1]
115120
# tuple[1] = postpadding dim[-1]
121+
is_channels_last = node.meta.get("XNN_NHWC_NODE", False)
116122
pre_paddings = all_paddings[-2::-2] # even index elements in reverse order
117123
post_paddings = all_paddings[::-2] # odd index elements in reverse order
124+
if is_channels_last:
125+
check_or_raise(len(pre_paddings) == 4, "Expecting prepaddings to be 4D")
126+
check_or_raise(len(post_paddings) == 4, "Expecting postpaddings to be 4D")
127+
128+
pre_paddings = [pre_paddings[i] for i in PERM_NCHW_TO_NHWC]
129+
post_paddings = [post_paddings[i] for i in PERM_NCHW_TO_NHWC]
118130

119131
# the padding value, which defaults to 0.0
120132
padding_value = cast(float, node.args[2]) if len(node.args) > 2 else 0.0

backends/xnnpack/test/ops/test_static_constant_pad.py

+45
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,30 @@ class TestStaticConstantPad(unittest.TestCase):
1414
def setUp(self):
1515
torch._dynamo.reset()
1616

17+
class NHWCStaticConstantPad(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
self.conv1 = torch.nn.Conv2d(in_channels=2, out_channels=2, kernel_size=1)
21+
self.conv2 = torch.nn.Conv2d(in_channels=13, out_channels=13, kernel_size=1)
22+
23+
def forward(self, x):
24+
a = self.conv1(x)
25+
pad_6 = (1, 2, 3, 4, 5, 6)
26+
a = torch.nn.functional.pad(
27+
input=a,
28+
pad=pad_6,
29+
mode="constant",
30+
value=3.1,
31+
)
32+
# tensorshape = [1, 13, 10, 7]
33+
a = self.conv2(a)
34+
35+
return a
36+
37+
def sample_inputs(self):
38+
# NCHW
39+
return (torch.randn(1, 2, 3, 4),)
40+
1741
class StaticConstantPadFunctional(torch.nn.Module):
1842
def __init__(self):
1943
super().__init__()
@@ -205,3 +229,24 @@ def test_qs8_static_constant_pad_2d(self):
205229
.serialize()
206230
.run_method_and_compare_outputs()
207231
)
232+
233+
def test_fp32_static_constant_pad_nhwc(self):
234+
conv = self.NHWCStaticConstantPad()
235+
inputs = conv.sample_inputs()
236+
(
237+
Tester(conv, inputs)
238+
.export()
239+
.check_count({"torch.ops.aten.pad.default": 1})
240+
.dump_artifact()
241+
.to_edge_transform_and_lower()
242+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
243+
.check_not(
244+
[
245+
"executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default",
246+
"executorch_exir_dialects_edge__ops_aten_convolution_default",
247+
]
248+
)
249+
.to_executorch()
250+
.serialize()
251+
.run_method_and_compare_outputs()
252+
)

0 commit comments

Comments
 (0)