|
1 |
| -# Work in progress |
| 1 | +# float8nocompile |
2 | 2 |
|
3 |
| -A prototype version of Float8Linear which is performant without `torch.compile`. |
| 3 | + |
| 4 | +A prototype API for high performance eager mode float8 training that uses handwritten Triton kernels for quantization. |
| 5 | + |
| 6 | +### Usage |
| 7 | + |
| 8 | +Prepare your model for high performance eager mode float8 training with a single conversion function: `convert_to_float8_nocompile_training` ([source](https://github.com/pytorch/ao/blob/32a51eca14257bbaafd3671a5349189e30c65e2b/torchao/prototype/float8nocompile/float8nocompile_linear_utils.py#L24)). |
| 9 | + |
| 10 | +This function will replace nn.Linear layers with Float8NoCompileLinear layers in-place, which uses **dynamic, tensorwise scaling** |
| 11 | +to perform all matmuls in the linear layer forward and backward pass as FP8 GEMMs. |
| 12 | + |
| 13 | +**Example**: |
| 14 | + |
| 15 | +```python |
| 16 | +from torchao.prototype.float8nocompile.float8nocompile_linear_utils import ( |
| 17 | + convert_to_float8_nocompile_training, |
| 18 | +) |
| 19 | + |
| 20 | +# define your model, data loaders, etc |
| 21 | +... |
| 22 | + |
| 23 | +# convert specified `torch.nn.Linear` modules to `Float8Linear` |
| 24 | +convert_to_float8_nocompile_training(model) |
| 25 | + |
| 26 | +# training loop |
| 27 | +for i in range(num_epochs): |
| 28 | + ... |
| 29 | +``` |
| 30 | + |
| 31 | +### Performance benchmarks |
| 32 | + |
| 33 | +Performance benchmarking was done via [experimental integration into torchtitan](https://github.com/pytorch/torchtitan/pull/778). |
| 34 | + |
| 35 | +The results indicate a solid 6-10% tokens/sec speedup with relatively flat memory (+/- 1% peak memory) compared the bf16 eager baseline. |
| 36 | + |
| 37 | +# Performance Comparison of Different Configurations on 8 H100s |
| 38 | + |
| 39 | +## No AC (seq len 4096) - 8 H100s |
| 40 | + |
| 41 | +| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ | |
| 42 | +|-------------------------------------------------|------------|------------------|--------------|---------------| |
| 43 | +| bfloat16, eager | 5339.0 | 53.12 | 0% | 0.00% | |
| 44 | +| float8nocompile prototype | 5871.4 | 52.7 | 9.97% | -0.79% | |
| 45 | +| float8 + torch.compile | 6667.6 | 46.64 | 24.88% | -12.20% | |
| 46 | + |
| 47 | +--- |
| 48 | + |
| 49 | +## Selective per layer AC (AC every 2nd layer, seq len 4096) - 8 H100s |
| 50 | + |
| 51 | +| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ | |
| 52 | +|-------------------------------------------------|------------|------------------|--------------|---------------| |
| 53 | +| bfloat16, eager | 4882.4 | 40.6 | 0% | 0.00% | |
| 54 | +| float8nocompile prototype | 5302.0 | 40.97 | 8.59% | 0.91% | |
| 55 | +| float8 + torch.compile | 6199.6 | 37.38 | 26.98% | -7.93% | |
| 56 | + |
| 57 | +--- |
| 58 | + |
| 59 | +## Full AC (seq len 4096) - 8 H100s |
| 60 | + |
| 61 | +| Configuration | Tokens/sec | Peak memory (GB) | Tokens/sec Δ | Peak memory Δ | |
| 62 | +|-------------------------------------------------|------------|------------------|--------------|---------------| |
| 63 | +| bfloat16, eager | 4502.0 | 28.07 | 0% | 0.00% | |
| 64 | +| float8nocompile prototype | 4773.4 | 28.07 | 6.03% | 0.00% | |
| 65 | +| float8 + torch.compile | 5775.2 | 28.03 | 28.28% | -0.14% | |
| 66 | + |
| 67 | + |
| 68 | +## Numerical accuracy |
| 69 | + |
| 70 | +Numerical accuracy has been verified via unit tests as well as manually verifying that the training loss curves maintain fidelity with the loss curves for bf16 eager and production float8 + torch.compile: |
| 71 | + |
| 72 | + |
0 commit comments