Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 3fe1055

Browse files
committed
[wip] hooks
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 31fba04 commit 3fe1055

File tree

4 files changed

+78
-2
lines changed

4 files changed

+78
-2
lines changed

float8_experimental/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,10 @@
1616
# according to their microbatching/pipeline parallel setup.
1717
# Note: this is currently a global flag for simplicity and dynamo performance.
1818
weight_cache_enabled = False
19+
20+
#
21+
# Other
22+
#
23+
24+
# If True, dynamic linear uses hooks for activation casting
25+
dynamic_use_activation_hooks = True

float8_experimental/dynamic_linear/dynamic_float8_linear.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from float8_experimental.float8_tensor import Float8Tensor
1313
from float8_experimental.float8_utils import tensor_to_scale, to_fp8_saturated
14+
import float8_experimental.config as config
1415

1516

1617
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
@@ -38,6 +39,24 @@ def backward(ctx, gradY):
3839
None,
3940
)
4041

42+
def cast_x_to_float8_e4m3fn_pre_hook(module, args):
43+
"""
44+
Hook to cast the incoming activation to `torch.float8_e4m3fn`
45+
"""
46+
return module.cast_to_float8(args[0])
47+
48+
def cast_dldy_to_float8_e5m2_pre_hook(module, grad_output):
49+
"""
50+
Hook to cast the incoming gradient to `torch.float8_e5m2`
51+
"""
52+
gradY = grad_output[0]
53+
gradY_scale = tensor_to_scale(gradY, torch.float8_e5m2)
54+
gradY_scaled = gradY * gradY_scale
55+
bits_fp8 = to_fp8_saturated(gradY_scaled, torch.float8_e5m2)
56+
gradY_fp8 = Float8Tensor(bits_fp8, gradY_scale, gradY.dtype, emulate=module.emulate)
57+
# TODO fix: the next op in the backward does not see this, it sees grad_output[0]
58+
return (gradY_fp8,)
59+
4160

4261
class Float8DynamicLinear(torch.nn.Linear):
4362
"""
@@ -48,9 +67,16 @@ class Float8DynamicLinear(torch.nn.Linear):
4867
def __init__(self, *args, **kwargs):
4968
super().__init__(*args, **kwargs)
5069
self.add_weight_tag()
70+
self.use_activation_hooks = config.dynamic_use_activation_hooks
5171

5272
def forward(self, x):
53-
x_fp8 = self.cast_to_float8(x)
73+
# cast x to float8_e4m3fn
74+
if self.use_activation_hooks:
75+
x_fp8 = x
76+
else:
77+
x_fp8 = self.cast_to_float8(x)
78+
79+
# cast w to float8_e4m3fn
5480
if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast
5581
w_fp8 = self._w_fp8
5682
else:
@@ -59,7 +85,10 @@ def forward(self, x):
5985
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
6086

6187
# Cast gradY to float8_e5m2 during backward
62-
y = self.cast_to_float8e5m2_bw(y)
88+
if self.use_activation_hooks:
89+
pass
90+
else:
91+
y = self.cast_to_float8e5m2_bw(y)
6392

6493
return y
6594

@@ -69,6 +98,7 @@ def add_weight_tag(self):
6998
self.weight._is_fp8_weight = True
7099

71100
def cast_to_float8(self, inpt_tensor):
101+
# TODO rename this function to clarify e4m3
72102
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
73103
return Float8Tensor.to_float8(
74104
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
@@ -92,4 +122,10 @@ def from_float(cls, mod, emulate: bool = False):
92122
new_mod.bias = mod.bias
93123
new_mod.emulate = emulate
94124
new_mod.add_weight_tag()
125+
126+
new_mod.use_activation_hooks = config.dynamic_use_activation_hooks
127+
if new_mod.use_activation_hooks:
128+
# install the hooks
129+
new_mod.register_forward_pre_hook(cast_x_to_float8_e4m3fn_pre_hook)
130+
new_mod.register_full_backward_pre_hook(cast_dldy_to_float8_e5m2_pre_hook)
95131
return new_mod

float8_experimental/float8_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def preprocess_addmm(a: Float8Tensor, b: Float8Tensor):
7777

7878
@implements([aten.mm.default])
7979
def float8_mm(aten_op, args, kwargs=None):
80+
print('float8_mm', args)
8081
assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor)
8182
a = args[0]
8283
b = args[1]

test/test_bw_hook.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class TestAutogradFunction(torch.autograd.Function):
5+
6+
@staticmethod
7+
def forward(ctx, tensor):
8+
tensor = tensor + 1.0
9+
return tensor
10+
11+
@staticmethod
12+
def backward(ctx, gradY):
13+
# prints a tensor filled with 0.123, as expected
14+
print('gradY', gradY)
15+
gradY = gradY + 1.0
16+
return gradY
17+
18+
class M(nn.Module):
19+
def forward(self, x):
20+
return TestAutogradFunction.apply(x)
21+
22+
m = M()
23+
24+
def bw_pre_hook(module, go):
25+
new_go = torch.empty_like(go[0]).fill_(0.123)
26+
return (new_go,)
27+
28+
m.register_full_backward_pre_hook(bw_pre_hook)
29+
30+
x = torch.randn(2, 2).requires_grad_()
31+
y = m(x)
32+
y.sum().backward()

0 commit comments

Comments
 (0)