You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
# optional: enable torch.compile for improved performance
43
+
m = torch.compile(m)
44
+
45
+
# train/finetune (not shown)
46
+
```
47
+
48
+
## float8 linear with delayed scaling
23
49
24
50
```python
25
51
from float8_experimental.float8_linear_utils import (
@@ -28,124 +54,40 @@ from float8_experimental.float8_linear_utils import (
28
54
)
29
55
from float8_experimental.float8_linear import Float8Linear
30
56
31
-
# create fp32 model
57
+
# create model
32
58
m = Model(...)
33
59
34
60
# convert all `torch.nn.Linear` modules to `Float8Linear`
35
61
swap_linear_with_float8_linear(m, Float8Linear)
36
62
63
+
# optional: use FSDP. Note that workarounds gated with config.enable_amax_init and
64
+
# config.enable_pre_and_post_forward are needed for autocast+compile+FSDP+float8 to work
65
+
from float8_experimental import config
66
+
config.enable_amax_init =False# only needed for autocast + compile + FSDP + float8 delayed
67
+
config.enable_pre_and_post_forward =False# only needed for autocast + compile + FSDP + float8 delayed
68
+
model = FSDP(model, use_orig_params=True)
69
+
70
+
# optional: enable torch.compile for improved performance
71
+
m = torch.compile(m)
72
+
37
73
# toy training loop
38
74
for _ inrange(N_ITER):
39
75
optimizer.zero_grad()
40
76
y = m(x)
41
77
y.sum().backward()
42
78
43
-
# specific to float8: separate step to sync scales/amaxes
79
+
# specific to float8 with delayed scaling: separate step to sync scales/amaxes
44
80
# in the future, this may move to a context manager
45
81
sync_float8_amax_and_scale_history(model)
46
82
47
83
optimizer.step()
48
84
```
49
85
50
-
## multi GPU
51
-
52
-
```python
53
-
from float8_experimental.tp_linear import swap_tp_linear_with_float8_linear
54
-
55
-
# swaps the fairscale `ColumnParallelLinear` with `Float8ColumnParallelLinear`,
56
-
# and the fairscale `RowParallelLinear` with `Float8RowParallelLinear`
57
-
swap_tp_linear_with_float8_linear(model)
58
-
59
-
# if applicable, enable sequence parallel on the right modules
60
-
#TODO make the API for this nicer
61
-
model.foo.bar.fc1.sequence_parallel =True
62
-
model.foo.bar.fc2.sequence_parallel =True
63
-
64
-
# the rest of the flow is the same as the single GPU flow
65
-
```
66
-
67
-
# high level technical design
68
-
69
-
## UX
70
-
71
-
We are using a module swap UX to keep things simple. If the user model has `torch.nn.Linear` modules or their `fairscale` TP/SP equivalents,
72
-
we can convert them to float8. `F.linear`, `torch.mm`, `torch.matmul` are not supported at the moment.
73
-
74
-
User is responsible for calling the `sync_float8_amax_and_scale_history` function once per fw/bw,
75
-
this function updates the amax history. If distributed is enabled, this function also syncs amax values across workers.
76
-
This is a separate model level function (as opposed to each module owning the syncing of its buffers) to
77
-
make it easier to optimize performance (for example, reduce all the amaxes once in a single tensor instead of doing N reductions).
78
-
79
-
Composability with `DTensor` is on our radar and we plan to look into this after the manual flow works e2e.
80
-
81
-
A user facing tensor subclass UX is not being considered at the moment because delayed scaling requires persistent state for
82
-
activations, and there isn't a clean and sound way to implement this with tensor subclasses.
83
-
84
-
## single GPU
85
-
86
-
### separation of concerns
87
-
88
-
1.`Float8Linear` owns casting X, W and dL/dY to float8 and does all the bookkeeping of the amax, amax_history and scale buffers
89
-
2. user is responsible for applying `Float8Linear` to the right parts of their model with module swaps
90
-
91
-
92
-
### Tensor subclasses
93
-
94
-
We are using tensor subclasses (`Float8Tensor`) to write modular code which satisfies
95
-
autograd's restriction that `x.dtype == x.grad.dtype`. The way we achieve this is by
96
-
ensuring that instances of `Float8Tensor` set their dtype attribute to the original
97
-
dtype (float32/float16/bfloat16) while the underlying data representation is in float8.
98
-
If you look in `float8_linear.py` and `te_linear.py`, you will see that we pass instances of `Float8Tensor`
99
-
around various `torch.autograd.Function` calls, enabling us to have modular code.
100
-
101
-
## multi GPU
102
-
103
-
### TP/SP
104
-
105
-
`Float8ColumnParallelLinear` and `Float8RowParallelLinear` are replacements for the non-float8 TP/SP primitives.
106
-
107
-
### FSDP with fp16 weight all-gather
108
-
109
-
No change from single GPU code - it just works.
110
-
111
-
### FSDP with fp8 weight all-gather
112
-
113
-
FSDP with fp8 weight-all gather is currently under design. The problem can be separated into three parts:
114
-
115
-
a. separation of concerns between user code and FSDP
116
-
b. user code interaction with FSDP
117
-
c. FSDP implementation of fp8 all-gather
118
-
119
-
#### Separation of concerns between user code and FSDP
120
-
121
-
We have alignment on the separation of concerns that we want:
122
-
1. user code is responsible for making the model fp8 aware and adding the right buffers
123
-
2. user code is responsible to passing FSDP the information necessary to cast weights to fp8: a way to tell if a weight should be cast to fp8, the weight's scale, and the Float8Tensor constructor
124
-
3. FSDP is responsible for performing the fp8 cast and providing the unsharded fp8 weight to each worker
125
-
4. user code is responsible for syncing amax metadata across workers and calculating scales
126
-
127
-
This way, FSDP knows as little as possible about user logic - it just gets a list of weights + amax buffers + scales,
128
-
and does the float8 fused cast + amax calculation. User code does everything else.
129
-
130
-
#### User code interaction with FSDP
131
-
132
-
We expect this to be trivial. First, when initializing FSDP, we will provide the necessary configuration
133
-
to it as described above. Second, instead of `w_fp8 = cast_to_fp8(w)`, we will just check if `w` is already in fp8.
134
-
135
-
#### FSDP implementation of fp8 all-gather
136
-
137
-
This is in early design. The current `FlatParameter` design does not work cleanly with heterogeneous dtypes,
138
-
and heterogeneous dtypes are required for a good UX, since for realistic models not all parameters
139
-
(norm parameters, biases, etc) will be in float8.
140
-
141
-
We are working on a new FSDP implementation that uses per-parameter sharding that will allow flexible fp8 all-gather. This is being prototyped currently.
142
-
143
86
# code tips
144
87
145
88
*`float8_experimental/float8_linear.py` - `Float8Linear` (main user facing entry point for delayed scaling)
146
89
*`float8_experimental/float8_dynamic_linear.py` - `Float8DynamicLinear` (main user facing entry point for dynamic scaling)
147
90
*`float8_experimental/float8_tensor.py` - `Float8Tensor`, which allows `Float8Linear` to abide by the `x.dtype == x.grad.dtype` restriction
148
-
*`float8_experimental/tp_linear.py` - `Float8ColumnParallelLinear` / `Float8RowParallelLinear` (TP/SP versions of float8 linear)
149
91
150
92
# testing
151
93
@@ -174,10 +116,9 @@ pytest test/test_compile.py
174
116
# benchmark the torch._scaled_mm function on LLaMa 2 70B shapes
175
117
./benchmarks/bench_matmul.py
176
118
177
-
# benchmark fw/bw of `Linear`, `Float8Linear` and `te.Linear` on LLaMa 2 70B shapes
119
+
# benchmark fw/bw of `Linear`, `Float8Linear` on LLaMa 2 70B shapes
178
120
# make sure to turn on torch.compile to get the best performance
0 commit comments