Skip to content

Commit d99785c

Browse files
Update float8nocompile readme (#1693)
1 parent 999b16d commit d99785c

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed
+71-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,72 @@
1-
# Work in progress
1+
# float8nocompile
22

3-
A prototype version of Float8Linear which is performant without `torch.compile`.
3+
4+
A prototype API for high performance eager mode float8 training that uses handwritten Triton kernels for quantization.
5+
6+
### Usage
7+
8+
Prepare your model for high performance eager mode float8 training with a single conversion function: `convert_to_float8_nocompile_training` ([source](https://github.com/pytorch/ao/blob/32a51eca14257bbaafd3671a5349189e30c65e2b/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py#L24)).
9+
10+
This function will replace nn.Linear layers with Float8NoCompileLinear layers in-place, which uses **dynamic, tensorwise scaling**
11+
to perform all matmuls in the linear layer forward and backward pass as FP8 GEMMs.
12+
13+
**Example**:
14+
15+
```python
16+
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
17+
convert_to_float8_nocompile_training,
18+
)
19+
20+
# define your model, data loaders, etc
21+
...
22+
23+
# convert specified `torch.nn.Linear` modules to `Float8Linear`
24+
convert_to_float8_nocompile_training(model)
25+
26+
# training loop
27+
for i in range(num_epochs):
28+
...
29+
```
30+
31+
### Performance benchmarks
32+
33+
Performance benchmarking was done via [experimental integration into torchtitan](https://github.com/pytorch/torchtitan/pull/778).
34+
35+
The results indicate a solid 6-10% tokens/sec speedup with relatively flat memory (+/- 1% peak memory) compared the bf16 eager baseline.
36+
37+
# Performance Comparison of Different Configurations on 8 H100s
38+
39+
## No AC (seq len 4096) - 8 H100s
40+
41+
| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ |
42+
|-------------------------------------------------|------------|------------------|--------------|---------------|
43+
| bfloat16, eager | 5339.0 | 53.12 | 0% | 0.00% |
44+
| float8nocompile prototype | 5871.4 | 52.7 | 9.97% | -0.79% |
45+
| float8 + torch.compile | 6667.6 | 46.64 | 24.88% | -12.20% |
46+
47+
---
48+
49+
## Selective per layer AC (AC every 2nd layer, seq len 4096) - 8 H100s
50+
51+
| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ |
52+
|-------------------------------------------------|------------|------------------|--------------|---------------|
53+
| bfloat16, eager | 4882.4 | 40.6 | 0% | 0.00% |
54+
| float8nocompile prototype | 5302.0 | 40.97 | 8.59% | 0.91% |
55+
| float8 + torch.compile | 6199.6 | 37.38 | 26.98% | -7.93% |
56+
57+
---
58+
59+
## Full AC (seq len 4096) - 8 H100s
60+
61+
| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ |
62+
|-------------------------------------------------|------------|------------------|--------------|---------------|
63+
| bfloat16, eager | 4502.0 | 28.07 | 0% | 0.00% |
64+
| float8nocompile prototype | 4773.4 | 28.07 | 6.03% | 0.00% |
65+
| float8 + torch.compile | 5775.2 | 28.03 | 28.28% | -0.14% |
66+
67+
68+
## Numerical accuracy
69+
70+
Numerical accuracy has been verified via unit tests as well as manually verifying that the training loss curves maintain fidelity with the loss curves for bf16 eager and production float8 + torch.compile:
71+
72+
![loss curves](float8nocompile_loss_curves.png "Loss curves")
Loading

0 commit comments

Comments
 (0)