11
11
12
12
from float8_experimental .float8_tensor import Float8Tensor
13
13
from float8_experimental .float8_utils import tensor_to_scale , to_fp8_saturated
14
+ import float8_experimental .config as config
14
15
15
16
16
17
class NoopFwToFloat8E5M2Bw (torch .autograd .Function ):
@@ -38,6 +39,24 @@ def backward(ctx, gradY):
38
39
None ,
39
40
)
40
41
42
+ def cast_x_to_float8_e4m3fn_pre_hook (module , args ):
43
+ """
44
+ Hook to cast the incoming activation to `torch.float8_e4m3fn`
45
+ """
46
+ return module .cast_to_float8 (args [0 ])
47
+
48
+ def cast_dldy_to_float8_e5m2_pre_hook (module , grad_output ):
49
+ """
50
+ Hook to cast the incoming gradient to `torch.float8_e5m2`
51
+ """
52
+ gradY = grad_output [0 ]
53
+ gradY_scale = tensor_to_scale (gradY , torch .float8_e5m2 )
54
+ gradY_scaled = gradY * gradY_scale
55
+ bits_fp8 = to_fp8_saturated (gradY_scaled , torch .float8_e5m2 )
56
+ gradY_fp8 = Float8Tensor (bits_fp8 , gradY_scale , gradY .dtype , emulate = module .emulate )
57
+ # TODO fix: the next op in the backward does not see this, it sees grad_output[0]
58
+ return (gradY_fp8 ,)
59
+
41
60
42
61
class Float8DynamicLinear (torch .nn .Linear ):
43
62
"""
@@ -48,9 +67,16 @@ class Float8DynamicLinear(torch.nn.Linear):
48
67
def __init__ (self , * args , ** kwargs ):
49
68
super ().__init__ (* args , ** kwargs )
50
69
self .add_weight_tag ()
70
+ self .use_activation_hooks = config .dynamic_use_activation_hooks
51
71
52
72
def forward (self , x ):
53
- x_fp8 = self .cast_to_float8 (x )
73
+ # cast x to float8_e4m3fn
74
+ if self .use_activation_hooks :
75
+ x_fp8 = x
76
+ else :
77
+ x_fp8 = self .cast_to_float8 (x )
78
+
79
+ # cast w to float8_e4m3fn
54
80
if getattr (self , "_w_fp8" , None ) is not None : # FSDP handled the cast
55
81
w_fp8 = self ._w_fp8
56
82
else :
@@ -59,7 +85,10 @@ def forward(self, x):
59
85
y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
60
86
61
87
# Cast gradY to float8_e5m2 during backward
62
- y = self .cast_to_float8e5m2_bw (y )
88
+ if self .use_activation_hooks :
89
+ pass
90
+ else :
91
+ y = self .cast_to_float8e5m2_bw (y )
63
92
64
93
return y
65
94
@@ -69,6 +98,7 @@ def add_weight_tag(self):
69
98
self .weight ._is_fp8_weight = True
70
99
71
100
def cast_to_float8 (self , inpt_tensor ):
101
+ # TODO rename this function to clarify e4m3
72
102
scale = tensor_to_scale (inpt_tensor , torch .float8_e4m3fn )
73
103
return Float8Tensor .to_float8 (
74
104
inpt_tensor , scale , torch .float8_e4m3fn , emulate = self .emulate
@@ -92,4 +122,10 @@ def from_float(cls, mod, emulate: bool = False):
92
122
new_mod .bias = mod .bias
93
123
new_mod .emulate = emulate
94
124
new_mod .add_weight_tag ()
125
+
126
+ new_mod .use_activation_hooks = config .dynamic_use_activation_hooks
127
+ if new_mod .use_activation_hooks :
128
+ # install the hooks
129
+ new_mod .register_forward_pre_hook (cast_x_to_float8_e4m3fn_pre_hook )
130
+ new_mod .register_full_backward_pre_hook (cast_dldy_to_float8_e5m2_pre_hook )
95
131
return new_mod
0 commit comments