Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit f4812ee

Browse files
vkuzofacebook-github-bot
authored andcommitted
support float8 weight caching for gradient accumulation/PP (#164)
Summary: In the cases where the optimizer update does not happen after every forward such as microbatching/PP, we can save the casted weight to trade some time for memory. For now I'm just testing out performance+accuracy. We can improve on the API in future PRs. The current code is torch.compile friendly which is nice. In terms of accuracy this should be no change, I will validate this further if we want to land this. For performance, on drisspg's LLaMa 7B pretrain script, with bsz==128 and micro_bsz == 1: 1. baseline bf16 + compile: 2.38 it/s 2. delayed scaling + compile: 2.80 it/s (1.18x over baseline) 3. delayed scaling + compile + this PR: 3.04 it/s (1.28x over baseline) Pull Request resolved: #164 Test Plan: ``` pytest test/test_base.py -s -k test_weight_caching ``` Reviewed By: drisspg Differential Revision: D52356785 Pulled By: vkuzo fbshipit-source-id: e0173666a6c7639246dfde636734900b9fc1657e
1 parent b099049 commit f4812ee

File tree

6 files changed

+152
-16
lines changed

6 files changed

+152
-16
lines changed

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,28 @@ model.foo.bar.fc2.sequence_parallel = True
6464
# the rest of the flow is the same as the single GPU flow
6565
```
6666

67+
## weight caching (very experimental)
68+
69+
```python
70+
import float8_experimental.config as config
71+
72+
m = Model(...)
73+
# before converting to `Float8Linear`, turn on weight cache buffer allocation
74+
config.allocate_float8_weight_cache_buffers = True
75+
76+
# in the training loop, manually control the global weight caching setting
77+
for idx in N_ITER:
78+
...
79+
if idx % n_microbatch == 0:
80+
# if we are in the first pass of a new microbatch, repopulate the cache
81+
config.weight_cache_enabled = False
82+
elif idx % n_microbatch == 1:
83+
# if we are in the second pass of a new microbatch, use cached weight
84+
# this persists until `idx % n_microbatch == 0` again
85+
config.weight_cache_enabled = True
86+
...
87+
```
88+
6789
# high level technical design
6890

6991
## UX

float8_experimental/config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
#
8+
# Weight caching.
9+
#
10+
11+
# If True, allocates buffers for float8 weight cache
12+
allocate_float8_weight_cache_buffers = False
13+
14+
# A global flag for controlling the weight cache, off by default. Intended
15+
# usage is for users to modify this from their training loop directly
16+
# according to their microbatching/pipeline parallel setup.
17+
# Note: this is currently a global flag for simplicity and dynamo performance.
18+
weight_cache_enabled = False

float8_experimental/float8_linear.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,14 @@
1616

1717
from typing import Optional
1818

19+
import float8_experimental.config as config
20+
1921
import torch
2022

21-
from float8_experimental.float8_tensor import Float8Tensor
23+
from float8_experimental.float8_tensor import (
24+
calculate_amax_and_cast_to_float8,
25+
Float8Tensor,
26+
)
2227

2328
from float8_experimental.float8_utils import (
2429
amax_history_to_scale,
@@ -172,6 +177,15 @@ def __init__(self, *args, **kwargs):
172177
# will access the scale when it has ensured that it is on GPU.
173178
self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs)
174179

180+
if config.allocate_float8_weight_cache_buffers:
181+
# this is a buffer to get `to(dtype)` for free
182+
# TODO(future): hide this from serialization
183+
# TODO(future): force this to stay in float8_e4m3fn
184+
self.register_buffer(
185+
"cached_fp8_weight",
186+
torch.empty(self.weight.shape, dtype=torch.float8_e4m3fn),
187+
)
188+
175189
def register_always_float32_buffer(
176190
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
177191
) -> None:
@@ -228,8 +242,33 @@ def cast_w_to_float8(
228242
torch.float8_e4m3fn,
229243
is_amax_initialized,
230244
)
245+
246+
if config.weight_cache_enabled:
247+
assert config.allocate_float8_weight_cache_buffers, (
248+
"float8 weight cache buffer must be allocated using "
249+
+ "`allocate_float8_weight_cache_buffers` to use the weight cache"
250+
)
251+
w_bits_fp8 = self.cached_fp8_weight
252+
else:
253+
# manual calculation of fp8 bits:
254+
# 1. calculate the bits without Float8Tensor, without grad
255+
# 2. store the bits here
256+
# 3. create Float8Tensor from the bits calculated in 2
257+
# motivation: this will take care of saving the bits without
258+
# interacting with tensor subclasses, as w_fp8._data is not
259+
# currently traceable by dynamo
260+
w_bits_fp8 = calculate_amax_and_cast_to_float8(
261+
self.weight, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w
262+
)
263+
if config.allocate_float8_weight_cache_buffers:
264+
self.cached_fp8_weight.copy_(w_bits_fp8)
231265
w_fp8 = Float8Tensor.to_float8(
232-
w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w, self.emulate
266+
w,
267+
self.fp8_scale_w,
268+
torch.float8_e4m3fn,
269+
self.fp8_amax_w,
270+
self.emulate,
271+
cached_casted_weight=w_bits_fp8,
233272
)
234273
return w_fp8
235274

float8_experimental/float8_linear_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ def sync_float8_amax_and_scale_history(
156156

157157
for idx in range(len(fp8_layers)):
158158
child = fp8_layers[idx]
159+
# TODO(future): enable skipping weight related syncing if weight cache
160+
# is on
161+
159162
#
160163
# 1. in distributed contexts, syncs amax values across workers
161164
#

float8_experimental/float8_tensor.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212
aten = torch.ops.aten
1313

1414

15+
@torch.no_grad()
16+
def calculate_amax_and_cast_to_float8(tensor, scale, float8_dtype, amax_buffer):
17+
if amax_buffer is not None:
18+
amax_buffer.fill_(tensor_to_amax(tensor))
19+
20+
tensor_scaled = tensor * scale
21+
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
22+
return bits_fp8
23+
24+
1525
class ToFloat8ConstrFunc(torch.autograd.Function):
1626
"""
1727
A differentiable conversion to fp8
@@ -25,24 +35,23 @@ def forward(
2535
float8_dtype=torch.float8_e4m3fn,
2636
amax_buffer=None,
2737
emulate: bool = False,
38+
cached_casted_weight=None,
2839
):
29-
# In TransformerEngine, the casts to float8 are fused with calculating
30-
# the new amax value. In this codebase, the eager mode code for those
31-
# two things is colocated in this function. We expect PT2.0 to fuse it
32-
# for us.
33-
if amax_buffer is not None:
34-
amax_buffer.fill_(tensor_to_amax(tensor))
35-
36-
tensor_scaled = tensor * scale
37-
bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)
40+
if cached_casted_weight is not None:
41+
return Float8Tensor(
42+
cached_casted_weight, scale, tensor.dtype, emulate=emulate
43+
)
44+
bits_fp8 = calculate_amax_and_cast_to_float8(
45+
tensor, scale, float8_dtype, amax_buffer
46+
)
3847
return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate)
3948

4049
@staticmethod
4150
def backward(ctx, g):
4251
if isinstance(g, Float8Tensor):
43-
return g.to_original_precision(), None, None, None, None
52+
return g.to_original_precision(), None, None, None, None, None
4453
else:
45-
return g, None, None, None, None
54+
return g, None, None, None, None, None
4655

4756

4857
class FromFloat8ConstrFunc(torch.autograd.Function):
@@ -122,7 +131,7 @@ def __tensor_flatten__(self):
122131
return ["_data", "_scale"], ctx
123132

124133
@staticmethod
125-
def __tensor_unflatten__(inner_tensors: Dict, metadata):
134+
def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride):
126135
assert len(inner_tensors) == 2
127136
return Float8Tensor(
128137
inner_tensors["_data"],
@@ -136,7 +145,14 @@ def to_original_precision(self):
136145

137146
@staticmethod
138147
@torch._dynamo.allow_in_graph
139-
def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False):
148+
def to_float8(
149+
tensor,
150+
scale,
151+
float8_dtype,
152+
amax_buffer=None,
153+
emulate: bool = False,
154+
cached_casted_weight=None,
155+
):
140156
"""Converts a higher precision tensor to float8 in a differentiable way.
141157
142158
Args:
@@ -149,7 +165,12 @@ def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = Fal
149165
Float8Tensor: a float8 tensor
150166
"""
151167
return ToFloat8ConstrFunc.apply(
152-
tensor, scale, float8_dtype, amax_buffer, emulate
168+
tensor,
169+
scale,
170+
float8_dtype,
171+
amax_buffer,
172+
emulate,
173+
cached_casted_weight,
153174
)
154175

155176
@classmethod

test/test_base.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
import warnings
1111
from enum import Enum
1212

13+
import float8_experimental.config as config
14+
import float8_experimental.float8_linear as float8_linear
15+
1316
import pytest
1417

1518
import torch
@@ -231,6 +234,36 @@ def test_type_cast(self, linear_type: LinearType, linear_dtype: torch.dtype):
231234
y.dtype == torch.bfloat16
232235
), f"y.dtype is {y.dtype}, expected {torch.bfloat16}"
233236

237+
@pytest.mark.parametrize("use_compile", [False, True])
238+
def test_weight_caching(self, use_compile):
239+
M, K, N = 16, 32, 64
240+
dtype = torch.bfloat16
241+
config.allocate_float8_weight_cache_buffers = True
242+
243+
x = torch.randn(M, K, device="cuda", dtype=dtype)
244+
m_ref = nn.Linear(K, N, bias=True, device="cuda", dtype=dtype)
245+
m = Float8Linear.from_float(copy.deepcopy(m_ref), emulate=False)
246+
247+
if use_compile:
248+
m = torch.compile(m)
249+
250+
config.weight_cache_enabled = False
251+
252+
y1 = m(x)
253+
y1.sum().backward()
254+
grad1 = m.weight.grad.clone().detach()
255+
256+
config.weight_cache_enabled = True
257+
sync_float8_amax_and_scale_history(m)
258+
259+
y2 = m(x)
260+
y2.sum().backward()
261+
grad2 = m.weight.grad.clone().detach()
262+
263+
torch.testing.assert_close(grad2, grad1 * 2)
264+
265+
config.allocate_float8_weight_cache_buffers = False
266+
234267

235268
class TestScaledMM:
236269
@unittest.skipIf(

0 commit comments

Comments
 (0)