Skip to content

Commit cb31420

Browse files
committed
[XNNPACK] Serialize fp16 weights as fp16
1 parent ce74f8e commit cb31420

File tree

4 files changed

+22
-21
lines changed

4 files changed

+22
-21
lines changed

Diff for: backends/xnnpack/operators/node_visitor.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def get_serialized_dtype(
210210
self,
211211
quant_params: Optional[QuantParams],
212212
node: torch.fx.Node,
213-
fp32_static_weight: bool = False,
213+
force_fp32: bool = False,
214214
) -> XNNDatatype:
215215
# Default initialization
216216
dtype = XNNDatatype.xnn_datatype_fp32
@@ -267,7 +267,7 @@ def get_per_channel_dtype(
267267
if node_dtype is not None and node_dtype == torch.float16:
268268
dtype = (
269269
XNNDatatype.xnn_datatype_fp32
270-
if fp32_static_weight
270+
if force_fp32
271271
else XNNDatatype.xnn_datatype_fp16
272272
)
273273

@@ -348,7 +348,7 @@ def define_tensor( # noqa: C901
348348
convert_to_nhwc: bool = False,
349349
swap_in_out_for_weights: bool = False,
350350
quant_params: Optional[QuantParams] = None,
351-
fp32_static_weights: bool = False,
351+
force_fp32: bool = False,
352352
groups: int = 1,
353353
) -> None:
354354
"""
@@ -368,7 +368,7 @@ def define_tensor( # noqa: C901
368368
constant data. If used along with convert_to_nhwc, this
369369
swap will happen before converting to nhwc.
370370
quant_params: Quantization meta data for this tensor, None if it is not quantized
371-
fp32_static_weights: XNN_FLAG_FP32_STATIC_WEIGHTS for fp16 conv
371+
force_fp32: forces tensor to be serialize as fp32, used for bias of dynamically quantized ops
372372
groups: number of groups for swap_in_out_for_weights
373373
"""
374374

@@ -405,7 +405,7 @@ def define_tensor( # noqa: C901
405405
convert_to_nhwc,
406406
swap_in_out_for_weights,
407407
quant_params,
408-
fp32_static_weights,
408+
force_fp32,
409409
groups,
410410
)
411411

@@ -417,9 +417,7 @@ def define_tensor( # noqa: C901
417417
check_or_raise(len(dims) == 4, "Converting to nhwc requires 4d tensor")
418418
dims = [dims[i] for i in PERM_NCHW_TO_NHWC]
419419

420-
dtype = self.get_serialized_dtype(
421-
quant_params, tensor, fp32_static_weight=fp32_static_weights
422-
)
420+
dtype = self.get_serialized_dtype(quant_params, tensor, force_fp32=force_fp32)
423421

424422
tvalue = XNNTensorValue(
425423
datatype=dtype,
@@ -504,7 +502,7 @@ def get_serialized_buffer_index(
504502
convert_to_nhwc: bool,
505503
swap_in_out_for_weights: bool,
506504
quant_params: Optional[QuantParams],
507-
fp32_static_weights: bool = False,
505+
force_fp32: bool = False,
508506
groups: int = 1,
509507
) -> int:
510508
"""
@@ -525,7 +523,7 @@ def get_serialized_buffer_index(
525523
constant data. If used along with convert_to_nhwc, this
526524
swap will happen before converting to nhwc.
527525
quant_params: Quantization meta data for this tensor, None if it is not quantize
528-
fp32_static_weights: bool to indicate whether tensor is fp32 static weights
526+
force_fp32: bool to indicate whether tensor is fp32 static weights
529527
groups: groups for swap_in_out_for_weights
530528
531529
Returns:
@@ -554,7 +552,7 @@ def get_serialized_buffer_index(
554552
# Quantize buffer if static data is indeed quantized
555553
if quant_params is not None and not quant_params.is_dynamic:
556554
const_val = quant_params.quantize_tensor(const_val).contiguous()
557-
elif const_val.dtype != torch.float16 or fp32_static_weights:
555+
elif const_val.dtype != torch.float16 or force_fp32:
558556
# ensure that the const is fp32
559557
const_val = const_val.to(dtype=torch.float32).contiguous()
560558

Diff for: backends/xnnpack/operators/op_conv2d.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def define_node(
8282
weight_quant_params = QuantParams.from_weights(
8383
kernel_node, self._exported_program
8484
)
85-
fp32_static_weights = kernel_node.meta["val"].dtype == torch.float16
8685

8786
if weight_quant_params is not None and weight_quant_params.per_channel:
8887
if is_transpose:
@@ -102,8 +101,8 @@ def define_node(
102101
convert_to_nhwc=True,
103102
swap_in_out_for_weights=is_depthwise_conv or is_transpose,
104103
quant_params=weight_quant_params,
105-
fp32_static_weights=fp32_static_weights,
106104
groups=groups if is_transpose else 1,
105+
force_fp32=True,
107106
)
108107
kwargs["filter_id"] = vals_to_ids[get_input_node(node, 1)]
109108

@@ -127,13 +126,14 @@ def define_node(
127126
bias_quant_params = QuantParams.from_bias(
128127
bias_node, weight_quant_params, input_quant_params
129128
)
129+
130130
self.define_tensor(
131131
get_input_node(node, 2),
132132
xnn_graph,
133133
vals_to_ids,
134134
convert_to_nhwc=False,
135135
quant_params=bias_quant_params,
136-
fp32_static_weights=fp32_static_weights,
136+
force_fp32=True,
137137
)
138138
kwargs["bias_id"] = vals_to_ids[get_input_node(node, 2)]
139139

Diff for: backends/xnnpack/operators/op_linear.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def define_node(
5959
xnn_graph,
6060
vals_to_ids,
6161
quant_params=weight_quant_params,
62-
fp32_static_weights=True,
6362
)
6463
filter_id = vals_to_ids[weight_node]
6564

@@ -69,12 +68,18 @@ def define_node(
6968
bias_quant_params = QuantParams.from_bias(
7069
bias_node, weight_quant_params, input_quant_params
7170
)
71+
# For dynamic quantization, there are no kernels with fp16 bias
72+
# So we need to force the fp16 bias to fp32
73+
force_fp32 = False
74+
if input_quant_params is not None and input_quant_params.is_dynamic:
75+
force_fp32 = True
76+
7277
self.define_tensor(
7378
get_input_node(node, 2),
7479
xnn_graph,
7580
vals_to_ids,
7681
quant_params=bias_quant_params,
77-
fp32_static_weights=True,
82+
force_fp32=force_fp32,
7883
)
7984
bias_id = vals_to_ids[bias_node]
8085
else:

Diff for: backends/xnnpack/test/ops/test_linear.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,7 @@ def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.flo
605605

606606
if legacy_partitioner:
607607
tester.to_edge()
608-
tester.partition(
609-
Partition(DynamicallyQuantizedPartitioner)
610-
).dump_artifact()
608+
tester.partition(Partition(DynamicallyQuantizedPartitioner))
611609
# should have [add]mm node
612610
if uses_bias:
613611
tester.check(
@@ -624,7 +622,7 @@ def _test_qd8_linear_per_tensor_unsupported(self, dtype: torch.dtype = torch.flo
624622
else:
625623
tester.to_edge_transform_and_lower(
626624
ToEdgeTransformAndLower([DynamicallyQuantizedPartitioner])
627-
).dump_artifact()
625+
)
628626
# should not have a delegate node
629627
tester.check_not(
630628
[
@@ -717,7 +715,7 @@ def test_fp16_linear(self):
717715
num_batch_dims=num_batch_dims,
718716
uses_bias=use_bias,
719717
dtype=torch.float16,
720-
atol=5e-2, # TODO(T212995726): Investigate right atol for rand[n] inputs
718+
atol=5e-3, # TODO(T212995726): Investigate right atol for rand[n] inputs
721719
)
722720

723721
def test_fp32_linear(self):

0 commit comments

Comments
 (0)