Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit f83bf21

Browse files
committed
[wip] hooks
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent d0af81a commit f83bf21

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

float8_experimental/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,8 @@
1414
# this doesn't work with autocast + torch.compile + FSDP. Enabling this
1515
# option is useful for safety, but not strictly necessary.
1616
enable_pre_and_post_forward = True
17+
18+
# If True, dynamic linear uses hooks for activation casting
19+
# TODO(before land): add test coverage for both cases
20+
dynamic_use_activation_hooks = True
21+
# dynamic_use_activation_hooks = False

float8_experimental/float8_dynamic_linear.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from float8_experimental.float8_tensor import Float8Tensor
1313
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
14+
import float8_experimental.config as config
1415

1516

1617
@torch._dynamo.allow_in_graph
@@ -39,25 +40,54 @@ def backward(ctx, gradY):
3940
None,
4041
)
4142

43+
def cast_x_to_float8_e4m3fn_pre_hook(module, args):
44+
"""
45+
Hook to cast the incoming activation to `torch.float8_e4m3fn`
46+
"""
47+
return module.cast_to_float8(args[0])
48+
49+
def cast_dldy_to_float8_e5m2_backward_pre_hook(module, grad_output):
50+
"""
51+
Hook to cast the incoming gradient to `torch.float8_e5m2`
52+
"""
53+
gradY = grad_output[0]
54+
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
55+
gradY_scaled = gradY * gradY_scale
56+
bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2)
57+
tensor_fp8 = Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=module.emulate)
58+
return (tensor_fp8,)
4259

4360
class Float8DynamicLinear(torch.nn.Linear):
4461
"""
4562
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
4663
conversion to fp8 of the input and weight tensors.
4764
"""
65+
def __init__(self, *args, **kwargs):
66+
super().__init__(*args, **kwargs)
67+
self.use_activation_hooks = config.dynamic_use_activation_hooks
4868

4969
def forward(self, x):
50-
x_fp8 = self.cast_to_float8(x)
70+
# cast x to float8_e4m3fn
71+
if self.use_activation_hooks:
72+
x_fp8 = x
73+
else:
74+
x_fp8 = self.cast_to_float8(x)
75+
76+
# cast w to float8_e4m3fn
5177
w_fp8 = self.cast_to_float8(self.weight)
5278

5379
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
5480

5581
# Cast gradY to float8_e5m2 during backward
56-
y = self.cast_to_float8e5m2_bw(y)
82+
if self.use_activation_hooks:
83+
pass
84+
else:
85+
y = self.cast_to_float8e5m2_bw(y)
5786

5887
return y
5988

6089
def cast_to_float8(self, inpt_tensor):
90+
# TODO rename this function to clarify e4m3
6191
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
6292
return Float8Tensor.to_float8(
6393
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
@@ -80,4 +110,9 @@ def from_float(cls, mod, emulate: bool = False):
80110
new_mod.weight = mod.weight
81111
new_mod.bias = mod.bias
82112
new_mod.emulate = emulate
113+
new_mod.use_activation_hooks = config.dynamic_use_activation_hooks
114+
if new_mod.use_activation_hooks:
115+
# install the hooks
116+
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
117+
new_mod.register_full_backward_pre_hook(cast_dldy_to_float8_e5m2_backward_pre_hook)
83118
return new_mod

0 commit comments

Comments
 (0)