11
11
12
12
from float8_experimental .float8_tensor import Float8Tensor
13
13
from float8_experimental .float8_utils import tensor_to_scale , to_fp8_saturated
14
+ import float8_experimental .config as config
14
15
15
16
16
17
class NoopFwToFloat8E5M2Bw (torch .autograd .Function ):
@@ -38,6 +39,22 @@ def backward(ctx, gradY):
38
39
None ,
39
40
)
40
41
42
+ def cast_x_to_float8_e4m3fn_pre_hook (module , args ):
43
+ """
44
+ Hook to cast the incoming activation to `torch.float8_e4m3fn`
45
+ """
46
+ return module .cast_to_float8 (args [0 ])
47
+
48
+ def cast_dldy_to_float8_e5m2_backward_pre_hook (module , grad_output ):
49
+ """
50
+ Hook to cast the incoming gradient to `torch.float8_e5m2`
51
+ """
52
+ gradY = grad_output [0 ]
53
+ gradY_scale = tensor_to_scale (gradY , torch .float8_e5m2 )
54
+ gradY_scaled = gradY * gradY_scale
55
+ bits_fp8 = to_fp8_saturated (gradY_scaled , torch .float8_e5m2 )
56
+ tensor_fp8 = Float8Tensor (bits_fp8 , gradY_scale , gradY .dtype , emulate = module .emulate )
57
+ return (tensor_fp8 ,)
41
58
42
59
class Float8DynamicLinear (torch .nn .Linear ):
43
60
"""
@@ -48,9 +65,16 @@ class Float8DynamicLinear(torch.nn.Linear):
48
65
def __init__ (self , * args , ** kwargs ):
49
66
super ().__init__ (* args , ** kwargs )
50
67
self .add_weight_tag ()
68
+ self .use_activation_hooks = config .dynamic_use_activation_hooks
51
69
52
70
def forward (self , x ):
53
- x_fp8 = self .cast_to_float8 (x )
71
+ # cast x to float8_e4m3fn
72
+ if self .use_activation_hooks :
73
+ x_fp8 = x
74
+ else :
75
+ x_fp8 = self .cast_to_float8 (x )
76
+
77
+ # cast w to float8_e4m3fn
54
78
if getattr (self , "_w_fp8" , None ) is not None : # FSDP handled the cast
55
79
w_fp8 = self ._w_fp8
56
80
else :
@@ -59,7 +83,10 @@ def forward(self, x):
59
83
y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
60
84
61
85
# Cast gradY to float8_e5m2 during backward
62
- y = self .cast_to_float8e5m2_bw (y )
86
+ if self .use_activation_hooks :
87
+ pass
88
+ else :
89
+ y = self .cast_to_float8e5m2_bw (y )
63
90
64
91
return y
65
92
@@ -69,6 +96,7 @@ def add_weight_tag(self):
69
96
self .weight ._is_fp8_weight = True
70
97
71
98
def cast_to_float8 (self , inpt_tensor ):
99
+ # TODO rename this function to clarify e4m3
72
100
scale = tensor_to_scale (inpt_tensor , torch .float8_e4m3fn )
73
101
return Float8Tensor .to_float8 (
74
102
inpt_tensor , scale , torch .float8_e4m3fn , emulate = self .emulate
@@ -92,4 +120,14 @@ def from_float(cls, mod, emulate: bool = False):
92
120
new_mod .bias = mod .bias
93
121
new_mod .emulate = emulate
94
122
new_mod .add_weight_tag ()
123
+
124
+ new_mod .use_activation_hooks = config .dynamic_use_activation_hooks
125
+ if new_mod .use_activation_hooks :
126
+ # install the hooks
127
+ # TODO(future): figure out why using backward pre-hooks does not
128
+ # work here:
129
+ # 1. repro code: https://gist.github.com/vkuzo/27a3f6ca48e50ba1134b077f0dba254c
130
+ # 2. repro output: https://gist.github.com/vkuzo/728eae9dcc627e130829d122daa982e7
131
+ new_mod .register_forward_pre_hook (cast_x_to_float8_e4m3fn_pre_hook )
132
+ new_mod .register_full_backward_pre_hook (cast_dldy_to_float8_e5m2_backward_pre_hook )
95
133
return new_mod
0 commit comments