Skip to content

Commit dd5c0fc

Browse files
committed
[WebNN EP] Remove NHWC preferred layout
Currently WebNN CPU backend has supported NCHW layout in Chromium, we can now drop NHWC preferred layout for CPU backend in WebNN EP to simplify the code.
1 parent 1d4b161 commit dd5c0fc

10 files changed

+39
-294
lines changed

onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc

+9-12
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@ common::Status ComputeConvPads(const std::vector<int64_t> input_shape,
1919
const std::vector<int64_t>& onnx_strides,
2020
const std::vector<int64_t>& onnx_dilations,
2121
AutoPadType auto_pad_type,
22-
std::vector<int64_t>& pads_out,
23-
bool use_nchw) {
24-
const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1];
25-
const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2];
22+
std::vector<int64_t>& pads_out) {
23+
const int64_t input_size_y = input_shape[2];
24+
const int64_t input_size_x = input_shape[3];
2625
const int64_t stride_y = onnx_strides[0];
2726
const int64_t stride_x = onnx_strides[1];
2827
const int64_t dilation_y = onnx_dilations[0];
@@ -54,16 +53,15 @@ common::Status HandleAutoPad(const std::vector<int64_t> input_shape,
5453
const std::vector<int64_t>& onnx_strides,
5554
const std::vector<int64_t>& onnx_dilations,
5655
AutoPadType auto_pad_type,
57-
std::vector<int64_t>& pads_out,
58-
bool use_nchw) {
56+
std::vector<int64_t>& pads_out) {
5957
if (AutoPadType::SAME_UPPER == auto_pad_type) {
6058
ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x,
6159
onnx_pads, onnx_strides, onnx_dilations,
62-
AutoPadType::SAME_UPPER, pads_out, use_nchw));
60+
AutoPadType::SAME_UPPER, pads_out));
6361
} else {
6462
ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x,
6563
onnx_pads, onnx_strides, onnx_dilations,
66-
AutoPadType::SAME_LOWER, pads_out, use_nchw));
64+
AutoPadType::SAME_LOWER, pads_out));
6765
}
6866
return Status::OK();
6967
}
@@ -111,10 +109,9 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t>
111109
const std::vector<int64_t>& onnx_output_padding,
112110
AutoPadType auto_pad_type,
113111
std::vector<int64_t>& pads_out,
114-
std::vector<int64_t>& output_shape_out,
115-
bool use_nchw) {
116-
const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1];
117-
const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2];
112+
std::vector<int64_t>& output_shape_out) {
113+
const int64_t input_size_y = input_shape[2];
114+
const int64_t input_size_x = input_shape[3];
118115
const int64_t stride_y = onnx_strides[0];
119116
const int64_t stride_x = onnx_strides[1];
120117
const int64_t dilation_y = onnx_dilations[0];

onnxruntime/core/providers/webnn/builders/impl/builder_utils.h

+2-4
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ common::Status HandleAutoPad(const std::vector<int64_t> input_shape,
2121
const std::vector<int64_t>& onnx_strides,
2222
const std::vector<int64_t>& onnx_dilations,
2323
AutoPadType auto_pad_type,
24-
std::vector<int64_t>& pads_out,
25-
bool use_nchw) ORT_MUST_USE_RESULT;
24+
std::vector<int64_t>& pads_out) ORT_MUST_USE_RESULT;
2625

2726
// Compute pads and output shape for ConvTranspose.
2827
common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t> input_shape,
@@ -34,8 +33,7 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t>
3433
const std::vector<int64_t>& onnx_output_padding,
3534
AutoPadType auto_pad_type,
3635
std::vector<int64_t>& pads_out,
37-
std::vector<int64_t>& output_shape_out,
38-
bool use_nchw) ORT_MUST_USE_RESULT;
36+
std::vector<int64_t>& output_shape_out) ORT_MUST_USE_RESULT;
3937

4038
} // namespace webnn
4139
} // namespace onnxruntime

onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc

