Skip to content

Commit ad8061b

Browse files
add support for power of 2 scaling in float8 training
1 parent 8afd10e commit ad8061b

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

torchao/float8/config.py

+17
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,20 @@ class Float8GemmConfig:
146146
use_fast_accum: bool = False
147147

148148

149+
@dataclass(frozen=True)
150+
class Float8ScalingFactorConfig:
151+
"""
152+
Configuration for scaling factor used for float8 quantization.
153+
"""
154+
155+
# If this option is enabled, the scaling factor used for float8 quantization
156+
# will be rounded down to the nearest power of 2. This has been shown to help
157+
# reduce quantization error by avoiding rounding errors when multiplying/dividing
158+
# by the scaling factor, as well as ensuring large values are quantized to the
159+
# same value in the forward pass as the backward pass.
160+
power_of_2_scale: bool = False
161+
162+
149163
@dataclass(frozen=True)
150164
class Float8LinearConfig:
151165
"""
@@ -234,6 +248,9 @@ class Float8LinearConfig:
234248
# tests so that the warning does not spam the CI stdout.
235249
force_recompute_fp8_weight_in_bwd: bool = False
236250

251+
# configuration used for calculating the scaling factor used in float8 quantization.
252+
scaling_factor_config: Float8ScalingFactorConfig = None
253+
237254
def __post_init__(self):
238255
# Populate the additional cast overrides, if the user did not specify them
239256
# Note: this hacks around the frozen-ness of this dataclass

torchao/float8/float8_scaling_utils.py

+41-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414

15-
from torchao.float8.config import ScalingGranularity
15+
from torchao.float8.config import Float8ScalingFactorConfig, ScalingGranularity
1616
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
1717
from torchao.float8.float8_tensor import (
1818
Float8Tensor,
@@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic(
3636
device_mesh=None,
3737
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
3838
axiswise_dim: Optional[int] = None,
39+
scaling_factor_config: Float8ScalingFactorConfig = None,
3940
) -> Float8Tensor:
4041
"""
4142
Given a high precision tensor `hp_tensor`,
@@ -51,6 +52,10 @@ def hp_tensor_to_float8_dynamic(
5152
the 3 fwd/bwd gemms of linear
5253
scaling_granularity: Defines the scaling granularity
5354
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
55+
scaling_factor_config: optional configurations used to calculate the scaling factor.
56+
* for row-wise scaling granularity, power of 2 scaling factor will be used by default,
57+
but can be disabled via this config.
58+
* for all other scaling granularities, power of 2 scaling factors are not used by default.
5459
"""
5560
scale = tensor_to_scale(
5661
hp_tensor,
@@ -60,6 +65,11 @@ def hp_tensor_to_float8_dynamic(
6065
scaling_granularity,
6166
axiswise_dim,
6267
)
68+
69+
if _use_power_of_2_scale(scaling_granularity, scaling_factor_config):
70+
# this rounds the scaling factor down to the nearest power of 2.
71+
scale = torch.exp2(torch.floor(torch.log2(scale)))
72+
6373
return hp_tensor_and_scale_to_float8(
6474
hp_tensor,
6575
scale,
@@ -70,6 +80,36 @@ def hp_tensor_to_float8_dynamic(
7080
)
7181

7282

83+
def _use_power_of_2_scale(
84+
scaling_granularity: ScalingGranularity,
85+
scaling_factor_config: Float8ScalingFactorConfig = None,
86+
) -> bool:
87+
"""
88+
Returns boolean indicating if scaling factor should be rounded down to
89+
the nearest power of 2.
90+
91+
Returns true in these cases:
92+
1. The caller has enabled it in the scaling factor config.
93+
2. Default on for row-wise scaling unless user has explicitly disabled
94+
it in the scaling factor config.
95+
96+
Otherwise, returns false.
97+
"""
98+
power_of_2_scale_enabled = (
99+
scaling_factor_config is not None
100+
and scaling_factor_config.power_of_2_scale is True
101+
)
102+
power_of_2_scale_explicitly_disabled = (
103+
scaling_factor_config is not None
104+
and scaling_factor_config.power_of_2_scale is False
105+
)
106+
use_power_of_2_scale = power_of_2_scale_enabled or (
107+
scaling_granularity == ScalingGranularity.AXISWISE
108+
and not power_of_2_scale_explicitly_disabled
109+
)
110+
return use_power_of_2_scale
111+
112+
73113
def hp_tensor_to_float8_delayed(
74114
hp_tensor: torch.Tensor,
75115
s: torch.Tensor,

0 commit comments

Comments
 (0)