Skip to content

Commit 17b9ce3

Browse files
authored
unbreak float8 static quant tutorial (#1709)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 413689d commit 17b9ce3

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

torchao/dtypes/floatx/float8_layout.py

+1
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def _linear_fp8_act_fp8_weight_impl(
253253
):
254254
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
255255
scaled_mm_config = weight_tensor._layout.mm_config
256+
assert scaled_mm_config is not None
256257
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)
257258

258259
# Weight tensor preprocessing

tutorials/calibration_flow/static_quant.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,13 @@ def __init__(
163163
weight, weight_scale, weight_zero_point, block_size, self.target_dtype
164164
)
165165
elif self.target_dtype == torch.float8_e4m3fn:
166+
mm_config = Float8MMConfig(use_fast_accum=True)
166167
self.qweight = to_affine_quantized_floatx_static(
167168
weight,
168169
weight_scale,
169170
block_size,
170171
target_dtype,
171-
Float8Layout(mm_config=None),
172+
Float8Layout(mm_config=mm_config),
172173
)
173174
else:
174175
raise ValueError(f"Unsupported target dtype {self.target_dtype}")

0 commit comments

Comments
 (0)