Skip to content

Commit 28fe6bd

Browse files
committed
[XNNPACK] Serialize fp16 weights as fp16
1 parent ce74f8e commit 28fe6bd

File tree

4 files changed

+26
-17
lines changed

4 files changed

+26
-17
lines changed

Diff for: backends/xnnpack/operators/node_visitor.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def get_serialized_dtype(
210210
self,
211211
quant_params: Optional[QuantParams],
212212
node: torch.fx.Node,
213-
fp32_static_weight: bool = False,
213+
force_fp32: bool = False,
214214
) -> XNNDatatype:
215215
# Default initialization
216216
dtype = XNNDatatype.xnn_datatype_fp32
@@ -267,7 +267,7 @@ def get_per_channel_dtype(
267267
if node_dtype is not None and node_dtype == torch.float16:
268268
dtype = (
269269
XNNDatatype.xnn_datatype_fp32
270-
if fp32_static_weight
270+
if force_fp32
271271
else XNNDatatype.xnn_datatype_fp16
272272
)
273273

@@ -348,7 +348,7 @@ def define_tensor( # noqa: C901
348348
convert_to_nhwc: bool = False,
349349
swap_in_out_for_weights: bool = False,
350350
quant_params: Optional[QuantParams] = None,
351-
fp32_static_weights: bool = False,
351+
force_fp32: bool = False,
352352
groups: int = 1,
353353
) -> None:
354354
"""
@@ -368,7 +368,7 @@ def define_tensor( # noqa: C901
368368
constant data. If used along with convert_to_nhwc, this
369369
swap will happen before converting to nhwc.
370370
quant_params: Quantization meta data for this tensor, None if it is not quantized
371-
fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
371+
force_fp32: XNN_FLAG_force_fp32 for fp16 conv
372372
groups: number of groups for swap_in_out_for_weights
373373
"""
374374

@@ -405,7 +405,7 @@ def define_tensor( # noqa: C901
405405
convert_to_nhwc,
406406
swap_in_out_for_weights,
407407
quant_params,
408-
fp32_static_weights,
408+
force_fp32,
409409
groups,
410410
)
411411

@@ -418,7 +418,7 @@ def define_tensor( # noqa: C901
418418
dims = [dims[i] for i in PERM_NCHW_TO_NHWC]
419419

420420
dtype = self.get_serialized_dtype(
421-
quant_params, tensor, fp32_static_weight=fp32_static_weights
421+
quant_params, tensor, force_fp32=force_fp32
422422
)
423423

424424
tvalue = XNNTensorValue(
@@ -504,7 +504,7 @@ def get_serialized_buffer_index(
504504
convert_to_nhwc: bool,
505505
swap_in_out_for_weights: bool,
506506
quant_params: Optional[QuantParams],
507-
fp32_static_weights: bool = False,
507+
force_fp32: bool = False,
508508
groups: int = 1,
509509
) -> int:
510510
"""
@@ -525,7 +525,7 @@ def get_serialized_buffer_index(
525525
constant data. If used along with convert_to_nhwc, this
526526
swap will happen before converting to nhwc.
527527
quant_params: Quantization meta data for this tensor, None if it is not quantize
528-
fp32_static_weights: bool to indicate whether tensor is fp32 static weights
528+
force_fp32: bool to indicate whether tensor is fp32 static weights
529529
groups: groups for swap_in_out_for_weights
530530
531531
Returns:
@@ -554,7 +554,7 @@ def get_serialized_buffer_index(
554554
# Quantize buffer if static data is indeed quantized
555555
if quant_params is not None and not quant_params.is_dynamic:
556556
const_val = quant_params.quantize_tensor(const_val).contiguous()
557-
elif const_val.dtype != torch.float16 or fp32_static_weights:
557+
elif const_val.dtype != torch.float16 or force_fp32:
558558
# ensure that the const is fp32
559559
const_val = const_val.to(dtype=torch.float32).contiguous()
560560

Diff for: backends/xnnpack/operators/op_conv2d.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def define_node(
8282
weight_quant_params = QuantParams.from_weights(
8383
kernel_node, self._exported_program
8484
)
85-
fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16
8685

8786
if weight_quant_params is not None and weight_quant_params.per_channel:
8887
if is_transpose:
@@ -102,7 +101,6 @@ def define_node(
102101
convert_to_nhwc=True,
103102
swap_in_out_for_weights=is_depthwise_conv or is_transpose,
104103
quant_params=weight_quant_params,
105-
fp32_static_weights=fp32_static_weights,
106104
groups=groups if is_transpose else 1,
107105
)
108106
kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)]
@@ -127,13 +125,19 @@ def define_node(
127125
bias_quant_params = QuantParams.from_bias(
128126
bias_node, weight_quant_params, input_quant_params
129127
)
128+
# For dynamic quantization, there are no kernels with fp16 bias
129+
# So we need to force the fp16 bias to fp32
130+
force_fp32 = False
131+
if input_quant_params is not None and input_quant_params.is_dynamic:
132+
force_fp32 = True
133+
130134
self.define_tensor(
131135
get_input_node(node, 2),
132136
xnn_graph,
133137
vals_to_ids,
134138
convert_to_nhwc=False,
135139
quant_params=bias_quant_params,
136-
fp32_static_weights=fp32_static_weights,
140+
force_fp32=force_fp32,
137141
)
138142
kwargs["bias_id"] = vals_to_ids[get_input_node(node, 2)]
139143

Diff for: backends/xnnpack/operators/op_linear.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def define_node(
5959
xnn_graph,
6060
vals_to_ids,
6161
quant_params=weight_quant_params,
62-
fp32_static_weights=True,
6362
)
6463
filter_id = vals_to_ids[weight_node]
6564

@@ -69,12 +68,18 @@ def define_node(
6968
bias_quant_params = QuantParams.from_bias(
7069
bias_node, weight_quant_params, input_quant_params
7170
)
71+
# For dynamic quantization, there are no kernels with fp16 bias
72+
# So we need to force the fp16 bias to fp32
73+
force_fp32 = False
74+
if input_quant_params is not None and input_quant_params.is_dynamic:
75+
force_fp32 = True
76+
7277
self.define_tensor(
7378
get_input_node(node, 2),
7479
xnn_graph,
7580
vals_to_ids,
7681
quant_params=bias_quant_params,
77-
fp32_static_weights=True,
82+
force_fp32=force_fp32,
7883
)
7984
bias_id = vals_to_ids[bias_node]
8085
else:

Diff for: backends/xnnpack/test/ops/test_linear.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.flo
607607
tester.to_edge()
608608
tester.partition(
609609
Partition(DynamicallyQuantizedPartitioner)
610-
).dump_artifact()
610+
)
611611
# should have [add]mm node
612612
if uses_bias:
613613
tester.check(
@@ -624,7 +624,7 @@ def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.flo
624624
else:
625625
tester.to_edge_transform_and_lower(
626626
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
627-
).dump_artifact()
627+
)
628628
# should not have a delegate node
629629
tester.check_not(
630630
[
@@ -717,7 +717,7 @@ def test_fp16_linear(self):
717717
num_batch_dims=num_batch_dims,
718718
uses_bias=use_bias,
719719
dtype=torch.float16,
720-
atol=5e-2, # TODO(T212995726): Investigate right atol for rand[n] inputs
720+
atol=5e-3, # TODO(T212995726): Investigate right atol for rand[n] inputs
721721
)
722722

723723
def test_fp32_linear(self):

0 commit comments

Comments
 (0)