Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 3265474

Browse files
vkuzofacebook-github-bot
authored andcommitted
clean up casting: rename delayed and dynamic casting functions (#350)
Summary: Pull Request resolved: #350 Renames the delayed and dynamic casting functions to `hp_tensor_to_float8_delayed` and `hp_tensor_to_float8_dynamic` to clarify what they are doing and how they are different from `hp_tensor_and_scale_to_float8`. Reviewed By: drisspg Differential Revision: D60310241 fbshipit-source-id: 1f67f9f4a59e6ed153411834a1fe58ef734c0f02
1 parent 4fb2877 commit 3265474

File tree

5 files changed

+60
-29
lines changed

5 files changed

+60
-29
lines changed

float8_experimental/float8_linear.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
from float8_experimental.float8_scaling_utils import (
2020
_maybe_initialize_amaxes_scales_for_float8_cast,
21-
cast_to_float8_delayed,
22-
cast_to_float8_dynamic,
21+
hp_tensor_to_float8_delayed,
22+
hp_tensor_to_float8_dynamic,
2323
NoopFwToFloat8E5M2BwDelayed,
2424
NoopFwToFloat8E5M2BwDynamic,
2525
)
@@ -260,7 +260,7 @@ def cast_input_to_float8(
260260
is_amax_initialized,
261261
reduce_amax=True,
262262
)
263-
input_fp8 = cast_to_float8_delayed(
263+
input_fp8 = hp_tensor_to_float8_delayed(
264264
input,
265265
self.fp8_scale_input,
266266
e4m3_dtype,
@@ -270,7 +270,9 @@ def cast_input_to_float8(
270270
)
271271
else:
272272
assert self.scaling_type_input is ScalingType.DYNAMIC
273-
input_fp8 = cast_to_float8_dynamic(input, e4m3_dtype, self.linear_mm_config)
273+
input_fp8 = hp_tensor_to_float8_dynamic(
274+
input, e4m3_dtype, self.linear_mm_config
275+
)
274276
return input_fp8
275277

276278
def cast_weight_to_float8(
@@ -292,7 +294,7 @@ def cast_weight_to_float8(
292294
reduce_amax=False,
293295
)
294296

295-
weight_fp8 = cast_to_float8_delayed(
297+
weight_fp8 = hp_tensor_to_float8_delayed(
296298
weight,
297299
self.fp8_scale_weight,
298300
e4m3_dtype,
@@ -305,7 +307,7 @@ def cast_weight_to_float8(
305307
if isinstance(self.weight, Float8Tensor): # cast by FSDP
306308
weight_fp8 = self.weight
307309
else:
308-
weight_fp8 = cast_to_float8_dynamic(
310+
weight_fp8 = hp_tensor_to_float8_dynamic(
309311
self.weight,
310312
e4m3_dtype,
311313
self.linear_mm_config,

float8_experimental/float8_scaling_utils.py

+42-13
Original file line numberDiff line numberDiff line change
@@ -30,37 +30,66 @@
3030
)
3131

3232

33-
def cast_to_float8_dynamic(
34-
inpt_tensor: torch.Tensor,
33+
def hp_tensor_to_float8_dynamic(
34+
hp_tensor: torch.Tensor,
3535
float8_dtype: torch.dtype,
3636
linear_mm_config: LinearMMConfig,
3737
reduce_amax: bool = False,
3838
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
3939
) -> Float8Tensor:
40-
if tensor_already_casted_to_fp8(inpt_tensor):
41-
return inpt_tensor
42-
scale = tensor_to_scale(inpt_tensor, float8_dtype, reduce_amax)
40+
"""
41+
Given a high precision tensor `hp_tensor`,
42+
scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result.
43+
44+
Args:
45+
hp_tensor: the tensor to convert
46+
float8_dtype: the float8 dtype to use
47+
linear_mm_config: Defines the configuration for the scaled_mm for
48+
the 3 fwd/bwd gemms of linear
49+
reduce_amax: whether to reduce the max(abs(hp_tensor)) value across distributed ranks
50+
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
51+
the 3 fwd/bwd gemms of linear
52+
"""
53+
if tensor_already_casted_to_fp8(hp_tensor):
54+
return hp_tensor
55+
scale = tensor_to_scale(hp_tensor, float8_dtype, reduce_amax)
4356
return hp_tensor_and_scale_to_float8(
44-
inpt_tensor,
57+
hp_tensor,
4558
scale,
4659
float8_dtype,
4760
linear_mm_config,
4861
gemm_input_role,
4962
)
5063

5164

52-
def cast_to_float8_delayed(
53-
tensor: torch.Tensor,
54-
scale: torch.Tensor,
65+
def hp_tensor_to_float8_delayed(
66+
hp_tensor: torch.Tensor,
67+
s: torch.Tensor,
5568
float8_dtype: torch.dtype,
5669
amax_buffer: torch.Tensor,
5770
linear_mm_config: Optional[LinearMMConfig] = None,
5871
gemm_input_role: Optional[GemmInputRole] = GemmInputRole.INPUT,
59-
):
60-
amax_buffer.fill_(tensor_to_amax(tensor))
72+
) -> Float8Tensor:
73+
"""
74+
Given a high precision tensor `hp_tensor` and relevant metadata, scales it using
75+
delayed scaling and returns a `Float8Tensor` of the result. Specifically:
76+
1. calculates max(abs(hp_tensor)) and stores the result in `amax_buffer`, inplace
77+
2. scales `hp_tensor` by `s` and returns the result wrapped in Float8Tensor
78+
79+
Args:
80+
hp_tensor: the tensor to convert
81+
s: the scale to use to convert the tensor
82+
float8_dtype: the float8 dtype to use
83+
amax_buffer: the buffer to modify inplace with max(abs(hp_tensor))
84+
linear_mm_config: Defines the configuration for the scaled_mm for
85+
the 3 fwd/bwd gemms of linear
86+
gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
87+
the 3 fwd/bwd gemms of linear
88+
"""
89+
amax_buffer.fill_(tensor_to_amax(hp_tensor))
6190
return hp_tensor_and_scale_to_float8(
62-
tensor,
63-
scale,
91+
hp_tensor,
92+
s,
6493
float8_dtype,
6594
linear_mm_config,
6695
gemm_input_role,

float8_experimental/float8_tensor_parallel.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
from float8_experimental.config import ScalingType
44
from float8_experimental.float8_scaling_utils import (
5-
cast_to_float8_dynamic,
5+
hp_tensor_to_float8_dynamic,
66
NoopFwToFloat8E5M2BwDynamic,
77
)
88
from float8_experimental.float8_tensor import GemmInputRole
@@ -46,7 +46,7 @@ def _prepare_input_fn(
4646
input_tensor, device_mesh, input_layouts, run_check=False
4747
)
4848

49-
input_tensor = cast_to_float8_dynamic(
49+
input_tensor = hp_tensor_to_float8_dynamic(
5050
input_tensor,
5151
e4m3_dtype,
5252
mod.linear_mm_config,
@@ -100,7 +100,7 @@ def _prepare_input_fn(
100100
input_tensor, device_mesh, input_layouts, run_check=False
101101
)
102102

103-
input_tensor = cast_to_float8_dynamic(
103+
input_tensor = hp_tensor_to_float8_dynamic(
104104
input_tensor,
105105
e4m3_dtype,
106106
mod.linear_mm_config,
@@ -199,7 +199,7 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
199199
input, mesh, (input_layout,), run_check=False
200200
)
201201

202-
dt_inp = cast_to_float8_dynamic(
202+
dt_inp = hp_tensor_to_float8_dynamic(
203203
dt_inp,
204204
e4m3_dtype,
205205
self.linear_mm_config,

float8_experimental/fsdp_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import torch.nn as nn
1212
import torch.utils._pytree as pytree
1313
from float8_experimental.float8_scaling_utils import (
14-
cast_to_float8_delayed,
15-
cast_to_float8_dynamic,
14+
hp_tensor_to_float8_delayed,
15+
hp_tensor_to_float8_dynamic,
1616
)
1717

1818
from float8_experimental.float8_tensor import (
@@ -175,7 +175,7 @@ def fsdp_pre_all_gather(self, mesh):
175175
GemmInputRole.WEIGHT,
176176
)
177177
else:
178-
float8_tensor = cast_to_float8_dynamic(
178+
float8_tensor = hp_tensor_to_float8_dynamic(
179179
self._tensor,
180180
e4m3_dtype,
181181
self._linear_mm_config,
@@ -355,7 +355,7 @@ def fsdp_pre_all_gather(self, mesh):
355355
)
356356
self.is_amax_initialized = True
357357

358-
float8_tensor = cast_to_float8_delayed(
358+
float8_tensor = hp_tensor_to_float8_delayed(
359359
self._tensor,
360360
self._scale_buffer,
361361
e4m3_dtype,

test/test_compile.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
get_float8_layers,
2121
sync_float8_amax_and_scale_history,
2222
)
23-
from float8_experimental.float8_scaling_utils import cast_to_float8_delayed
23+
from float8_experimental.float8_scaling_utils import hp_tensor_to_float8_delayed
2424
from float8_experimental.float8_tensor import LinearMMConfig
2525
from float8_experimental.float8_utils import e4m3_dtype
2626

@@ -179,7 +179,7 @@ def __init__(self, graph_break: bool):
179179
self.graph_break = graph_break
180180

181181
def forward(self, x):
182-
x_fp8 = cast_to_float8_delayed(
182+
x_fp8 = hp_tensor_to_float8_delayed(
183183
x,
184184
self.fp8_scale_x,
185185
e4m3_dtype,

0 commit comments

Comments
 (0)