11
11
12
12
from float8_experimental .float8_tensor import Float8Tensor
13
13
from float8_experimental .float8_utils import tensor_to_scale , to_fp8_saturated
14
+ import float8_experimental .config as config
14
15
15
16
16
17
@torch ._dynamo .allow_in_graph
@@ -39,25 +40,54 @@ def backward(ctx, gradY):
39
40
None ,
40
41
)
41
42
43
+ def cast_x_to_float8_e4m3fn_pre_hook (module , args ):
44
+ """
45
+ Hook to cast the incoming activation to `torch.float8_e4m3fn`
46
+ """
47
+ return module .cast_to_float8 (args [0 ])
48
+
49
+ def cast_dldy_to_float8_e5m2_backward_pre_hook (module , grad_output ):
50
+ """
51
+ Hook to cast the incoming gradient to `torch.float8_e5m2`
52
+ """
53
+ gradY = grad_output [0 ]
54
+ gradY_scale = tensor_to_scale (gradY , torch .float8_e5m2 )
55
+ gradY_scaled = gradY * gradY_scale
56
+ bits_fp8 = to_fp8_saturated (gradY_scaled , torch .float8_e5m2 )
57
+ tensor_fp8 = Float8Tensor (bits_fp8 , gradY_scale , gradY .dtype , emulate = module .emulate )
58
+ return (tensor_fp8 ,)
42
59
43
60
class Float8DynamicLinear (torch .nn .Linear ):
44
61
"""
45
62
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
46
63
conversion to fp8 of the input and weight tensors.
47
64
"""
65
+ def __init__ (self , * args , ** kwargs ):
66
+ super ().__init__ (* args , ** kwargs )
67
+ self .use_activation_hooks = config .dynamic_use_activation_hooks
48
68
49
69
def forward (self , x ):
50
- x_fp8 = self .cast_to_float8 (x )
70
+ # cast x to float8_e4m3fn
71
+ if self .use_activation_hooks :
72
+ x_fp8 = x
73
+ else :
74
+ x_fp8 = self .cast_to_float8 (x )
75
+
76
+ # cast w to float8_e4m3fn
51
77
w_fp8 = self .cast_to_float8 (self .weight )
52
78
53
79
y = torch .nn .functional .linear (x_fp8 , w_fp8 , self .bias )
54
80
55
81
# Cast gradY to float8_e5m2 during backward
56
- y = self .cast_to_float8e5m2_bw (y )
82
+ if self .use_activation_hooks :
83
+ pass
84
+ else :
85
+ y = self .cast_to_float8e5m2_bw (y )
57
86
58
87
return y
59
88
60
89
def cast_to_float8 (self , inpt_tensor ):
90
+ # TODO rename this function to clarify e4m3
61
91
scale = tensor_to_scale (inpt_tensor , torch .float8_e4m3fn )
62
92
return Float8Tensor .to_float8 (
63
93
inpt_tensor , scale , torch .float8_e4m3fn , emulate = self .emulate
@@ -80,4 +110,9 @@ def from_float(cls, mod, emulate: bool = False):
80
110
new_mod .weight = mod .weight
81
111
new_mod .bias = mod .bias
82
112
new_mod .emulate = emulate
113
+ new_mod .use_activation_hooks = config .dynamic_use_activation_hooks
114
+ if new_mod .use_activation_hooks :
115
+ # install the hooks
116
+ new_mod .register_forward_pre_hook (cast_x_to_float8_e4m3fn_pre_hook )
117
+ new_mod .register_full_backward_pre_hook (cast_dldy_to_float8_e5m2_backward_pre_hook )
83
118
return new_mod
0 commit comments