Skip to content

Commit 32a51ec

Browse files
Support power of 2 scaling factors in float8 training and use e4m3 everywhere (#1670)
1 parent bae41d1 commit 32a51ec

File tree

7 files changed

+145
-21
lines changed

7 files changed

+145
-21
lines changed

test/float8/test_base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ def test_transpose(self):
164164

165165
@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
166166
@pytest.mark.parametrize("axiswise_dim", [0, -1])
167-
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
167+
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
168+
def test_axiswise_dynamic_cast(
169+
self, shape, axiswise_dim, round_scales_to_power_of_2
170+
):
168171
a = torch.randn(*shape, dtype=torch.bfloat16)
169172
linear_mm_config = LinearMMConfig()
170173
a_fp8 = hp_tensor_to_float8_dynamic(
@@ -173,6 +176,7 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
173176
linear_mm_config,
174177
scaling_granularity=ScalingGranularity.AXISWISE,
175178
axiswise_dim=axiswise_dim,
179+
round_scales_to_power_of_2=round_scales_to_power_of_2,
176180
)
177181
a_dq = a_fp8.to_original_precision()
178182
sqnr = compute_error(a, a_dq)

test/float8/test_compile.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@
4545
hp_tensor_to_float8_delayed,
4646
hp_tensor_to_float8_dynamic,
4747
)
48-
from torchao.float8.float8_tensor import (
49-
GemmInputRole,
50-
LinearMMConfig,
51-
ScaledMMConfig,
52-
)
48+
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig, ScaledMMConfig
5349
from torchao.float8.float8_utils import config_has_stateful_scaling
5450
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
5551
from torchao.testing.float8.test_utils import get_test_float8_linear_config
@@ -420,13 +416,23 @@ def test_sync_amax_func_cuda_graph_success():
420416
torch.float16,
421417
],
422418
)
423-
def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
419+
@pytest.mark.parametrize(
420+
"round_scales_to_power_of_2",
421+
[
422+
True,
423+
False,
424+
],
425+
)
426+
def test_dynamic_scale_numeric_parity(
427+
dtype: torch.dtype, round_scales_to_power_of_2: bool
428+
):
424429
scaling_type_weight = ScalingType.DYNAMIC
425430
torch.manual_seed(42)
426431
hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype)
427432
hp_tensor2 = hp_tensor1.detach().clone()
428433
float8_config = Float8LinearConfig(
429434
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
435+
round_scales_to_power_of_2=round_scales_to_power_of_2,
430436
)
431437
linear_mm_config = LinearMMConfig(
432438
# output
@@ -456,13 +462,15 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
456462
e4m3_dtype,
457463
linear_mm_config,
458464
gemm_input_role=GemmInputRole.WEIGHT,
465+
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
459466
)
460467
torch._dynamo.reset()
461468
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
462469
hp_tensor2,
463470
e4m3_dtype,
464471
linear_mm_config,
465472
gemm_input_role=GemmInputRole.WEIGHT,
473+
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
466474
)
467475
assert torch.equal(float8_eager._scale, float8_compile._scale)
468476
assert torch.equal(float8_eager._data, float8_compile._data)

test/float8/test_float8_utils.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import unittest
2+
3+
import pytest
4+
import torch
5+
6+
from torchao.float8.float8_utils import _round_scale_down_to_power_of_2
7+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
8+
9+
if not TORCH_VERSION_AT_LEAST_2_5:
10+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
11+
12+
13+
# source for notable single-precision cases:
14+
# https://en.wikipedia.org/wiki/Single-precision_floating-point_format
15+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
16+
@pytest.mark.parametrize(
17+
"test_case",
18+
[
19+
# ("test_case_name", input, expected result)
20+
("one", 1.0, 1.0),
21+
("inf", float("inf"), float("inf")),
22+
("nan", float("nan"), float("nan")),
23+
("smallest positive subnormal number", 2**-126 * 2**-23, 2**-126 * 2**-23),
24+
("largest normal number", 2**127 * (2 - 2**-23), float("inf")),
25+
("smallest positive normal number", 2**-126, 2**-126),
26+
("largest number less than one", 1.0 - 2**-24, 0.5),
27+
("smallest number larger than one", 1.0 + 2**-23, 1.0),
28+
# TODO(danielvegamyhre): debug why creating a tensor with largest
29+
# subnormal value in CI env for pytorch 2.5.1 truncates the value to 0.
30+
# ("largest subnormal number", [2**-126 * (1 - 2**-23), 1.1754943508222875e-38]),
31+
],
32+
)
33+
def test_round_scale_down_to_power_of_2_valid_inputs(
34+
test_case: dict,
35+
):
36+
test_case_name, input, expected_result = test_case
37+
input_tensor, expected_tensor = (
38+
torch.tensor(input, dtype=torch.float32).cuda(),
39+
torch.tensor(expected_result, dtype=torch.float32).cuda(),
40+
)
41+
result = _round_scale_down_to_power_of_2(input_tensor)
42+
43+
assert (
44+
torch.equal(result, expected_tensor)
45+
or (result.isnan() and expected_tensor.isnan())
46+
), f"test: {test_case_name}, input: {input_tensor}, expected {expected_tensor}, but got {result}"
47+
48+
49+
@pytest.mark.parametrize(
50+
"invalid_dtype",
51+
[
52+
torch.bfloat16,
53+
torch.float16,
54+
torch.float64,
55+
torch.int8,
56+
torch.uint8,
57+
torch.int32,
58+
torch.uint32,
59+
torch.int64,
60+
],
61+
)
62+
def test_non_float32_input(invalid_dtype: torch.dtype):
63+
non_float32_tensor = torch.tensor([3.0], dtype=invalid_dtype)
64+
with pytest.raises(AssertionError, match="scale must be float32 tensor"):
65+
_round_scale_down_to_power_of_2(non_float32_tensor)

