Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 110ec4b

Browse files
committed
cast to fp32 in amax
1 parent 1dd4573 commit 110ec4b

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

float8_experimental/float8_utils.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@
2323

2424

2525
@torch.no_grad()
26-
def amax_to_scale(amax, float8_dtype, orig_dtype):
27-
scale = torch.empty_like(amax, dtype=torch.float32)
26+
def amax_to_scale(
27+
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
28+
):
29+
assert amax.dtype == torch.float32, "amax must be a float32 tensor"
2830
if float8_dtype == torch.float8_e4m3fn:
2931
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
3032
else: # e5m2
@@ -35,16 +37,15 @@ def amax_to_scale(amax, float8_dtype, orig_dtype):
3537
# to care about this for float32/bfloat16.
3638
if orig_dtype is torch.float16:
3739
res = torch.clamp(res, max=FP16_MAX_POS)
38-
scale.copy_(res)
39-
return scale
40+
return res
4041

4142

4243
@torch.no_grad()
4344
def amax_history_to_scale(
44-
amax_history,
45-
float8_dtype,
46-
orig_dtype,
47-
history_to_scale_fn_type,
45+
amax_history: torch.Tensor,
46+
float8_dtype: torch.dtype,
47+
orig_dtype: torch.dtype,
48+
history_to_scale_fn_type: str,
4849
):
4950
if history_to_scale_fn_type == "max":
5051
amax = torch.max(amax_history)
@@ -87,7 +88,7 @@ def tensor_to_amax(x, distributed_reduction=False):
8788

8889

8990
@torch.no_grad()
90-
def tensor_to_scale(x, float8_dtype):
91+
def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype):
9192
amax = tensor_to_amax(x)
9293
if float8_experimental.config.use_fused_cast and x.is_cuda:
9394
from float8_experimental.fused_kernels.fused_casting_kernels import (

0 commit comments

Comments
 (0)