+12-159
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ namespace webnn {
1818

1919
class ConvOpBuilder : public BaseOpBuilder {
2020
// Add operator related.
21-
public:
22-
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
23-
2421
private:
2522
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
2623
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
@@ -33,13 +30,6 @@ class ConvOpBuilder : public BaseOpBuilder {
3330
const logging::Logger& logger) const override;
3431
};
3532

36-
void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
37-
// skip the weight for conv as we need to transpose for preferred layout NHWC.
38-
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
39-
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // W
40-
}
41-
}
42-
4333
// Helper functions
4434
common::Status SetConvBaseOptions(ModelBuilder& model_builder,
4535
const Node& node, emscripten::val& options,
@@ -48,7 +38,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
4838
const std::vector<int64_t>& strides,
4939
const std::vector<int64_t>& dilations,
5040
std::vector<int64_t>& pads,
51-
const bool is_nhwc,
5241
const bool is_conv1d,
5342
const logging::Logger& logger) {
5443
NodeAttrHelper helper(node);
@@ -61,7 +50,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
6150
// Calculate explicit padding for autoPad.
6251
if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) {
6352
ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3],
64-
pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc));
53+
pads, strides, dilations, auto_pad_type, pads_out));
6554
pads = pads_out;
6655
}
6756
} else if (node.OpType() == "ConvTranspose") {
@@ -82,7 +71,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
8271
// Otherwise compute the output shape, as well as the pads if the auto_pad attribute is SAME_UPPER/SAME_LOWER.
8372
ORT_RETURN_IF_ERROR(ComputeConvTransposePadsAndOutputShape(input_shape, weight_shape[2], weight_shape[3],
8473
pads, strides, dilations, output_padding,
85-
auto_pad_type, pads_out, output_shape, !is_nhwc));
74+
auto_pad_type, pads_out, output_shape));
8675

8776
if (output_shape[0] != -1 && output_shape[1] != -1) {
8877
options.set("outputSizes", emscripten::val::array(GetVecUint32FromVecInt64(output_shape)));
@@ -111,89 +100,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
111100
return Status::OK();
112101
}
113102

114-
// Both depthwise Conv and ConvTranspose share the same logic to add the layout.
115-
Status AddInitializerInNewLayout(ModelBuilder& model_builder,
116-
const std::string& name,
117-
bool is_conv,
118-
bool is_conv1d) {
119-
const auto& tensor = *model_builder.GetInitializerTensors().at(name);
120-
auto data_type = tensor.data_type();
121-
122-
const auto& shape = tensor.dims();
123-
std::vector<uint32_t> dims = GetVecUint32FromVecInt64(std::vector<int64_t>(std::begin(shape), std::end(shape)));
124-
125-
if (is_conv1d) {
126-
// Support conv1d by prepending a 1 size dimension.
127-
dims.push_back(1);
128-
}
129-
130-
const uint8_t* src = nullptr;
131-
Initializer unpacked_tensor(tensor, model_builder.GetGraphViewer().ModelPath());
132-
src = unpacked_tensor.DataAsByteSpan().data();
133-
const auto out_t = dims[0], in_t = dims[1],
134-
h_t = dims[2], w_t = dims[3];
135-
std::vector<uint32_t> dest_shape;
136-
if (is_conv == 1)
137-
dest_shape = {out_t, h_t, w_t, in_t}; // L_0231
138-
else
139-
dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv and convTranspose weight
140-
141-
SafeInt<size_t> num_elements = SafeInt<size_t>(Product(dest_shape));
142-
143-
size_t element_size{0};
144-
switch (data_type) {
145-
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
146-
element_size = sizeof(uint8_t);
147-
break;
148-
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
149-
element_size = sizeof(int8_t);
150-
break;
151-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
152-
element_size = sizeof(uint16_t);
153-
break;
154-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
155-
element_size = sizeof(float);
156-
break;
157-
default:
158-
break;
159-
}
160-
std::unique_ptr<uint8_t[]> buffer_holder(new uint8_t[element_size * num_elements]);
161-
uint8_t* buffer = buffer_holder.get();
162-
163-
for (uint32_t out = 0; out < out_t; out++) {
164-
for (uint32_t in = 0; in < in_t; in++) {
165-
for (uint32_t h = 0; h < h_t; h++) {
166-
for (uint32_t w = 0; w < w_t; w++) {
167-
auto onnx_idx = out * in_t * h_t * w_t +
168-
in * h_t * w_t +
169-
h * w_t +
170-
w;
171-
172-
uint32_t nnapi_idx;
173-
if (is_conv == 1) { // L_0231
174-
nnapi_idx = out * h_t * w_t * in_t +
175-
h * w_t * in_t +
176-
w * in_t +
177-
in;
178-
} else { // L_1230 for depthwise conv weight
179-
nnapi_idx = in * h_t * w_t * out_t +
180-
h * w_t * out_t +
181-
w * out_t +
182-
out;
183-
}
184-
185-
for (size_t i = 0; i < element_size; i++) {
186-
buffer[element_size * nnapi_idx + i] = src[element_size * onnx_idx + i];
187-
}
188-
}
189-
}
190-
}
191-
}
192-
ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(name, buffer, num_elements * element_size,
193-
dest_shape, data_type));
194-
return Status::OK();
195-
}
196-
197103
// Add operator related.
198104

