Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit ce48fbc

Browse files
committed
[wip] add axiswise granularity to Float8Tensor
Summary: This PR adds the axiswise scaling granularity to `Float8Tensor` and ensures that basic ops like transpose and `torch._scaled_mm` work as expected. A future PR will add integration with `Float8Linear`. Test Plan: TODO Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 163452e Pull Request resolved: #352
1 parent 3265474 commit ce48fbc

File tree

6 files changed

+235
-21
lines changed

6 files changed

+235
-21
lines changed

float8_experimental/config.py

+12
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ def short_str(self):
2121
return "dyn"
2222

2323

24+
class ScalingGranularity(enum.Enum):
25+
"""
26+
Defines the granularity of scaling strategies for casting to float8
27+
"""
28+
29+
# A single scaling factor for the entire tensor
30+
TENSORWISE = "tensorwise"
31+
# Scaling factors computed along one axis of the tensor, reducing it to
32+
# size 1.
33+
AXISWISE = "axiswise"
34+
35+
2436
@dataclass(frozen=True)
2537
class CastConfig:
2638
"""

float8_experimental/float8_ops.py

+72-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
FLOAT8_OPS_TABLE: Dict[Any, Any] = {}
2020

2121

22+
def _assert_tensorwise_scale(aten_op, scale):
23+
assert (
24+
# TODO(future PR): figure out why tensorwise scaling can have
25+
# both rank 0 and rank 1
26+
len(scale.shape)
27+
in (0, 1)
28+
), f"{aten_op} with axiswise scaling is not supported yet"
29+
30+
2231
def implements(aten_ops):
2332
"""Register aten ops to the float8 op table"""
2433

@@ -32,18 +41,16 @@ def decorator(func):
3241

3342
@implements(
3443
[
35-
aten.view.default,
3644
aten._unsafe_view.default,
37-
aten.t.default,
3845
aten.as_strided.default,
3946
aten.clone.default,
4047
aten.detach.default,
4148
aten.slice.Tensor,
42-
aten.transpose.int,
4349
aten.fill_.Scalar,
4450
]
4551
)
4652
def float8_desugar_op(aten_op, args, kwargs=None):
53+
_assert_tensorwise_scale(aten_op, args[0]._scale)
4754
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
4855
return Float8Tensor(
4956
new_data,
@@ -54,8 +61,61 @@ def float8_desugar_op(aten_op, args, kwargs=None):
5461
)
5562

5663

64+
@implements(
65+
[
66+
aten.t.default,
67+
aten.transpose.int,
68+
]
69+
)
70+
def float8_desugar_data_and_scale(aten_op, args, kwargs=None):
71+
new_data = aten_op(args[0]._data, *args[1:], **kwargs)
72+
new_scale = aten_op(args[0]._scale, *args[1:], **kwargs)
73+
return Float8Tensor(
74+
new_data,
75+
new_scale,
76+
args[0]._orig_dtype,
77+
args[0]._linear_mm_config,
78+
args[0]._gemm_input_role,
79+
)
80+
81+
82+
@implements([aten.view.default])
83+
def float8_view(aten_op, args, kwargs=None):
84+
if len(args[0]._scale.shape) < 2:
85+
# tensorwise scaling
86+
return float8_desugar_op(aten_op, args, kwargs)
87+
88+
t, new_shape = args[0], args[1]
89+
# for now, only support reshaping to [-1, dim] or [dim, -1]
90+
if len(new_shape) == 2:
91+
if new_shape == [t.shape[0], -1] and t._scale.shape[0] == 1:
92+
new_data = aten_op(t._data, new_shape, **kwargs)
93+
new_scale = aten_op(t._scale, [1, -1], **kwargs)
94+
return Float8Tensor(
95+
new_data,
96+
new_scale,
97+
t._orig_dtype,
98+
t._linear_mm_config,
99+
t._gemm_input_role,
100+
)
101+
elif new_shape == [-1, t.shape[-1]] and t._scale.shape[-1] == 1:
102+
new_data = aten_op(t._data, new_shape, **kwargs)
103+
new_scale = aten_op(t._scale, [-1, 1], **kwargs)
104+
return Float8Tensor(
105+
new_data,
106+
new_scale,
107+
t._orig_dtype,
108+
t._linear_mm_config,
109+
t._gemm_input_role,
110+
)
111+
raise AssertionError(
112+
f"{aten_op} with axiswise scaling and t.shape {t.shape} t._scale.shape {t._scale.shape} new_shape {new_shape} is not supported yet."
113+
)
114+
115+
57116
@implements([aten.split.Tensor])
58117
def float8_split(aten_op, args, kwargs=None):
118+
_assert_tensorwise_scale(aten_op, args[0]._scale)
59119
new_data_tensors = aten_op(args[0]._data, *args[1:], **kwargs)
60120

61121
def make_float8(data):
@@ -101,6 +161,7 @@ def float8_cat(aten_op, args, kwargs=None):
101161
assert (
102162
chunk._gemm_input_role is gemm_input_role
103163
), "Expecting all chunks to have the same gemm_input_role as a result of a split"
164+
_assert_tensorwise_scale(aten_op, chunk._scale)
104165
chunk_data.append(chunk._data.view(torch.uint8))
105166

106167
new_data = aten_op(chunk_data, *args[1:], **kwargs)
@@ -117,6 +178,7 @@ def float8_cast_up_op(aten_op, args, kwargs=None):
117178
"addmm" -> out
118179
"hp_gradBias" <-"sum" <- "identity" <- gradOut <- "hp_gradOut"
119180
"""
181+
_assert_tensorwise_scale(aten_op, args[0]._scale)
120182

