Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 4fb2877

Browse files
vkuzofacebook-github-bot
authored andcommitted
clean up casting: cast_to_float8_e4m3_dynamic -> cast_to_float8_dynamic (#349)
Summary: Pull Request resolved: #349 Moves the dtype from function name to argument, to match delayed scaling version. Reviewed By: drisspg Differential Revision: D60310239 fbshipit-source-id: d266f8d9a17ed3170176c058e9960541a1d3946b
1 parent 6ac2f82 commit 4fb2877

File tree

5 files changed

+19
-14
lines changed

5 files changed

+19
-14
lines changed

float8_experimental/float8_linear.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from float8_experimental.float8_scaling_utils import (
2020
_maybe_initialize_amaxes_scales_for_float8_cast,
2121
cast_to_float8_delayed,
22-
cast_to_float8_e4m3_dynamic,
22+
cast_to_float8_dynamic,
2323
NoopFwToFloat8E5M2BwDelayed,
2424
NoopFwToFloat8E5M2BwDynamic,
2525
)
@@ -270,7 +270,7 @@ def cast_input_to_float8(
270270
)
271271
else:
272272
assert self.scaling_type_input is ScalingType.DYNAMIC
273-
input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config)
273+
input_fp8 = cast_to_float8_dynamic(input, e4m3_dtype, self.linear_mm_config)
274274
return input_fp8
275275

276276
def cast_weight_to_float8(
@@ -305,8 +305,9 @@ def cast_weight_to_float8(
305305
if isinstance(self.weight, Float8Tensor): # cast by FSDP
306306
weight_fp8 = self.weight
307307
else:
308-
weight_fp8 = cast_to_float8_e4m3_dynamic(
308+
weight_fp8 = cast_to_float8_dynamic(
309309
self.weight,
310+
e4m3_dtype,
310311
self.linear_mm_config,
311312
gemm_input_role=GemmInputRole.WEIGHT,
312313
)

float8_experimental/float8_scaling_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -30,25 +30,25 @@
3030
)
3131

3232

33-
def cast_to_float8_e4m3_dynamic(
33+
def cast_to_float8_dynamic(
3434
inpt_tensor: torch.Tensor,
35+
float8_dtype: torch.dtype,
3536
linear_mm_config: LinearMMConfig,
3637
reduce_amax: bool = False,
3738
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
3839
) -> Float8Tensor:
3940
if tensor_already_casted_to_fp8(inpt_tensor):
4041
return inpt_tensor
41-
scale = tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
42+
scale = tensor_to_scale(inpt_tensor, float8_dtype, reduce_amax)
4243
return hp_tensor_and_scale_to_float8(
4344
inpt_tensor,
4445
scale,
45-
e4m3_dtype,
46+
float8_dtype,
4647
linear_mm_config,
4748
gemm_input_role,
4849
)
4950

5051

51-
# TODO(future PR): align name with cast_to_float8_e4m3_dynamic
5252
def cast_to_float8_delayed(
5353
tensor: torch.Tensor,
5454
scale: torch.Tensor,

float8_experimental/float8_tensor_parallel.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import torch.nn as nn
33
from float8_experimental.config import ScalingType
44
from float8_experimental.float8_scaling_utils import (
5-
cast_to_float8_e4m3_dynamic,
5+
cast_to_float8_dynamic,
66
NoopFwToFloat8E5M2BwDynamic,
77
)
88
from float8_experimental.float8_tensor import GemmInputRole
9+
from float8_experimental.float8_utils import e4m3_dtype
910
from torch.distributed._tensor import DTensor
1011
from torch.distributed.device_mesh import DeviceMesh
1112
from torch.distributed.tensor.parallel import (
@@ -45,8 +46,9 @@ def _prepare_input_fn(
4546
input_tensor, device_mesh, input_layouts, run_check=False
4647
)
4748

48-
input_tensor = cast_to_float8_e4m3_dynamic(
49+
input_tensor = cast_to_float8_dynamic(
4950
input_tensor,
51+
e4m3_dtype,
5052
mod.linear_mm_config,
5153
gemm_input_role=GemmInputRole.INPUT,
5254
) # DTensor(Float8Tensor)
@@ -98,8 +100,9 @@ def _prepare_input_fn(
98100
input_tensor, device_mesh, input_layouts, run_check=False
99101
)
100102

101-
input_tensor = cast_to_float8_e4m3_dynamic(
103+
input_tensor = cast_to_float8_dynamic(
102104
input_tensor,
105+
e4m3_dtype,
103106
mod.linear_mm_config,
104107
gemm_input_role=GemmInputRole.INPUT,
105108
) # DTensor(Float8Tensor)
@@ -196,8 +199,9 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
196199
input, mesh, (input_layout,), run_check=False
197200
)
198201

199-
dt_inp = cast_to_float8_e4m3_dynamic(
202+
dt_inp = cast_to_float8_dynamic(
200203
dt_inp,
204+
e4m3_dtype,
201205
self.linear_mm_config,
202206
gemm_input_role=GemmInputRole.INPUT,
203207
) # DTensor(Float8Tensor)

float8_experimental/fsdp_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch.utils._pytree as pytree
1313
from float8_experimental.float8_scaling_utils import (
1414
cast_to_float8_delayed,
15-
cast_to_float8_e4m3_dynamic,
15+
cast_to_float8_dynamic,
1616
)
1717

1818
from float8_experimental.float8_tensor import (
@@ -175,8 +175,9 @@ def fsdp_pre_all_gather(self, mesh):
175175
GemmInputRole.WEIGHT,
176176
)
177177
else:
178-
float8_tensor = cast_to_float8_e4m3_dynamic(
178+
float8_tensor = cast_to_float8_dynamic(
179179
self._tensor,
180+
e4m3_dtype,
180181
self._linear_mm_config,
181182
reduce_amax=True,
182183
gemm_input_role=GemmInputRole.WEIGHT,

test/test_base.py

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
sync_float8_amax_and_scale_history,
2525
)
2626
from float8_experimental.float8_python_api import addmm_float8_unwrapped
27-
from float8_experimental.float8_scaling_utils import cast_to_float8_e4m3_dynamic
2827
from float8_experimental.float8_tensor import (
2928
Float8Tensor,
3029
GemmInputRole,

0 commit comments

Comments
 (0)