199105
Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
@@ -203,7 +109,6 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
203109
const auto& op_type = node.OpType();
204110
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
205111
emscripten::val output = emscripten::val::object();
206-
const auto& initializers(model_builder.GetInitializerTensors());
207112

208113
std::vector<int64_t> input_shape;
209114
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
@@ -216,19 +121,11 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
216121
auto dilations = helper.Get("dilations", std::vector<int64_t>{1, 1});
217122
auto pads = helper.Get("pads", std::vector<int64_t>{0, 0, 0, 0});
218123

219-
const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC;
220124
const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3;
221-
const bool is_constant_weight = Contains(initializers, weight_name);
222125
// Support conv1d by prepending a 1 or 2 size dimensions.
223126
if (is_conv1d) {
224127
// Reshape input.
225-
if (is_nhwc) {
226-
// For NHWC preferred layout, the input has been transposed.
227-
// For conv1d it is NCD1 -> ND1C, so we need to prepend 1 to the index 2.
228-
input_shape.insert(input_shape.begin() + 2, 1);
229-
} else {
230-
input_shape.push_back(1);
231-
}
128+
input_shape.push_back(1);
232129
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(input_shape);
233130
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, emscripten::val::array(new_shape));
234131

@@ -244,63 +141,19 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
244141
emscripten::val options = emscripten::val::object();
245142
options.set("label", node.Name());
246143
ORT_RETURN_IF_ERROR(SetConvBaseOptions(
247-
model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger));
248-
bool depthwise = false;
249-
if (op_type == "Conv" || op_type == "ConvInteger") {
250-
int groups = options["groups"].as<int>();
251-
if (is_nhwc) {
252-
depthwise = (groups == input_shape[3] && groups != 1);
253-
options.set("inputLayout", emscripten::val("nhwc"));
254-
if (is_constant_weight) {
255-
ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d));
256-
}
257-
if (!depthwise) {
258-
options.set("filterLayout", emscripten::val("ohwi"));
259-
} else {
260-
options.set("filterLayout", emscripten::val("ihwo"));
261-
}
262-
}
263-
} else { // ConvTranspose
264-
if (is_nhwc) {
265-
options.set("inputLayout", emscripten::val("nhwc"));
266-
options.set("filterLayout", emscripten::val("ohwi"));
267-
if (is_constant_weight) {
268-
ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, true, is_conv1d));
269-
}
270-
}
271-
}
272-
144+
model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_conv1d, logger));
273145
emscripten::val filter = model_builder.GetOperand(weight_name);
274146

