Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit e87f005

Browse files
committed
[wip] add axiswise granularity to Float8Tensor
Summary: This PR adds the axiswise scaling granularity to `Float8Tensor` and ensures that basic ops like transpose and `torch._scaled_mm` work as expected. A future PR will add integration with `Float8Linear`. Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 3b786be commit e87f005

File tree

7 files changed

+151
-16
lines changed

7 files changed

+151
-16
lines changed

float8_experimental/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ def short_str(self):
2121
return "dyn"
2222

2323

24+
class ScalingGranularity(enum.Enum):
25+
"""
26+
Defines the granularity of scaling strategies for casting to float8
27+
"""
28+
29+
# A single scaling factor for the entire tensor
30+
TENSORWISE = "tensorwise"
31+
# Scaling factors computed along one axis of the tensor, reducing it to
32+
# size 1.
33+
AXISWISE = "axiswise"
34+
35+
2436
@dataclass(frozen=True)
2537
class CastConfig:
2638
"""

float8_experimental/float8_ops.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}
2020

2121

22+
def _assert_tensorwise_scale(aten_op, scale):
23+
assert (
24+
# TODO(future PR): figure out why tensorwise scaling can have
25+
# both rank 0 and rank 1
26+
len(scale.shape)
27+
in (0, 1)
28+
), f"{aten_op} with axiswise scaling is not supported yet"
29+
30+
2231
def implements(aten_ops):
2332
"""Register aten ops to the float8 op table"""
2433

@@ -34,16 +43,15 @@ def decorator(func):
3443
[
3544
aten.view.default,
3645
aten._unsafe_view.default,
37-
aten.t.default,
3846
aten.as_strided.default,
3947
aten.clone.default,
4048
aten.detach.default,
4149
aten.slice.Tensor,
42-
aten.transpose.int,
4350
aten.fill_.Scalar,
4451
]
4552
)
4653
def float8_desugar_op(aten_op, args, kwargs=None):
54+
_assert_tensorwise_scale(aten_op, args[0]._scale)
4755
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
4856
return Float8Tensor(
4957
new_data,
@@ -54,8 +62,27 @@ def float8_desugar_op(aten_op, args, kwargs=None):
5462
)
5563

5664

65+
@implements(
66+
[
67+
aten.t.default,
68+
aten.transpose.int,
69+
]
70+
)
71+
def float8_desugar_data_and_scale(aten_op, args, kwargs=None):
72+
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
73+
new_scale = aten_op(args[0]._scale, *args[1:], **kwargs)
74+
return Float8Tensor(
75+
new_data,
76+
new_scale,
77+
args[0]._orig_dtype,
78+
args[0]._linear_mm_config,
79+
args[0]._gemm_input_role,
80+
)
81+
82+
5783
@implements([aten.split.Tensor])
5884
def float8_split(aten_op, args, kwargs=None):
85+
_assert_tensorwise_scale(aten_op, args[0]._scale)
5986
new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs)
6087

6188
def make_float8(data):
@@ -101,6 +128,7 @@ def float8_cat(aten_op, args, kwargs=None):
101128
assert (
102129
chunk._gemm_input_role is gemm_input_role
103130
), "Expecting all chunks to have the same gemm_input_role as a result of a split"
131+
_assert_tensorwise_scale(aten_op, chunk._scale)
104132
chunk_data.append(chunk._data.view(torch.uint8))
105133

106134
new_data = aten_op(chunk_data, *args[1:], **kwargs)
@@ -117,6 +145,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None):
117145
"addmm" -> out
118146
"hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut"
119147
"""
148+
_assert_tensorwise_scale(aten_op, args[0]._scale)
120149

121150
def unwrap(x):
122151
if isinstance(x, Float8Tensor):
@@ -229,6 +258,7 @@ def float8_addmm(aten_op, args, kwargs=None):
229258

230259
@implements([aten.is_same_size.default])
231260
def float8_is_same_size(aten_op, args, kwargs=None):
261+
_assert_tensorwise_scale(aten_op, args[0]._scale)
232262
return args[0].shape == args[1].shape
233263

234264

