Skip to content

Commit f2433b1

Browse files
support power of 2 scaling factors in float8 training
1 parent 8afd10e commit f2433b1

File tree

5 files changed

+44
-18
lines changed

5 files changed

+44
-18
lines changed

test/float8/test_base.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import torch.nn as nn
1616

1717
from torchao.utils import (
18-
TORCH_VERSION_AT_LEAST_2_5,
1918
is_sm_at_least_89,
2019
is_sm_at_least_90,
20+
TORCH_VERSION_AT_LEAST_2_5,
2121
)
2222

2323
if not TORCH_VERSION_AT_LEAST_2_5:
@@ -26,13 +26,13 @@
2626

2727
from torchao.float8.config import (
2828
CastConfig,
29+
e4m3_dtype,
30+
e5m2_dtype,
2931
Float8LinearConfig,
3032
Float8LinearRecipeName,
33+
recipe_name_to_linear_config,
3134
ScalingGranularity,
3235
ScalingType,
33-
e4m3_dtype,
34-
e5m2_dtype,
35-
recipe_name_to_linear_config,
3636
)
3737
from torchao.float8.float8_linear import Float8Linear
3838
from torchao.float8.float8_linear_utils import (
@@ -48,15 +48,15 @@
4848
from torchao.float8.float8_tensor import (
4949
Float8Tensor,
5050
GemmInputRole,
51+
hp_tensor_and_scale_to_float8,
5152
LinearMMConfig,
5253
ScaledMMConfig,
53-
hp_tensor_and_scale_to_float8,
5454
)
5555
from torchao.float8.float8_utils import (
56-
FP8_TYPES,
5756
compute_error,
5857
config_has_stateful_scaling,
5958
fp8_tensor_statistics,
59+
FP8_TYPES,
6060
tensor_to_scale,
6161
)
6262
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
@@ -164,7 +164,8 @@ def test_transpose(self):
164164

165165
@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
166166
@pytest.mark.parametrize("axiswise_dim", [0, -1])
167-
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
167+
@pytest.mark.parametrize("power_of_2_scale", [True, False])
168+
def test_axiswise_dynamic_cast(self, shape, axiswise_dim, power_of_2_scale):
168169
a = torch.randn(*shape, dtype=torch.bfloat16)
169170
linear_mm_config = LinearMMConfig()
170171
a_fp8 = hp_tensor_to_float8_dynamic(
@@ -173,6 +174,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
173174
linear_mm_config,
174175
scaling_granularity=ScalingGranularity.AXISWISE,
175176
axiswise_dim=axiswise_dim,
177+
power_of_2_scale=power_of_2_scale,
176178
)
177179
a_dq = a_fp8.to_original_precision()
178180
sqnr = compute_error(a, a_dq)

test/float8/test_compile.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
import pytest
1414

1515
from torchao.utils import (
16-
TORCH_VERSION_AT_LEAST_2_5,
1716
is_sm_at_least_89,
1817
is_sm_at_least_90,
18+
TORCH_VERSION_AT_LEAST_2_5,
1919
)
2020

2121
if not TORCH_VERSION_AT_LEAST_2_5:
@@ -29,11 +29,11 @@
2929
from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes
3030
from torchao.float8.config import (
3131
CastConfig,
32+
e4m3_dtype,
3233
Float8LinearConfig,
3334
Float8LinearRecipeName,
34-
ScalingType,
35-
e4m3_dtype,
3635
recipe_name_to_linear_config,
36+
ScalingType,
3737
)
3838
from torchao.float8.float8_linear import Float8Linear
3939
from torchao.float8.float8_linear_utils import (
@@ -45,11 +45,7 @@
4545
hp_tensor_to_float8_delayed,
4646
hp_tensor_to_float8_dynamic,
4747
)
48-
from torchao.float8.float8_tensor import (
49-
GemmInputRole,
50-
LinearMMConfig,
51-
ScaledMMConfig,
52-
)
48+
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
5349
from torchao.float8.float8_utils import config_has_stateful_scaling
5450
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
5551
from torchao.testing.float8.test_utils import get_test_float8_linear_config
@@ -420,7 +416,14 @@ def test_sync_amax_func_cuda_graph_success():
420416
torch.float16,
421417
],
422418
)
423-
def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
419+
@pytest.mark.parametrize(
420+
"power_of_2_scale",
421+
[
422+
True,
423+
False,
424+
],
425+
)
426+
def test_dynamic_scale_numeric_parity(dtype: torch.dtype, power_of_2_scale: bool):
424427
scaling_type_weight = ScalingType.DYNAMIC
425428
torch.manual_seed(42)
426429
hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype)
@@ -456,13 +459,15 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
456459
e4m3_dtype,
457460
linear_mm_config,
458461
gemm_input_role=GemmInputRole.WEIGHT,
462+
power_of_2_scale=power_of_2_scale,
459463
)
460464
torch._dynamo.reset()
461465
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
462466
hp_tensor2,
463467
e4m3_dtype,
464468
linear_mm_config,
465469
gemm_input_role=GemmInputRole.WEIGHT,
470+
power_of_2_scale=power_of_2_scale,
466471
)
467472
assert torch.equal(float8_eager._scale, float8_compile._scale)
468473
assert torch.equal(float8_eager._data, float8_compile._data)
@@ -474,8 +479,7 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
474479
)
475480
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
476481
def test_delayed_scaling_pattern_replacement(dtype: torch.dtype):
477-
from torch._inductor import config as inductor_config
478-
from torch._inductor import metrics
482+
from torch._inductor import config as inductor_config, metrics
479483

