diff --git a/torchao/dtypes/floatx/float8_layout.py b/torchao/dtypes/floatx/float8_layout.py index 5a7e1924b3..656ebb61ae 100644 --- a/torchao/dtypes/floatx/float8_layout.py +++ b/torchao/dtypes/floatx/float8_layout.py @@ -253,6 +253,7 @@ def _linear_fp8_act_fp8_weight_impl( ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" scaled_mm_config = weight_tensor._layout.mm_config + assert scaled_mm_config is not None out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) # Weight tensor preprocessing diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index 4b7dfe405f..fd24a71189 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -163,12 +163,13 @@ def __init__( weight, weight_scale, weight_zero_point, block_size, self.target_dtype ) elif self.target_dtype == torch.float8_e4m3fn: + mm_config = Float8MMConfig(use_fast_accum=True) self.qweight = to_affine_quantized_floatx_static( weight, weight_scale, block_size, target_dtype, - Float8Layout(mm_config=None), + Float8Layout(mm_config=mm_config), ) else: raise ValueError(f"Unsupported target dtype {self.target_dtype}")