121183
def unwrap(x):
122184
if isinstance(x, Float8Tensor):
@@ -229,6 +291,7 @@ def float8_addmm(aten_op, args, kwargs=None):
229291

230292
@implements([aten.is_same_size.default])
231293
def float8_is_same_size(aten_op, args, kwargs=None):
294+
_assert_tensorwise_scale(aten_op, args[0]._scale)
232295
return args[0].shape == args[1].shape
233296

234297

@@ -238,6 +301,7 @@ def autocast_to_copy(aten_op, args, kwargs=None):
238301
when the input is a Float8Tensor, presenting as a fp32
239302
tensor.
240303
"""
304+
_assert_tensorwise_scale(aten_op, args[0]._scale)
241305
assert isinstance(args[0], Float8Tensor)
242306
assert (
243307
len(kwargs) == 1 and "dtype" in kwargs
@@ -265,6 +329,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
265329
"""
266330
override funcol with FP8 handling
267331
"""
332+
_assert_tensorwise_scale(aten_op, args[0]._scale)
268333
fp8_input = args[0]
269334
assert isinstance(
270335
fp8_input, Float8Tensor
@@ -284,6 +349,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
284349

285350
@implements([c10d_functional.wait_tensor.default, _c10d_functional.wait_tensor.default])
286351
def wait_tensor_fp8(aten_op, args, kwargs=None):
352+
_assert_tensorwise_scale(aten_op, args[0]._scale)
287353
fp8_input = args[0]
288354
assert isinstance(fp8_input, Float8Tensor)
289355

@@ -304,6 +370,7 @@ def index_put_fp8(aten_op, args, kwargs=None):
304370
fp8_values = args[2]
305371
assert isinstance(fp8_self, Float8Tensor)
306372
assert isinstance(fp8_values, Float8Tensor)
373+
_assert_tensorwise_scale(fp8_self, args[0]._scale)
307374
assert fp8_self._scale == fp8_values._scale
308375
assert fp8_self.dtype == fp8_values.dtype
309376
assert fp8_self._orig_dtype == fp8_values._orig_dtype
@@ -334,8 +401,10 @@ def copy_fp8(aten_op, args, kwargs=None):
334401

335402
if not isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
336403
src_hp = src.to_original_precision()
404+
_assert_tensorwise_scale(aten_op, src._scale)
337405
return aten_op(self, src_hp, *args[2:], **kwargs)
338406
elif isinstance(self, Float8Tensor) and isinstance(src, Float8Tensor):
407+
_assert_tensorwise_scale(aten_op, src._scale)
339408
assert (
340409
self._orig_dtype == src._orig_dtype
341410
), "Expecting both Float8Tensors to be of the same dtype"

float8_experimental/float8_scaling_utils.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import torch
1414

15+
from float8_experimental.config import ScalingGranularity
16+
1517
from float8_experimental.float8_tensor import (
1618
Float8Tensor,
1719
GemmInputRole,
@@ -36,6 +38,8 @@ def hp_tensor_to_float8_dynamic(
3638
linear_mm_config: LinearMMConfig,
3739
reduce_amax: bool = False,
3840
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
41+
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
42+
axiswise_dim: Optional[int] = None,
3943
) -> Float8Tensor:
4044
"""
4145
Given a high precision tensor `hp_tensor`,
@@ -49,10 +53,18 @@ def hp_tensor_to_float8_dynamic(
4953
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
5054
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
5155
the 3 fwd/bwd gemms of linear
56+
scaling_granularity: Defines the scaling granularity
57+
axiswise_dim: if axiswise granularity is used, defines the dim to scale across
5258
"""
5359
if tensor_already_casted_to_fp8(hp_tensor):
5460
return hp_tensor
55-
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax)
61+
scale = tensor_to_scale(
62+
hp_tensor,
63+
float8_dtype,
64+
reduce_amax,
65+
scaling_granularity,
66+
axiswise_dim,
67+
)
5668
return hp_tensor_and_scale_to_float8(
5769
hp_tensor,
5870
scale,

float8_experimental/float8_tensor.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,12 @@ class Float8Tensor(torch.Tensor):
250250
* `_data`: the underlying e4m3 or e5m2 data
251251
* `_scale`: the scale used to scale the original fp32 tensor. We multiply
252252
by scale to go from fp32 range to fp8 range, and divide by scale to go
253-
from fp8 range to fp32 range.
253+
from fp8 range to fp32 range. Scale is guaranteed to have a shape compatible
254+
with `_data`. For example:
255+
- if scaling is tensorwise, `_scale` is a scalar tensor
256+
- if scaling is axiswise and _data.shape is [3, 5], `_scale` could have
257+
shape [1, 5] or [5, 1]. The dim of the non-one entry defines the scaling
258+
axis.
254259
* `_orig_dtype`: the original dtype of the tensor used to create this
255260
tensor.
256261
* `_emulate`: if true using fp32 emulation for the matmuls, helpful
@@ -279,12 +284,6 @@ def __new__(
279284
linear_mm_config: Optional[LinearMMConfig],
280285
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
281286
):
282-
assert (
283-
scale.numel() == 1
284-
), "Scale should contain a single value, but got: {} elements".format(
285-
scale.numel()
286-
)
287-
288287
self = torch.Tensor._make_wrapper_subclass(
289288
cls,
290289
data.size(),

float8_experimental/float8_utils.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Iterable, Literal, Tuple, Union
7+
from typing import Iterable, Literal, Optional, Tuple, Union
88

99
import float8_experimental.config as config
1010

1111
import torch
1212
import torch.distributed as dist
13+
from float8_experimental.config import ScalingGranularity
1314

1415
# Helpful visualizer for debugging (only supports fp32):
1516
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
@@ -100,8 +101,18 @@ def amax_history_to_scale_stack(
100101

101102

102103
@torch.no_grad()
103-
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
104-
amax = torch.max(torch.abs(x))
104+
def tensor_to_amax(
105+
x: torch.Tensor,
106+
reduce_amax: bool = False,
107+
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
108+
axiswise_dim: Optional[int] = None,
109+
) -> torch.Tensor:
110+
if scaling_granularity is ScalingGranularity.TENSORWISE:
111+
amax = torch.max(torch.abs(x))
112+
else:
113+
assert scaling_granularity is ScalingGranularity.AXISWISE, "unsupported"
114+
assert axiswise_dim is not None, "unsupported"
115+
amax = torch.amax(torch.abs(x), dim=axiswise_dim, keepdim=True)
105116

106117
# If the user asked for distributed reduction, do it.
107118
# If the user did not ask for it, assume that it will
@@ -114,9 +125,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
114125

115126
@torch.no_grad()
116127
def tensor_to_scale(
117-
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
128+
x: torch.Tensor,
129+
float8_dtype: torch.dtype,
130+
reduce_amax: bool = False,
131+
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
132+
axiswise_dim: Optional[int] = None,
118133
) -> torch.Tensor:
119-
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
134+
amax = tensor_to_amax(x, reduce_amax, scaling_granularity, axiswise_dim)
120135
return amax_to_scale(amax, float8_dtype, x.dtype)
121136

122137

0 commit comments

Comments
 (0)