Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 270ce64

Browse files
committed
[wip] hooks
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 31fba04 commit 270ce64

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

float8_experimental/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,11 @@
1616
# according to their microbatching/pipeline parallel setup.
1717
# Note: this is currently a global flag for simplicity and dynamo performance.
1818
weight_cache_enabled = False
19+
20+
#
21+
# Other
22+
#
23+
24+
# If True, dynamic linear uses hooks for activation casting
25+
dynamic_use_activation_hooks = True
26+
# dynamic_use_activation_hooks = False

float8_experimental/dynamic_linear/dynamic_float8_linear.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from float8_experimental.float8_tensor import Float8Tensor
1313
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
14+
import float8_experimental.config as config
1415

1516

1617
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
@@ -38,6 +39,22 @@ def backward(ctx, gradY):
3839
None,
3940
)
4041

42+
def cast_x_to_float8_e4m3fn_pre_hook(module, args):
43+
"""
44+
Hook to cast the incoming activation to `torch.float8_e4m3fn`
45+
"""
46+
return module.cast_to_float8(args[0])
47+
48+
def cast_dldy_to_float8_e5m2_backward_pre_hook(module, grad_output):
49+
"""
50+
Hook to cast the incoming gradient to `torch.float8_e5m2`
51+
"""
52+
gradY = grad_output[0]
53+
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
54+
gradY_scaled = gradY * gradY_scale
55+
bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2)
56+
tensor_fp8 = Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=module.emulate)
57+
return (tensor_fp8,)
4158

4259
class Float8DynamicLinear(torch.nn.Linear):
4360
"""
@@ -48,9 +65,16 @@ class Float8DynamicLinear(torch.nn.Linear):
4865
def __init__(self, *args, **kwargs):
4966
super().__init__(*args, **kwargs)
5067
self.add_weight_tag()
68+
self.use_activation_hooks = config.dynamic_use_activation_hooks
5169

5270
def forward(self, x):
53-
x_fp8 = self.cast_to_float8(x)
71+
# cast x to float8_e4m3fn
72+
if self.use_activation_hooks:
73+
x_fp8 = x
74+
else:
75+
x_fp8 = self.cast_to_float8(x)
76+
77+
# cast w to float8_e4m3fn
5478
if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast
5579
w_fp8 = self._w_fp8
5680
else:
@@ -59,7 +83,10 @@ def forward(self, x):
5983
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
6084

6185
# Cast gradY to float8_e5m2 during backward
62-
y = self.cast_to_float8e5m2_bw(y)
86+
if self.use_activation_hooks:
87+
pass
88+
else:
89+
y = self.cast_to_float8e5m2_bw(y)
6390

6491
return y
6592

@@ -69,6 +96,7 @@ def add_weight_tag(self):
6996
self.weight._is_fp8_weight = True
7097

7198
def cast_to_float8(self, inpt_tensor):
99+
# TODO rename this function to clarify e4m3
72100
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
73101
return Float8Tensor.to_float8(
74102
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
@@ -92,4 +120,14 @@ def from_float(cls, mod, emulate: bool = False):
92120
new_mod.bias = mod.bias
93121
new_mod.emulate = emulate
94122
new_mod.add_weight_tag()
123+
124+
new_mod.use_activation_hooks = config.dynamic_use_activation_hooks
125+
if new_mod.use_activation_hooks:
126+
# install the hooks
127+
# TODO(future): figure out why using backward pre-hooks does not
128+
# work here:
129+
# 1. repro code: https://gist.github.com/vkuzo/27a3f6ca48e50ba1134b077f0dba254c
130+
# 2. repro output: https://gist.github.com/vkuzo/728eae9dcc627e130829d122daa982e7
131+
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
132+
new_mod.register_full_backward_pre_hook(cast_dldy_to_float8_e5m2_backward_pre_hook)
95133
return new_mod

0 commit comments

Comments
 (0)