@@ -238,6 +268,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
238268
when the input is a Float8Tensor, presenting as a fp32
239269
tensor.
240270
"""
271+
_assert_tensorwise_scale(aten_op, args[0]._scale)
241272
assert isinstance(args[0], Float8Tensor)
242273
assert (
243274
len(kwargs) == 1 and "dtype" in kwargs
@@ -265,6 +296,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
265296
"""
266297
override funcol with FP8 handling
267298
"""
299+
_assert_tensorwise_scale(aten_op, args[0]._scale)
268300
fp8_input = args[0]
269301
assert isinstance(
270302
fp8_input, Float8Tensor
@@ -284,6 +316,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
284316

285317
@implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default])
286318
def wait_tensor_fp8(aten_op, args, kwargs=None):
319+
_assert_tensorwise_scale(aten_op, args[0]._scale)
287320
fp8_input = args[0]
288321
assert isinstance(fp8_input, Float8Tensor)
289322

@@ -304,6 +337,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
304337
fp8_values = args[2]
305338
assert isinstance(fp8_self, Float8Tensor)
306339
assert isinstance(fp8_values, Float8Tensor)
340+
_assert_tensorwise_scale(fp8_self, args[0]._scale)
307341
assert fp8_self._scale == fp8_values._scale
308342
assert fp8_self.dtype == fp8_values.dtype
309343
assert fp8_self._orig_dtype == fp8_values._orig_dtype
@@ -334,8 +368,10 @@ def copy_fp8(aten_op, args, kwargs=None):
334368

335369
if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
336370
src_hp = src.to_original_precision()
371+
_assert_tensorwise_scale(aten_op, src._scale)
337372
return aten_op(self, src_hp, *args[2:], **kwargs)
338373
elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
374+
_assert_tensorwise_scale(aten_op, src._scale)
339375
assert (
340376
self._orig_dtype == src._orig_dtype
341377
), "Expecting both Float8Tensors to be of the same dtype"

float8_experimental/float8_python_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ def addmm_float8_unwrapped(
3838
"""
3939
a_inverse_scale = a_scale.reciprocal()
4040
b_inverse_scale = b_scale.reciprocal()
41+
42+
# TODO: should we change torch._scaled_mm?
43+
# torch._scaled_mm expects rowwise scaled scales to be of rank 1, not rank
44+
# 2. Translate to this format.
45+
# TODO: audit if we need to make this more generic for various shapes.
46+
a_inverse_scale = a_inverse_scale.squeeze()
47+
b_inverse_scale = b_inverse_scale.squeeze()
48+
4149
if output_dtype == torch.float32 and bias is not None:
4250
# Bias is not supported by _scaled_mm when output is fp32
4351
output = torch._scaled_mm(

float8_experimental/float8_scaling_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import torch
1414

15+
from float8_experimental.config import ScalingGranularity
16+
1517
from float8_experimental.float8_tensor import (
1618
Float8Tensor,
1719
GemmInputRole,
@@ -36,6 +38,8 @@ def hp_tensor_to_float8_dynamic(
3638
linear_mm_config: LinearMMConfig,
3739
reduce_amax: bool = False,
3840
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
41+
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
42+
axiswise_dim: Optional[int] = None,
3943
) -> Float8Tensor:
4044
"""
4145
Given a high precision tensor `hp_tensor`,
@@ -49,10 +53,18 @@ def hp_tensor_to_float8_dynamic(
4953
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
5054
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
5155
the 3 fwd/bwd gemms of linear
56+
scaling_granularity: Defines the scaling granularity
57+
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
5258
"""
5359
if tensor_already_casted_to_fp8(hp_tensor):
5460
return hp_tensor
55-
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax)
61+
scale = tensor_to_scale(
62+
hp_tensor,
63+
float8_dtype,
64+
reduce_amax,
65+
scaling_granularity,
66+
axiswise_dim,
67+
)
5668
return hp_tensor_and_scale_to_float8(
5769
hp_tensor,
5870
scale,

float8_experimental/float8_tensor.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,12 @@ class Float8Tensor(torch.Tensor):
250250
* `_data`: the underlying e4m3 or e5m2 data
251251
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
252252
by scale to go from fp32 range to fp8 range, and divide by scale to go
253-
from fp8 range to fp32 range.
253+
from fp8 range to fp32 range. Scale is guaranteed to have a shape compatible
254+
with `_data`. For example:
255+
- if scaling is tensorwise, `_scale` is a scalar tensor
256+
- if scaling is axiswise and _data.shape is [3, 5], `_scale` could have
257+
shape [1, 5] or [5, 1]. The dim of the non-one entry defines the scaling
258+
axis.
254259
* `_orig_dtype`: the original dtype of the tensor used to create this
255260
tensor.
256261
* `_emulate`: if true using fp32 emulation for the matmuls, helpful
@@ -279,12 +284,6 @@ def __new__(
279284
linear_mm_config: Optional[LinearMMConfig],
280285
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
281286
):
282-
assert (
283-
scale.numel() == 1
284-
), "Scale should contain a single value, but got: {} elements".format(
285-
scale.numel()
286-
)
287-
288287
self = torch.Tensor._make_wrapper_subclass(
289288
cls,
290289
data.size(),

float8_experimental/float8_utils.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Iterable, Literal, Tuple, Union
7+
from typing import Iterable, Literal, Optional, Tuple, Union
88

99
import float8_experimental.config as config
1010

1111
import torch
1212
import torch.distributed as dist
13+
from float8_experimental.config import ScalingGranularity
1314

1415
# Helpful visualizer for debugging (only supports fp32):
1516
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
@@ -100,8 +101,23 @@ def amax_history_to_scale_stack(
100101

101102

102103
@torch.no_grad()
103-
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
104-
amax = torch.max(torch.abs(x))
104+
def tensor_to_amax(
105+
x: torch.Tensor,
106+
reduce_amax: bool = False,
107+
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
108+
axiswise_dim: Optional[int] = None,
109+
) -> torch.Tensor:
110+
if scaling_granularity is ScalingGranularity.TENSORWISE:
111+
amax = torch.max(torch.abs(x))
112+
else:
113+
assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported"
114+
assert axiswise_dim is not None, "unsupported"
115+
116+
# convert from axiswise_dim (dim to keep) to
117+
# dim as the input to the `torch.amax` function (tuple of dims to reduce)
118+
dim_to_reduce = tuple(d for d in range(len(x.shape)) if d != axiswise_dim)
119+
120+
amax = torch.amax(torch.abs(x), dim=dim_to_reduce, keepdim=True)
105121

106122
# If the user asked for distributed reduction, do it.
107123
# If the user did not ask for it, assume that it will
@@ -114,9 +130,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
114130

115131
@torch.no_grad()
116132
def tensor_to_scale(
117-
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
133+
x: torch.Tensor,
134+
float8_dtype: torch.dtype,
135+
reduce_amax: bool = False,
136+
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
137+
axiswise_dim: Optional[int] = None,
118138
) -> torch.Tensor:
119-
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
139+
amax = tensor_to_amax(x, reduce_amax, scaling_granularity, axiswise_dim)
120140
return amax_to_scale(amax, float8_dtype, x.dtype)
121141

122142

test/test_base.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,20 @@
1616
import torch
1717
import torch.nn as nn
1818

19-
from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType
19+
from float8_experimental.config import (
20+
CastConfig,
21+
Float8LinearConfig,
22+
ScalingGranularity,
23+
ScalingType,
24+
)
2025
from float8_experimental.float8_linear import Float8Linear
2126
from float8_experimental.float8_linear_utils import (
2227
convert_to_float8_training,
2328
linear_requires_sync,
2429
sync_float8_amax_and_scale_history,
2530
)
2631
from float8_experimental.float8_python_api import addmm_float8_unwrapped
32+
from float8_experimental.float8_scaling_utils import hp_tensor_to_float8_dynamic
2733
from float8_experimental.float8_tensor import (
2834
Float8Tensor,
2935
GemmInputRole,
@@ -143,6 +149,48 @@ def test_weights_only_load(self):
143149
buffer.seek(0)
144150
_ = torch.load(buffer, weights_only=True)
145151

152+
def test_axiswise_dynamic_cast(self):
153+
a = torch.randn(16, 32, dtype=torch.bfloat16)
154+
linear_mm_config = LinearMMConfig()
155+
a_fp8 = hp_tensor_to_float8_dynamic(
156+
a,
157+
e4m3_dtype,
158+
linear_mm_config,
159+
scaling_granularity=ScalingGranularity.AXISWISE,
160+
axiswise_dim=0,
161+
)
162+
# print(a_fp8)
163+
# print(a_fp8.to_original_precision())
164+
# print(a_fp8.t())
165+
b = a_fp8.t()
166+
# TODO check numerical accuracy
167+
168+
def test_axiswise_gemm(self):
169+
a = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda")
170+
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
171+
172+
linear_mm_config = LinearMMConfig()
173+
174+
a_fp8 = hp_tensor_to_float8_dynamic(
175+
a,
176+
e4m3_dtype,
177+
linear_mm_config,
178+
gemm_input_role=GemmInputRole.INPUT,
179+
scaling_granularity=ScalingGranularity.AXISWISE,
180+
axiswise_dim=0,
181+
)
182+
b_fp8 = hp_tensor_to_float8_dynamic(
183+
b,
184+
e4m3_dtype,
185+
linear_mm_config,
186+
gemm_input_role=GemmInputRole.WEIGHT,
187+
scaling_granularity=ScalingGranularity.AXISWISE,
188+
axiswise_dim=0,
189+
)
190+
c = torch.mm(a_fp8, b_fp8.t())
191+
print(c)
192+
# TODO check numerical accuracy
193+
146194

147195
class TestFloat8Linear:
148196
def _test_linear_impl(

0 commit comments

Comments
 (0)