23
23
24
24
25
25
@torch .no_grad ()
26
- def amax_to_scale (amax , float8_dtype , orig_dtype ):
27
- scale = torch .empty_like (amax , dtype = torch .float32 )
26
+ def amax_to_scale (
27
+ amax : torch .Tensor , float8_dtype : torch .dtype , orig_dtype : torch .dtype
28
+ ):
29
+ assert amax .dtype == torch .float32 , "amax must be a float32 tensor"
28
30
if float8_dtype == torch .float8_e4m3fn :
29
31
res = E4M3_MAX_POS / torch .clamp (amax , min = EPS )
30
32
else : # e5m2
@@ -35,16 +37,15 @@ def amax_to_scale(amax, float8_dtype, orig_dtype):
35
37
# to care about this for float32/bfloat16.
36
38
if orig_dtype is torch .float16 :
37
39
res = torch .clamp (res , max = FP16_MAX_POS )
38
- scale .copy_ (res )
39
- return scale
40
+ return res
40
41
41
42
42
43
@torch .no_grad ()
43
44
def amax_history_to_scale (
44
- amax_history ,
45
- float8_dtype ,
46
- orig_dtype ,
47
- history_to_scale_fn_type ,
45
+ amax_history : torch . Tensor ,
46
+ float8_dtype : torch . dtype ,
47
+ orig_dtype : torch . dtype ,
48
+ history_to_scale_fn_type : str ,
48
49
):
49
50
if history_to_scale_fn_type == "max" :
50
51
amax = torch .max (amax_history )
@@ -87,7 +88,7 @@ def tensor_to_amax(x, distributed_reduction=False):
87
88
88
89
89
90
@torch .no_grad ()
90
- def tensor_to_scale (x , float8_dtype ):
91
+ def tensor_to_scale (x : torch . Tensor , float8_dtype : torch . dtype ):
91
92
amax = tensor_to_amax (x )
92
93
if float8_experimental .config .use_fused_cast and x .is_cuda :
93
94
from float8_experimental .fused_kernels .fused_casting_kernels import (
0 commit comments