Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit d4cf2ad

Browse files
vkuzofacebook-github-bot
authored andcommitted
make dynamic scaling default in Float8Linear (#300)
Summary: Pull Request resolved: #300 1. makes dynamic scaling default in Float8Linear for an easier migration of callsites which currently use Float8DynamicLinear. Fixes tests as needed. 2. updates the README to reference Float8Linear for dynamic scaling Reviewed By: drisspg Differential Revision: D59305790 fbshipit-source-id: 30d3813946239e0e958e0f7ed446082b578b0607
1 parent 4fb0ada commit d4cf2ad

File tree

5 files changed

+54
-19
lines changed

5 files changed

+54
-19
lines changed

README.md

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,23 @@ pip install -e ".[dev]"
2727

2828
# User API
2929

30-
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details.
30+
We provide two per-tensor scaling strategies: dynamic and delayed. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`x`), weights (`w`) and gradients (`dL_dY`).
3131

32-
## float8 linear with dynamic scaling
32+
## float8 linear with dynamic scaling for `x`, `w` and `dL_dY`
33+
34+
This is the most accurate recipe as every tensor is scaled dynamically.
3335

3436
```python
3537
from float8_experimental.float8_linear_utils import (
3638
swap_linear_with_float8_linear,
3739
)
38-
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
40+
from float8_experimental.float8_linear import Float8Linear
3941

4042
# create model
4143
m = Model(...)
4244

43-
# convert all `torch.nn.Linear` modules to `Float8DynamicLinear`
44-
swap_linear_with_float8_linear(m, Float8DynamicLinear)
45+
# convert all `torch.nn.Linear` modules to `Float8Linear`
46+
swap_linear_with_float8_linear(m, Float8Linear)
4547

4648
# optional: use FSDP
4749
model = FSDP(model, use_orig_params=True)
@@ -54,18 +56,27 @@ m = torch.compile(m)
5456

5557
## float8 linear with delayed scaling
5658

59+
This is theoretically the most performant recipe as it minimizes memory reads.
60+
5761
```python
5862
from float8_experimental.float8_linear_utils import (
5963
swap_linear_with_float8_linear,
6064
sync_float8_amax_and_scale_history,
6165
)
62-
from float8_experimental.float8_linear import Float8Linear
66+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
6367

6468
# create model
6569
m = Model(...)
6670

67-
# convert all `torch.nn.Linear` modules to `Float8Linear`
68-
swap_linear_with_float8_linear(m, Float8Linear)
71+
# convert all `torch.nn.Linear` modules to `Float8Linear`, specifying scaling
72+
# type
73+
swap_linear_with_float8_linear(
74+
m,
75+
Float8Linear,
76+
scaling_type_x=TensorScalingType.DELAYED,
77+
scaling_type_w=TensorScalingType.DELAYED,
78+
scaling_type_dL_dY=TensorScalingType.DELAYED,
79+
)
6980

7081
# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
7182
# config.enable_pre_and_post_forward are needed for autocast + compile + FSDP + float8 to work
@@ -93,9 +104,7 @@ for _ in range(N_ITER):
93104
# 🧭 Code Organization
94105

95106
* `float8_experimental/float8_linear.py`
96-
- `Float8Linear` (main user facing entry point for delayed scaling)
97-
* `float8_experimental/float8_dynamic_linear.py`
98-
- `Float8DynamicLinear` (main user facing entry point for dynamic scaling)
107+
- `Float8Linear` (main user facing entry point for Float8Linear)
99108
* `float8_experimental/float8_tensor.py`
100109
- `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
101110
- `ScaledMMConfig` defines the semantics for matmul in the forward and backwards pass

float8_experimental/float8_linear_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ def swap_linear_with_float8_linear(
191191
skip_fqn_list: Optional[List[str]] = None,
192192
emulate: bool = False,
193193
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
194-
scaling_type_x: TensorScalingType = TensorScalingType.DELAYED,
195-
scaling_type_w: TensorScalingType = TensorScalingType.DELAYED,
196-
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED,
194+
scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC,
195+
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
196+
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC,
197197
) -> Optional[nn.Module]:
198198
"""
199199
Swaps `torch.nn.Linear` in `module` with `Float8Linear` or `Float8DynamicLinear`.

test/test_compile.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,13 @@ def test_sync_amax_func():
299299
module = torch.nn.Sequential(
300300
nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
301301
)
302-
float8_mod = swap_linear_with_float8_linear(module, Float8Linear)
302+
float8_mod = swap_linear_with_float8_linear(
303+
module,
304+
Float8Linear,
305+
scaling_type_x=TensorScalingType.DELAYED,
306+
scaling_type_w=TensorScalingType.DELAYED,
307+
scaling_type_dL_dY=TensorScalingType.DELAYED,
308+
)
303309
compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts)
304310
compiled_swap_func(float8_mod)
305311
assert cnts.frame_count == 1, "Compiled graph should have 1 frame!"
@@ -329,7 +335,13 @@ def test_sync_amax_func_cuda_graph_success():
329335
my_module = nn.Sequential(
330336
nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
331337
).to("cuda")
332-
swap_linear_with_float8_linear(my_module, Float8Linear)
338+
swap_linear_with_float8_linear(
339+
my_module,
340+
Float8Linear,
341+
scaling_type_x=TensorScalingType.DELAYED,
342+
scaling_type_w=TensorScalingType.DELAYED,
343+
scaling_type_dL_dY=TensorScalingType.DELAYED,
344+
)
333345
inpt = torch.randn(
334346
16, 16, device="cuda", dtype=torch.float32, requires_grad=True
335347
)

test/test_fsdp.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import torch.nn as nn
2424
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
2525
from float8_experimental.float8_linear_utils import (
26+
linear_requires_sync,
27+
LinearType,
2628
swap_linear_with_float8_linear,
2729
sync_float8_amax_and_scale_history,
2830
)
@@ -130,7 +132,12 @@ def forward_backward(model, optim, is_fp8, i):
130132
optim.zero_grad()
131133
y_local = model(ref_input_local[i])
132134
y_local.backward(ref_grad_local[i])
133-
if is_fp8:
135+
if is_fp8 and linear_requires_sync(
136+
LinearType.DELAYED,
137+
TensorScalingType.DYNAMIC,
138+
scaling_type_w,
139+
TensorScalingType.DYNAMIC,
140+
):
134141
sync_float8_func(model)
135142
optim.step()
136143
return y_local

test/test_fsdp_compile.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch.multiprocessing as mp
1919
import torch.nn as nn
2020
from float8_experimental import config
21-
from float8_experimental.float8_linear import Float8Linear
21+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
2222
from float8_experimental.float8_linear_utils import (
2323
swap_linear_with_float8_linear,
2424
sync_float8_amax_and_scale_history,
@@ -49,7 +49,14 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
4949
nn.Linear(K, N, dtype=base_dtype),
5050
nn.ReLU(),
5151
)
52-
swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate)
52+
swap_linear_with_float8_linear(
53+
m,
54+
Float8Linear,
55+
emulate=emulate,
56+
scaling_type_x=TensorScalingType.DELAYED,
57+
scaling_type_w=TensorScalingType.DELAYED,
58+
scaling_type_dL_dY=TensorScalingType.DELAYED,
59+
)
5360
return m
5461

5562

0 commit comments

Comments
 (0)