diff --git a/torchao/float8/config.py b/torchao/float8/config.py index c7f32cd3fa..d29bafb8b1 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -146,6 +146,20 @@ class Float8GemmConfig: use_fast_accum: bool = False +@dataclass(frozen=True) +class Float8ScalingFactorConfig: + """ + Configuration for scaling factor used for float8 quantization. + """ + + # If this option is enabled, the scaling factor used for float8 quantization + # will be rounded down to the nearest power of 2. This has been shown to help + # reduce quantization error by avoiding rounding errors when multiplying/dividing + # by the scaling factor, as well as ensuring large values are quantized to the + # same value in the forward pass as the backward pass. + power_of_2_scale: bool = False + + @dataclass(frozen=True) class Float8LinearConfig: """ @@ -234,6 +248,9 @@ class Float8LinearConfig: # tests so that the warning does not spam the CI stdout. force_recompute_fp8_weight_in_bwd: bool = False + # configuration used for calculating the scaling factor used in float8 quantization. + scaling_factor_config: Float8ScalingFactorConfig = None + def __post_init__(self): # Populate the additional cast overrides, if the user did not specify them # Note: this hacks around the frozen-ness of this dataclass @@ -258,7 +275,9 @@ def __post_init__(self): # float8 all-gather only supports tensorwise, in the future may support blockwise if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE: - assert not self.enable_fsdp_float8_all_gather, f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}" + assert ( + not self.enable_fsdp_float8_all_gather + ), f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}" # save some characters in the compatibility checks below cc_i = self.cast_config_input diff --git a/torchao/float8/float8_scaling_utils.py b/torchao/float8/float8_scaling_utils.py index 0c27e4f3fc..8a1b3af387 100644 --- a/torchao/float8/float8_scaling_utils.py +++ b/torchao/float8/float8_scaling_utils.py @@ -12,7 +12,7 @@ import torch -from torchao.float8.config import ScalingGranularity +from torchao.float8.config import Float8ScalingFactorConfig, ScalingGranularity from torchao.float8.distributed_utils import tensor_already_casted_to_fp8 from torchao.float8.float8_tensor import ( Float8Tensor, @@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic( device_mesh=None, scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE, axiswise_dim: Optional[int] = None, + scaling_factor_config: Float8ScalingFactorConfig = None, ) -> Float8Tensor: """ Given a high precision tensor `hp_tensor`, @@ -51,6 +52,10 @@ def hp_tensor_to_float8_dynamic( the 3 fwd/bwd gemms of linear scaling_granularity: Defines the scaling granularity axiswise_dim: if axiswise granularity is used, defines the dim to scale across + scaling_factor_config: optional configurations used to calculate the scaling factor. + * for row-wise scaling granularity, power of 2 scaling factor will be used by default, + but can be disabled via this config. + * for all other scaling granularities, power of 2 scaling factors are not used by default. """ scale = tensor_to_scale( hp_tensor, @@ -60,6 +65,11 @@ def hp_tensor_to_float8_dynamic( scaling_granularity, axiswise_dim, ) + + if _use_power_of_2_scale(scaling_granularity, scaling_factor_config): + # this rounds the scaling factor down to the nearest power of 2. + scale = torch.exp2(torch.floor(torch.log2(scale))) + return hp_tensor_and_scale_to_float8( hp_tensor, scale, @@ -70,6 +80,36 @@ def hp_tensor_to_float8_dynamic( ) +def _use_power_of_2_scale( + scaling_granularity: ScalingGranularity, + scaling_factor_config: Float8ScalingFactorConfig = None, +) -> bool: + """ + Returns boolean indicating if scaling factor should be rounded down to + the nearest power of 2. + + Returns true in these cases: + 1. The caller has enabled it in the scaling factor config. + 2. Default on for row-wise scaling unless user has explicitly disabled + it in the scaling factor config. + + Otherwise, returns false. + """ + power_of_2_scale_enabled = ( + scaling_factor_config is not None + and scaling_factor_config.power_of_2_scale is True + ) + power_of_2_scale_explicitly_disabled = ( + scaling_factor_config is not None + and scaling_factor_config.power_of_2_scale is False + ) + use_power_of_2_scale = power_of_2_scale_enabled or ( + scaling_granularity == ScalingGranularity.AXISWISE + and not power_of_2_scale_explicitly_disabled + ) + return use_power_of_2_scale + + def hp_tensor_to_float8_delayed( hp_tensor: torch.Tensor, s: torch.Tensor,