Skip to content

Commit da04a9e

Browse files
update float8nocompile readme
1 parent 32a51ec commit da04a9e

File tree

1 file changed

+63
-2
lines changed

1 file changed

+63
-2
lines changed
+63-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,64 @@
1-
# Work in progress
1+
# float8nocompile
22

3-
A prototype version of Float8Linear which is performant without `torch.compile`.
3+
4+
A prototype API for high performance eager mode float8 training via handwritten Triton kernels for quantization.
5+
6+
### Usage
7+
8+
Prepare your model for high performance eager mode float8 training with a single conversion function: `convert_to_float8_nocompile_training`.
9+
This function will replace nn.Linear layers with Float8NoCompileLinear layers in-place, which uses **dynamic, tensorwise scaling**
10+
to perform all matmuls in the linear layer forward and backward pass as FP8 GEMMs.
11+
12+
**Example**:
13+
14+
```python
15+
from torchao.prototype.float8nocompile.float8nocompile_linear_utils import (
16+
convert_to_float8_nocompile_training,
17+
)
18+
19+
# define your model, data loaders, etc
20+
...
21+
22+
# convert specified `torch.nn.Linear` modules to `Float8Linear`
23+
convert_to_float8_nocompile_training(model)
24+
25+
# training loop
26+
for i in range(num_epochs):
27+
...
28+
```
29+
30+
### Performance benchmarks
31+
32+
Performance benchmarking was done via [experimental integration into torchtitan](https://github.com/pytorch/torchtitan/pull/778).
33+
34+
The results indicate a solid 6-10% tokens/sec speedup with relatively flat memory (+/- 1% peak memory) compared the bf16 eager baseline.
35+
36+
# Performance Comparison of Different Configurations on 8 H100s
37+
38+
## No AC (seq len 4096) - 8 H100s
39+
40+
| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ |
41+
|-------------------------------------------------|------------|------------------|--------------|---------------|
42+
| bfloat16, eager | 5339.0 | 53.12 | 0% | 0.00% |
43+
| float8nocompile prototype | 5871.4 | 52.7 | 9.97% | -0.79% |
44+
| float8 + torch.compile | 6667.6 | 46.64 | 24.88% | -12.20% |
45+
46+
---
47+
48+
## Selective per layer AC (AC every 2nd layer, seq len 4096) - 8 H100s
49+
50+
| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ |
51+
|-------------------------------------------------|------------|------------------|--------------|---------------|
52+
| bfloat16, eager | 4882.4 | 40.6 | 0% | 0.00% |
53+
| float8nocompile prototype | 5302.0 | 40.97 | 8.59% | 0.91% |
54+
| float8 + torch.compile | 6199.6 | 37.38 | 26.98% | -7.93% |
55+
56+
---
57+
58+
## Full AC (seq len 4096) - 8 H100s
59+
60+
| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ |
61+
|-------------------------------------------------|------------|------------------|--------------|---------------|
62+
| bfloat16, eager | 4502.0 | 28.07 | 0% | 0.00% |
63+
| float8nocompile prototype | 4773.4 | 28.07 | 6.03% | 0.00% |
64+
| float8 + torch.compile | 5775.2 | 28.03 | 28.28% | -0.14% |

0 commit comments

Comments
 (0)