You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: torchao/float8/README.md
+53-2Lines changed: 53 additions & 2 deletions
Original file line number
Diff line number
Diff line change
@@ -17,9 +17,9 @@ throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs.
17
17
18
18
We provide three per-tensor scaling strategies: dynamic, delayed and static. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`).
19
19
20
-
## float8 linear with dynamic scaling for `input`, `weight` and `grad_output`
20
+
## float8 linear with dynamic tensorwise scaling
21
21
22
-
This is the most accurate recipe as every tensor is scaled dynamically.
22
+
This is the default recipe, with a good balance of performance and accuracy.
23
23
24
24
```python
25
25
import torch
@@ -63,6 +63,57 @@ for _ in range(10):
63
63
optimizer.step()
64
64
```
65
65
66
+
## float8 linear with rowwise scaling
67
+
68
+
This is a more accurate recipe compared to tensorwise, with more granular scaling.
69
+
70
+
:warning: <em>The composability of float8 with rowwise scaling with Tensor Parallelism is WIP, please see https://github.com/pytorch/ao/issues/1732 for more details.</em>
71
+
72
+
```python
73
+
import torch
74
+
import torch.nn as nn
75
+
from torchao.float8 import convert_to_float8_training, Float8LinearConfig
76
+
from torchao.utils importTORCH_VERSION_AT_LEAST_2_5
77
+
78
+
ifnotTORCH_VERSION_AT_LEAST_2_5:
79
+
raiseAssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
80
+
81
+
# create model and sample input
82
+
m = nn.Sequential(
83
+
nn.Linear(2048, 4096),
84
+
nn.Linear(4096, 128),
85
+
).bfloat16().cuda()
86
+
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
0 commit comments