torchao/float8/config.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,13 @@ class Float8LinearConfig:
234234
# tests so that the warning does not spam the CI stdout.
235235
force_recompute_fp8_weight_in_bwd: bool = False
236236

237+
# If this option is enabled, the scaling factor used for float8 quantization
238+
# will be rounded down to the nearest power of 2. This has been shown to help
239+
# reduce quantization error by avoiding rounding errors when multiplying/dividing
240+
# by the scaling factor, as well as ensuring large values are quantized to the
241+
# same value in the forward pass as the backward passes.
242+
round_scales_to_power_of_2: bool = False
243+
237244
def __post_init__(self):
238245
# Populate the additional cast overrides, if the user did not specify them
239246
# Note: this hacks around the frozen-ness of this dataclass
@@ -338,14 +345,22 @@ def recipe_name_to_linear_config(
338345

339346
elif recipe_name is Float8LinearRecipeName.ALL_AXISWISE:
340347
# dynamic axiswise scaling with the CUTLASS rowwise kernel
341-
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
342-
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
343-
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
348+
cc_i = CastConfig(
349+
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
350+
)
351+
cc_w = CastConfig(
352+
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
353+
)
354+
cc_go = CastConfig(
355+
scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype
356+
)
344357

345358
return Float8LinearConfig(
346359
cast_config_input=cc_i,
347360
cast_config_weight=cc_w,
348361
cast_config_grad_output=cc_go,
362+
# enable power of 2 scaling factors by default for row-wise scaling
363+
round_scales_to_power_of_2=True,
349364
)
350365

351366
elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:

torchao/float8/float8_linear.py

+6
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def forward(
9696
axiswise_dim=get_maybe_axiswise_dim(
9797
-1, c.cast_config_input.scaling_granularity
9898
),
99+
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
99100
)
100101

101102
if tensor_already_casted_to_fp8(weight_hp_t):
@@ -112,6 +113,7 @@ def forward(
112113
axiswise_dim=get_maybe_axiswise_dim(
113114
0, c.cast_config_weight.scaling_granularity
114115
),
116+
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
115117
)
116118

117119
# the reshapes are needed in order to make the shapes compatible with
@@ -151,6 +153,7 @@ def backward(ctx, grad_output):
151153
axiswise_dim=get_maybe_axiswise_dim(
152154
-1, c.cast_config_grad_output.scaling_granularity
153155
),
156+
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
154157
)
155158

156159
if tensor_already_casted_to_fp8(weight_hp_t):
@@ -181,6 +184,7 @@ def backward(ctx, grad_output):
181184
axiswise_dim=get_maybe_axiswise_dim(
182185
-1, c.cast_config_weight_for_grad_input.scaling_granularity
183186
),
187+
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
184188
)
185189

186190
grad_input = torch.mm(
@@ -216,6 +220,7 @@ def backward(ctx, grad_output):
216220
axiswise_dim=get_maybe_axiswise_dim(
217221
0, c.cast_config_grad_output_for_grad_weight.scaling_granularity
218222
),
223+
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
219224
)
220225

221226
if tensor_already_casted_to_fp8(input_hp_reshaped):
@@ -233,6 +238,7 @@ def backward(ctx, grad_output):
233238
axiswise_dim=get_maybe_axiswise_dim(
234239
0, c.cast_config_input_for_grad_weight.scaling_granularity
235240
),
241+
round_scales_to_power_of_2=c.round_scales_to_power_of_2,
236242
)
237243

