Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 37c9f0a

Browse files
committed
[wip] add scaling granularity
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f21151b Pull Request resolved: #338
1 parent c520ddd commit 37c9f0a

File tree

5 files changed

+49
-12
lines changed

5 files changed

+49
-12
lines changed

float8_experimental/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ def short_str(self):
2121
return "dyn"
2222

2323

24+
class ScalingGranularity(enum.Enum):
25+
"""
26+
Defines the granularity of scaling strategies for casting to float8
27+
"""
28+
29+
# A single scaling factor for the entire tensor
30+
TENSORWISE = "tensorwise"
31+
# Scaling factors computed along one axis of the tensor, reducing it to
32+
# size 1.
33+
AXISWISE = "axiswise"
34+
35+
2436
@dataclass(frozen=True)
2537
class CastConfig:
2638
"""

float8_experimental/float8_dynamic_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional, Tuple, Union
8+
79
import torch
810

11+
from float8_experimental.config import ScalingGranularity
912
from float8_experimental.float8_tensor import (
1013
Float8Tensor,
1114
GemmInputRole,
@@ -52,10 +55,12 @@ def cast_to_float8_e4m3_dynamic(
5255
linear_mm_config: LinearMMConfig,
5356
reduce_amax: bool = False,
5457
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
58+
granularity: ScalingGranularity = ScalingGranularity.TENSORWISE,
59+
dim: Optional[Union[int, Tuple[int]]] = None,
5560
) -> Float8Tensor:
5661
if tensor_already_casted_to_fp8(inpt_tensor):
5762
return inpt_tensor
58-
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
63+
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax, granularity, dim)
5964
return Float8Tensor.to_float8(
6065
inpt_tensor,
6166
scale,

float8_experimental/float8_tensor.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,6 @@ def __new__(
290290
linear_mm_config: Optional[LinearMMConfig],
291291
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
292292
):
293-
assert (
294-
scale.numel() == 1
295-
), "Scale should contain a single value, but got: {} elements".format(
296-
scale.numel()
297-
)
298-
299293
self = torch.Tensor._make_wrapper_subclass(
300294
cls,
301295
data.size(),

float8_experimental/float8_utils.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Iterable, Literal, Tuple, Union
7+
from typing import Iterable, Literal, Optional, Tuple, Union
88

99
import float8_experimental.config as config
10+
from float8_experimental.config import ScalingGranularity
1011

1112
import torch
1213
import torch.distributed as dist
@@ -100,8 +101,18 @@ def amax_history_to_scale_stack(
100101

101102

102103
@torch.no_grad()
103-
def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
104-
amax = torch.max(torch.abs(x))
104+
def tensor_to_amax(
105+
x: torch.Tensor,
106+
reduce_amax: bool = False,
107+
granularity: ScalingGranularity = ScalingGranularity.AXISWISE,
108+
dim: Optional[Union[int, Tuple[int]]] = None,
109+
) -> torch.Tensor:
110+
if granularity is ScalingGranularity.TENSORWISE:
111+
amax = torch.max(torch.abs(x))
112+
else:
113+
assert granularity is ScalingGranularity.AXISWISE, "unsupported"
114+
assert dim is not None, "unsupported"
115+
amax = torch.amax(torch.abs(x), dim=dim, keepdim=True)
105116

106117
# If the user asked for distributed reduction, do it.
107118
# If the user did not ask for it, assume that it will
@@ -114,9 +125,13 @@ def tensor_to_amax(x: torch.Tensor, reduce_amax: bool = False) -> torch.Tensor:
114125

115126
@torch.no_grad()
116127
def tensor_to_scale(
117-
x: torch.Tensor, float8_dtype: torch.dtype, reduce_amax: bool = False
128+
x: torch.Tensor,
129+
float8_dtype: torch.dtype,
130+
reduce_amax: bool = False,
131+
granularity: ScalingGranularity = ScalingGranularity.AXISWISE,
132+
dim: Optional[Union[int, Tuple[int]]] = None,
118133
) -> torch.Tensor:
119-
amax = tensor_to_amax(x, reduce_amax=reduce_amax)
134+
amax = tensor_to_amax(x, reduce_amax=reduce_amax, granularity=granularity, dim=dim)
120135
return amax_to_scale(amax, float8_dtype, x.dtype)
121136

122137

test/test_base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,17 @@ def test_weights_only_load(self):
143143
buffer.seek(0)
144144
_ = torch.load(buffer, weights_only=True)
145145

146+
def test_axiswise_dynamic_cast(self):
147+
a = torch.randn(16, 32, dtype=torch.bfloat16)
148+
linear_mm_config = LinearMMConfig()
149+
a_fp8 = cast_to_float8_e4m3_dynamic(
150+
a,
151+
linear_mm_config,
152+
granularity=ScalingGranularity.AXISWISE,
153+
dim=0,
154+
)
155+
print(a_fp8)
156+
146157

147158
class TestFloat8Linear:
148159
def _test_linear_impl(

0 commit comments

Comments
 (0)