275147
if (is_conv1d) {
276148
// Reshape weight to 4D for conv1d.
277-
if (!is_nhwc || !is_constant_weight) {
278-
// The weight_shape has been appended 1's, reshape weight operand.
279-
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(weight_shape);
280-
emscripten::val reshape_options = emscripten::val::object();
281-
reshape_options.set("label", node.Name() + "_reshape_filter");
282-
filter = model_builder.GetBuilder().call<emscripten::val>("reshape",
283-
filter,
284-
emscripten::val::array(new_shape),
285-
reshape_options);
286-
}
287-
}
288-
289-
emscripten::val transpose_options = emscripten::val::object();
290-
if (is_nhwc && !is_constant_weight) {
291-
// For NHWC preferred layout, if the weight is input:
292-
// - Transpose it from iohw -> ohwi for convTranspose.
293-
// - Transpose it from oihw -> ihwo for depthwise conv.
294-
// - Transpose it from oihw -> ohwi for conv.
295-
std::vector<uint32_t> perm(4);
296-
if (op_type == "ConvTranspose" || depthwise) {
297-
perm = {1, 2, 3, 0}; // L_1230 for depthwise conv and convTranspose weight
298-
} else {
299-
perm = {0, 2, 3, 1}; // L_0231
300-
}
301-
transpose_options.set("permutation", emscripten::val::array(perm));
302-
transpose_options.set("label", node.Name() + "_transpose_filter");
303-
filter = model_builder.GetBuilder().call<emscripten::val>("transpose", filter, transpose_options);
149+
// The weight_shape has been appended 1's, reshape weight operand.
150+
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(weight_shape);
151+
emscripten::val reshape_options = emscripten::val::object();
152+
reshape_options.set("label", node.Name() + "_reshape_filter");
153+
filter = model_builder.GetBuilder().call<emscripten::val>("reshape",
154+
filter,
155+
emscripten::val::array(new_shape),
156+
reshape_options);
304157
}
305158

306159
if (op_type == "Conv") {

onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc

+1-8
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
7979
ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs.");
8080
emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name());
8181
emscripten::val variance = model_builder.GetOperand(input_defs[4]->Name());
82-
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
83-
options.set("axis", rank - 1);
84-
}
8582

8683
output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
8784
} else if (op_type == "LayerNormalization") {
@@ -104,9 +101,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
104101
std::back_inserter(new_shape),
105102
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
106103

107-
size_t insertion_offset = (model_builder.GetPreferredLayout() == DataLayout::NHWC) ? 2 : 3;
108104
ptrdiff_t excess_rank = new_shape.size() - webnn_shape_rank;
109-
auto insertion_point = new_shape.begin() + insertion_offset;
105+
auto insertion_point = new_shape.begin() + 3;
110106
if (input_shape.size() < webnn_shape_rank) {
111107
// Pad the shape with extra 1's to satisfy WebNN v1's rank requirements.
112108
new_shape.insert(insertion_point, -excess_rank, 1);
@@ -125,9 +121,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
125121
reshape_input_options);
126122
}
127123

128-
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
129-
options.set("layout", emscripten::val("nhwc"));
130-
}
131124
output = model_builder.GetBuilder().call<emscripten::val>("instanceNormalization", input, options);
132125
// Reshape back to the original output shape for 3D input.
133126
if (input_shape.size() != 4) {

onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc

+2-7
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
7070
options.set("strides", emscripten::val::array(strides));
7171
const auto dilations = helper.Get("dilations", std::vector<int32_t>{1, 1});
7272
options.set("dilations", emscripten::val::array(dilations));
73-
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
74-
options.set("layout", emscripten::val("nhwc"));
75-
} else {
76-
options.set("layout", emscripten::val("nchw"));
77-
}
73+
options.set("layout", emscripten::val("nchw"));
7874

7975
// Add Padding.
8076
// Usually using autopadding is more efficient than using explicit padding.
@@ -93,8 +89,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
9389
helper.Get("strides", std::vector<int64_t>{1, 1}),
9490
helper.Get("dilations", std::vector<int64_t>{1, 1}),
9591
auto_pad_type,
96-
pads_out,
97-
model_builder.GetPreferredLayout() == DataLayout::NCHW));
92+
pads_out));
9893
pads = GetVecUint32FromVecInt64(pads_out);
9994
}
10095
// Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width],

0 commit comments

Comments
 (0)