480484
inductor_config.loop_ordering_after_fusion = True
481485

torchao/float8/config.py

+9
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,13 @@ class Float8LinearConfig:
234234
# tests so that the warning does not spam the CI stdout.
235235
force_recompute_fp8_weight_in_bwd: bool = False
236236

237+
# If this option is enabled, the scaling factor used for float8 quantization
238+
# will be rounded down to the nearest power of 2. This has been shown to help
239+
# reduce quantization error by avoiding rounding errors when multiplying/dividing
240+
# by the scaling factor, as well as ensuring large values are quantized to the
241+
# same value in the forward pass as the backward passes.
242+
power_of_2_scale: bool = False
243+
237244
def __post_init__(self):
238245
# Populate the additional cast overrides, if the user did not specify them
239246
# Note: this hacks around the frozen-ness of this dataclass
@@ -336,6 +343,8 @@ def recipe_name_to_linear_config(
336343
cast_config_input=cc_i,
337344
cast_config_weight=cc_w,
338345
cast_config_grad_output=cc_go,
346+
# enable power of 2 scaling factors by default for row-wise scaling
347+
power_of_2_scale=True,
339348
)
340349

341350
elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:

torchao/float8/float8_linear.py

+6
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def forward(
9696
axiswise_dim=get_maybe_axiswise_dim(
9797
-1, c.cast_config_input.scaling_granularity
9898
),
99+
power_of_2_scale=c.power_of_2_scale,
99100
)
100101

101102
if tensor_already_casted_to_fp8(weight_hp_t):
@@ -112,6 +113,7 @@ def forward(
112113
axiswise_dim=get_maybe_axiswise_dim(
113114
0, c.cast_config_weight.scaling_granularity
114115
),
116+
power_of_2_scale=c.power_of_2_scale,
115117
)
116118

117119
# the reshapes are needed in order to make the shapes compatible with
@@ -151,6 +153,7 @@ def backward(ctx, grad_output):
151153
axiswise_dim=get_maybe_axiswise_dim(
152154
-1, c.cast_config_grad_output.scaling_granularity
153155
),
156+
power_of_2_scale=c.power_of_2_scale,
154157
)
155158

156159
if tensor_already_casted_to_fp8(weight_hp_t):
@@ -181,6 +184,7 @@ def backward(ctx, grad_output):
181184
axiswise_dim=get_maybe_axiswise_dim(
182185
-1, c.cast_config_weight_for_grad_input.scaling_granularity
183186
),
187+
power_of_2_scale=c.power_of_2_scale,
184188
)
185189

186190
grad_input = torch.mm(
@@ -216,6 +220,7 @@ def backward(ctx, grad_output):
216220
axiswise_dim=get_maybe_axiswise_dim(
217221
0, c.cast_config_grad_output_for_grad_weight.scaling_granularity
218222
),
223+
power_of_2_scale=c.power_of_2_scale,
219224
)
220225

221226
if tensor_already_casted_to_fp8(input_hp_reshaped):
@@ -233,6 +238,7 @@ def backward(ctx, grad_output):
233238
axiswise_dim=get_maybe_axiswise_dim(
234239
0, c.cast_config_input_for_grad_weight.scaling_granularity
235240
),
241+
power_of_2_scale=c.power_of_2_scale,
236242
)
237243

238244
grad_weight = torch.mm(

torchao/float8/float8_scaling_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def hp_tensor_to_float8_dynamic(
3636
device_mesh=None,
3737
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
3838
axiswise_dim: Optional[int] = None,
39+
power_of_2_scale: bool = False,
3940
) -> Float8Tensor:
4041
"""
4142
Given a high precision tensor `hp_tensor`,
@@ -51,6 +52,7 @@ def hp_tensor_to_float8_dynamic(
5152
the 3 fwd/bwd gemms of linear
5253
scaling_granularity: Defines the scaling granularity
5354
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
55+
power_of_2_scale: if true, round scaling factor down to the nearest power of 2.
5456
"""
5557
scale = tensor_to_scale(
5658
hp_tensor,
@@ -60,6 +62,9 @@ def hp_tensor_to_float8_dynamic(
6062
scaling_granularity,
6163
axiswise_dim,
6264
)
65+
if power_of_2_scale:
66+
# rounds down to the nearest power of 2.
67+
scale = torch.exp2(torch.floor(torch.log2(scale)))
6368
return hp_tensor_and_scale_to_float8(
6469
hp_tensor,
6570
scale,

0 commit comments

Comments
 (0)