Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit dd0c596

Browse files
vkuzofacebook-github-bot
authored andcommitted
remove weight caching (#181)
Summary: Removes most of #164 This isn't useful in the short term since it doesn't compose with FSDP + compile, and memory overhead is high. We can bring it back later if needed. Pull Request resolved: #181 Test Plan: ``` ./test/test_everything.sh ``` Reviewed By: drisspg Differential Revision: D52648603 Pulled By: vkuzo fbshipit-source-id: f956337264fd28fa0bc50d151c316cde7c3d28de
1 parent 8ed0eb7 commit dd0c596

File tree

6 files changed

+9
-136
lines changed

6 files changed

+9
-136
lines changed

README.md

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,28 +64,6 @@ model.foo.bar.fc2.sequence_parallel = True
6464
# the rest of the flow is the same as the single GPU flow
6565
```
6666

67-
## weight caching (very experimental)
68-
69-
```python
70-
import float8_experimental.config as config
71-
72-
m = Model(...)
73-
# before converting to `Float8Linear`, turn on weight cache buffer allocation
74-
config.allocate_float8_weight_cache_buffers = True
75-
76-
# in the training loop, manually control the global weight caching setting
77-
for idx in N_ITER:
78-
...
79-
if idx % n_microbatch == 0:
80-
# if we are in the first pass of a new microbatch, repopulate the cache
81-
config.weight_cache_enabled = False
82-
elif idx % n_microbatch == 1:
83-
# if we are in the second pass of a new microbatch, use cached weight
84-
# this persists until `idx % n_microbatch == 0` again
85-
config.weight_cache_enabled = True
86-
...
87-
```
88-
8967
# high level technical design
9068

9169
## UX

float8_experimental/config.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,6 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
#
8-
# Weight caching.
9-
#
10-
11-
# If True, allocates buffers for float8 weight cache
12-
allocate_float8_weight_cache_buffers = False
13-
14-
# A global flag for controlling the weight cache, off by default. Intended
15-
# usage is for users to modify this from their training loop directly
16-
# according to their microbatching/pipeline parallel setup.
17-
# Note: this is currently a global flag for simplicity and dynamo performance.
18-
weight_cache_enabled = False
19-
20-
#
21-
# Other
22-
#
23-
247
# If True, on the first iteration of Float8Linear the amaxes will be
258
# initialized with the incoming data. As of 2023-12-30, this doesn't work
269
# with autocast + torch.compile + FSDP. Enabling this option is nice for

float8_experimental/float8_linear.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020

2121
import torch
2222

23-
from float8_experimental.float8_tensor import (
24-
calculate_amax_and_cast_to_float8,
25-
Float8Tensor,
26-
)
23+
from float8_experimental.float8_tensor import Float8Tensor
2724

2825
from float8_experimental.float8_utils import (
2926
amax_history_to_scale,
@@ -175,15 +172,6 @@ def __init__(self, *args, **kwargs):
175172
# and torch.compile, this option can disable them
176173
self.enable_pre_and_post_forward = config.enable_pre_and_post_forward
177174

178-
if config.allocate_float8_weight_cache_buffers:
179-
# this is a buffer to get `to(dtype)` for free
180-
# TODO(future): hide this from serialization
181-
# TODO(future): force this to stay in float8_e4m3fn
182-
self.register_buffer(
183-
"cached_fp8_weight",
184-
torch.empty(self.weight.shape, dtype=torch.float8_e4m3fn),
185-
)
186-
187175
def register_always_float32_buffer(
188176
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
189177
) -> None:
@@ -240,32 +228,12 @@ def cast_w_to_float8(
240228
is_amax_initialized,
241229
)
242230

243-
if config.weight_cache_enabled:
244-
assert config.allocate_float8_weight_cache_buffers, (
245-
"float8 weight cache buffer must be allocated using "
246-
+ "`allocate_float8_weight_cache_buffers` to use the weight cache"
247-
)
248-
w_bits_fp8 = self.cached_fp8_weight
249-
else:
250-
# manual calculation of fp8 bits:
251-
# 1. calculate the bits without Float8Tensor, without grad
252-
# 2. store the bits here
253-
# 3. create Float8Tensor from the bits calculated in 2
254-
# motivation: this will take care of saving the bits without
255-
# interacting with tensor subclasses, as w_fp8._data is not
256-
# currently traceable by dynamo
257-
w_bits_fp8 = calculate_amax_and_cast_to_float8(
258-
self.weight, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w
259-
)
260-
if config.allocate_float8_weight_cache_buffers:
261-
self.cached_fp8_weight.copy_(w_bits_fp8)
262231
w_fp8 = Float8Tensor.to_float8(
263232
w,
264233
self.fp8_scale_w,
265234
torch.float8_e4m3fn,
266235
self.fp8_amax_w,
267236
self.emulate,
268-
cached_casted_weight=w_bits_fp8,
269237
)
270238
return w_fp8
271239

float8_experimental/float8_linear_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,6 @@ def sync_float8_amax_and_scale_history(
156156

157157
for idx in range(len(fp8_layers)):
158158
child = fp8_layers[idx]
159-
# TODO(future): enable skipping weight related syncing if weight cache
160-
# is on
161159

162160
#
163161
# 1. in distributed contexts, syncs amax values across workers

float8_experimental/float8_tensor.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,6 @@
1212
aten = torch.ops.aten
1313

1414

15-
@torch.no_grad()
16-
def calculate_amax_and_cast_to_float8(tensor, scale, float8_dtype, amax_buffer):
17-
if amax_buffer is not None:
18-
amax_buffer.fill_(tensor_to_amax(tensor))
19-
20-
tensor_scaled = tensor * scale
21-
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
22-
return bits_fp8
23-
24-
2515
@torch._dynamo.allow_in_graph
2616
class ToFloat8ConstrFunc(torch.autograd.Function):
2717
"""
@@ -36,23 +26,20 @@ def forward(
3626
float8_dtype=torch.float8_e4m3fn,
3727
amax_buffer=None,
3828
emulate: bool = False,
39-
cached_casted_weight=None,
4029
):
41-
if cached_casted_weight is not None:
42-
return Float8Tensor(
43-
cached_casted_weight, scale, tensor.dtype, emulate=emulate
44-
)
45-
bits_fp8 = calculate_amax_and_cast_to_float8(
46-
tensor, scale, float8_dtype, amax_buffer
47-
)
30+
if amax_buffer is not None:
31+
amax_buffer.fill_(tensor_to_amax(tensor))
32+
33+
tensor_scaled = tensor * scale
34+
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
4835
return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate)
4936

5037
@staticmethod
5138
def backward(ctx, g):
5239
if isinstance(g, Float8Tensor):
53-
return g.to_original_precision(), None, None, None, None, None
40+
return g.to_original_precision(), None, None, None, None
5441
else:
55-
return g, None, None, None, None, None
42+
return g, None, None, None, None
5643

5744

5845
@torch._dynamo.allow_in_graph
@@ -147,14 +134,7 @@ def to_original_precision(self):
147134

148135
@staticmethod
149136
@torch._dynamo.allow_in_graph
150-
def to_float8(
151-
tensor,
152-
scale,
153-
float8_dtype,
154-
amax_buffer=None,
155-
emulate: bool = False,
156-
cached_casted_weight=None,
157-
):
137+
def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False):
158138
"""Converts a higher precision tensor to float8 in a differentiable way.
159139
160140
Args:
@@ -172,7 +152,6 @@ def to_float8(
172152
float8_dtype,
173153
amax_buffer,
174154
emulate,
175-
cached_casted_weight,
176155
)
177156

178157
@classmethod

test/test_base.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
import warnings
1111
from enum import Enum
1212

13-
import float8_experimental.config as config
14-
import float8_experimental.float8_linear as float8_linear
15-
1613
import pytest
1714

1815
import torch
@@ -229,36 +226,6 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
229226
y.dtype == torch.bfloat16
230227
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
231228

232-
@pytest.mark.parametrize("use_compile", [False, True])
233-
def test_weight_caching(self, use_compile):
234-
M, K, N = 16, 32, 64
235-
dtype = torch.bfloat16
236-
config.allocate_float8_weight_cache_buffers = True
237-
238-
x = torch.randn(M, K, device="cuda", dtype=dtype)
239-
m_ref = nn.Linear(K, N, bias=True, device="cuda", dtype=dtype)
240-
m = Float8Linear.from_float(copy.deepcopy(m_ref), emulate=False)
241-
242-
if use_compile:
243-
m = torch.compile(m)
244-
245-
config.weight_cache_enabled = False
246-
247-
y1 = m(x)
248-
y1.sum().backward()
249-
grad1 = m.weight.grad.clone().detach()
250-
251-
config.weight_cache_enabled = True
252-
sync_float8_amax_and_scale_history(m)
253-
254-
y2 = m(x)
255-
y2.sum().backward()
256-
grad2 = m.weight.grad.clone().detach()
257-
258-
torch.testing.assert_close(grad2, grad1 * 2)
259-
260-
config.allocate_float8_weight_cache_buffers = False
261-
262229

263230
class TestScaledMM:
264231
@unittest.skipIf(

0 commit comments

Comments
 (0)