238244
grad_weight = torch.mm(

torchao/float8/float8_scaling_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828

2929

30+
# TODO(danielvegamyhre): refactor to accept Float8LinearConfig directly
3031
def hp_tensor_to_float8_dynamic(
3132
hp_tensor: torch.Tensor,
3233
float8_dtype: torch.dtype,
@@ -36,6 +37,7 @@ def hp_tensor_to_float8_dynamic(
3637
device_mesh=None,
3738
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
3839
axiswise_dim: Optional[int] = None,
40+
round_scales_to_power_of_2: bool = False,
3941
) -> Float8Tensor:
4042
"""
4143
Given a high precision tensor `hp_tensor`,
@@ -51,6 +53,7 @@ def hp_tensor_to_float8_dynamic(
5153
the 3 fwd/bwd gemms of linear
5254
scaling_granularity: Defines the scaling granularity
5355
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
56+
round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2.
5457
"""
5558
scale = tensor_to_scale(
5659
hp_tensor,
@@ -59,6 +62,7 @@ def hp_tensor_to_float8_dynamic(
5962
device_mesh,
6063
scaling_granularity,
6164
axiswise_dim,
65+
round_scales_to_power_of_2,
6266
)
6367
return hp_tensor_and_scale_to_float8(
6468
hp_tensor,

torchao/float8/float8_utils.py

+33-11
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010
import torch.distributed as dist
1111
from torch.distributed._functional_collectives import AsyncCollectiveTensor, all_reduce
1212

13-
from torchao.float8.config import (
14-
Float8LinearConfig,
15-
ScalingGranularity,
16-
ScalingType,
17-
)
13+
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
1814

1915
# Helpful visualizer for debugging (only supports fp32):
2016
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
@@ -33,21 +29,28 @@
3329

3430

3531
@torch.no_grad()
36-
def amax_to_scale(amax: torch.Tensor, float8_dtype: torch.dtype):
32+
def amax_to_scale(
33+
amax: torch.Tensor,
34+
float8_dtype: torch.dtype,
35+
round_scales_to_power_of_2: bool = False,
36+
):
3737
"""Converts the amax value of a tensor to the fp8 scale.
3838
Args:
3939
amax: The amax value of the tensor.
4040
float8_dtype: The float8 dtype.
41+
round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2.
4142
"""
4243
# torch.compile and eager show different numerics for 1.0 / float32,
4344
# upcast to float64 to ensure same numeric between compile and eager
4445
amax = amax.to(torch.float64)
4546
if float8_dtype in FP8_TYPES:
4647
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
48+
res = res.to(torch.float32)
4749
else:
4850
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
49-
50-
return res.to(torch.float32)
51+
if round_scales_to_power_of_2:
52+
res = _round_scale_down_to_power_of_2(res)
53+
return res
5154

5255

5356
@torch.no_grad()
@@ -119,21 +122,35 @@ def tensor_to_amax(
119122

120123
@torch.no_grad()
121124
def tensor_to_scale(
122-
x: torch.Tensor,
125+
hp_tensor: torch.Tensor,
123126
float8_dtype: torch.dtype,
124127
reduce_amax: bool = False,
125128
device_mesh=None,
126129
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
127130
axiswise_dim: Optional[int] = None,
131+
round_scales_to_power_of_2: bool = False,
128132
) -> torch.Tensor:
133+
"""
134+
Compute scaling factor for the given high precision tensor.
135+
136+
Args:
137+
hp_tensor: high precision tensor
138+
float8_dtype: the float8 dtype to use
139+
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
140+
scaling_granularity: Defines the scaling granularity
141+
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
142+
round_scales_to_power_of_2: if true, round scaling factor down to the nearest power of 2.
143+
"""
129144
amax = tensor_to_amax(
130-
x,
145+
hp_tensor,
131146
reduce_amax,
132147
device_mesh,
133148
scaling_granularity,
134149
axiswise_dim,
135150
)
136-
return amax_to_scale(amax, float8_dtype)
151+
return amax_to_scale(
152+
amax, float8_dtype, round_scales_to_power_of_2=round_scales_to_power_of_2
153+
)
137154

138155

139156
def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
@@ -266,3 +283,8 @@ def config_has_stateful_scaling(config: Float8LinearConfig) -> bool:
266283
or config.cast_config_weight.scaling_type != ScalingType.DYNAMIC
267284
or config.cast_config_grad_output.scaling_type != ScalingType.DYNAMIC
268285
)
286+
287+
288+
def _round_scale_down_to_power_of_2(scale: torch.Tensor):
289+
assert scale.dtype == torch.float32, "scale must be float32 tensor"
290+
return torch.exp2(torch.floor(torch.log2(scale)))

0 commit comments

Comments
 (0)