Skip to content

Commit 52d27a1

Browse files
authored
add axiswise granularity to Float8Tensor (#919)
Summary: This is a copy-paste of pytorch-labs/float8_experimental#352 which never landed. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent e7331ab commit 52d27a1

File tree

7 files changed

+278
-34
lines changed

7 files changed

+278
-34
lines changed

test/float8/test_base.py

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,20 @@
2222
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
2323

2424

25-
from torchao.float8.config import CastConfig, Float8LinearConfig, ScalingType
25+
from torchao.float8.config import (
26+
CastConfig,
27+
Float8LinearConfig,
28+
ScalingGranularity,
29+
ScalingType,
30+
)
2631
from torchao.float8.float8_linear import Float8Linear
2732
from torchao.float8.float8_linear_utils import (
2833
convert_to_float8_training,
2934
linear_requires_sync,
3035
sync_float8_amax_and_scale_history,
3136
)
3237
from torchao.float8.float8_python_api import addmm_float8_unwrapped
38+
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_dynamic
3339
from torchao.float8.float8_tensor import (
3440
Float8Tensor,
3541
GemmInputRole,
@@ -51,14 +57,15 @@
5157

5258

5359
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
60+
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
5461

5562
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
5663
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
5764
assert torch.all(a._data == b._data).item(), "data is not identical"
5865
return True
5966

6067

61-
class TestFloat8Tensor(unittest.TestCase):
68+
class TestFloat8Tensor:
6269
def test_preserves_dtype(self) -> None:
6370
# hp means high precision, lp means low precision
6471
hp_dtypes = (torch.float32, torch.float16, torch.bfloat16)
@@ -68,7 +75,7 @@ def test_preserves_dtype(self) -> None:
6875
x1_s = tensor_to_scale(x1_hp, lp_dtype)
6976
x2_lp = hp_tensor_and_scale_to_float8(x1_hp, x1_s, lp_dtype)
7077
x3_hp = x2_lp.to_original_precision()
71-
self.assertTrue(x3_hp.dtype == hp_dtype)
78+
assert x3_hp.dtype == hp_dtype
7279

7380
def test_differentiable_casts(self) -> None:
7481
lp_dtypes = (e4m3_dtype, e5m2_dtype)
@@ -103,7 +110,7 @@ def test_index_put(self):
103110
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
104111
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)
105112

106-
with self.assertRaises(AssertionError):
113+
with pytest.raises(AssertionError):
107114
b[index] = fp8_a
108115
fp8_b[index] = a
109116
fp8_b_bad[index] = fp8_a
@@ -117,7 +124,7 @@ def test_copy_(self):
117124
b = torch.empty(16, dtype=torch.bfloat16)
118125
b.copy_(fp8_a) # Should work
119126
torch.testing.assert_close(b, fp8_a.to_original_precision())
120-
with self.assertRaises(RuntimeError):
127+
with pytest.raises(RuntimeError):
121128
fp8_a.copy_(b) # Should fail
122129

123130
fp8_b = Float8Tensor(
@@ -129,6 +136,105 @@ def test_copy_(self):
129136
fp8_b.copy_(fp8_a)
130137
torch.testing.assert_close(fp8_a._data, fp8_b._data)
131138

139+
@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
140+
@pytest.mark.parametrize("axiswise_dim", [0, -1])
141+
def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
142+
a = torch.randn(*shape, dtype=torch.bfloat16)
143+
linear_mm_config = LinearMMConfig()
144+
a_fp8 = hp_tensor_to_float8_dynamic(
145+
a,
146+
e4m3_dtype,
147+
linear_mm_config,
148+
scaling_granularity=ScalingGranularity.AXISWISE,
149+
axiswise_dim=axiswise_dim,
150+
)
151+
a_dq = a_fp8.to_original_precision()
152+
sqnr = compute_error(a, a_dq)
153+
assert sqnr >= 25.0
154+
155+
def test_axiswise_reshape(self):
156+
a = torch.randn(3, 5, 7, dtype=torch.bfloat16)
157+
linear_mm_config = LinearMMConfig()
158+
159+
# if we scale across dim0, we can only reshape to [3, -1]
160+
a_fp8_d0 = hp_tensor_to_float8_dynamic(
161+
a,
162+
e4m3_dtype,
163+
linear_mm_config,
164+
scaling_granularity=ScalingGranularity.AXISWISE,
165+
axiswise_dim=0,
166+
)
167+
assert list(a_fp8_d0._data.shape) == [3, 5, 7]
168+
assert list(a_fp8_d0._scale.shape) == [1, 5, 7]
169+
170+
a_fp8_d0_r = a_fp8_d0.reshape(3, -1)
171+
assert list(a_fp8_d0_r.shape) == [3, 5 * 7]
172+
assert list(a_fp8_d0_r._scale.shape) == [1, 5 * 7]
173+
# verify numerics did not change
174+
assert torch.allclose(
175+
a_fp8_d0.to_original_precision(),
176+
a_fp8_d0_r.to_original_precision().reshape(3, 5, 7),
177+
atol=0,
178+
rtol=0,
179+
)
180+
with pytest.raises(RuntimeError):
181+
a_fp8_d0_r2 = a_fp8_d0.reshape(-1, 7)
182+
183+
# if we scale across dim2, we can only reshape to [-1, 7]
184+
a_fp8_d2 = hp_tensor_to_float8_dynamic(
185+
a,
186+
e4m3_dtype,
187+
linear_mm_config,
188+
scaling_granularity=ScalingGranularity.AXISWISE,
189+
axiswise_dim=-1,
190+
)
191+
assert list(a_fp8_d2._data.shape) == [3, 5, 7]
192+
assert list(a_fp8_d2._scale.shape) == [3, 5, 1]
193+
194+
a_fp8_d2_r = a_fp8_d2.reshape(-1, 7)
195+
assert list(a_fp8_d2_r.shape) == [3 * 5, 7]
196+
assert list(a_fp8_d2_r._scale.shape) == [3 * 5, 1]
197+
# verify numerics did not change
198+
assert torch.allclose(
199+
a_fp8_d2.to_original_precision(),
200+
a_fp8_d2_r.to_original_precision().reshape(3, 5, 7),
201+
atol=0,
202+
rtol=0,
203+
)
204+
with pytest.raises(RuntimeError):
205+
a_fp8_d2_r2 = a_fp8_d2.reshape(3, -1)
206+
207+
@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
208+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
209+
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
210+
def test_axiswise_gemm(self, a_shape):
211+
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
212+
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
213+
214+
linear_mm_config = LinearMMConfig()
215+
216+
a_fp8 = hp_tensor_to_float8_dynamic(
217+
a,
218+
e4m3_dtype,
219+
linear_mm_config,
220+
gemm_input_role=GemmInputRole.INPUT,
221+
scaling_granularity=ScalingGranularity.AXISWISE,
222+
axiswise_dim=-1,
223+
)
224+
a_fp8 = a_fp8.reshape(-1, a_shape[-1])
225+
b_fp8 = hp_tensor_to_float8_dynamic(
226+
b,
227+
e4m3_dtype,
228+
linear_mm_config,
229+
gemm_input_role=GemmInputRole.WEIGHT,
230+
scaling_granularity=ScalingGranularity.AXISWISE,
231+
axiswise_dim=-1, # will be transposed
232+
)
233+
c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
234+
a = a.reshape(-1, a_shape[-1])
235+
c_ref = torch.mm(a, b.t())
236+
sqnr = compute_error(c_ref, c_fp8_compute)
237+
assert sqnr >= 25.0
132238

133239

134240
class TestFloat8Linear:

torchao/float8/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ def short_str(self):
2626
return "sta"
2727

2828

29+
class ScalingGranularity(enum.Enum):
30+
"""
31+
Defines the granularity of scaling strategies for casting to float8
32+
"""
33+
34+
# A single scaling factor for the entire tensor
35+
TENSORWISE = "tensorwise"
36+
# Scaling factors computed along one axis of the tensor, reducing it to
37+
# size 1.
38+
AXISWISE = "axiswise"
39+
40+
2941
@dataclass(frozen=True)
3042
class CastConfig:
3143
"""
@@ -146,6 +158,8 @@ class Float8LinearConfig:
146158
# save the fp8_weight_transpose for backward, which is an un-sahrded weight and costs a high memory utilization.
147159
# The longer-term solution is to let compile decide how to partition the graph with optimal computation and memory savings.
148160
# For now, we use the checkpointing api to force the recomputation of fp8 weight in backward.
161+
# TODO(future PR): either enable by default or have a warning and set up the
162+
# tests so that the warning does not spam the CI stdout.
149163

150164
force_recompute_fp8_weight_in_bwd: bool = False
151165

torchao/float8/float8_linear.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import dataclasses
1111
import enum
12-
import logging
1312

1413
from typing import Optional
1514

@@ -50,8 +49,6 @@
5049
WeightWithStaticFloat8CastTensor,
5150
)
5251

53-
logger = logging.getLogger(__name__)
54-
5552

5653
# this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files
5754
@torch._dynamo.allow_in_graph
@@ -191,15 +188,6 @@ def __init__(self, *args, **kwargs):
191188
# would be initialized in every iteration.
192189
self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward
193190

194-
# See the comments in config.py for more details of this option.
195-
if (
196-
self.config.enable_pre_and_post_forward
197-
and not self.config.force_recompute_fp8_weight_in_bwd
198-
):
199-
logger.warning(
200-
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
201-
)
202-
203191
def create_buffers(self):
204192
# Default values for history buffers, see above TODO
205193
history_len = self.config.delayed_scaling_config.history_len

0 commit comments

Comments
 (0)