Skip to content

Commit ed16fe7

Browse files
authored
float8 training: add README.md entry for rowwise scaling (#1733)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent c6c388b commit ed16fe7

File tree

1 file changed

+53
-2
lines changed

1 file changed

+53
-2
lines changed

torchao/float8/README.md

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ throughput speedups of up to 1.5x on 128 GPU LLaMa 3 70B pretraining jobs.
1717

1818
We provide three per-tensor scaling strategies: dynamic, delayed and static. See https://arxiv.org/pdf/2209.05433.pdf, Section 4.3 for more details. These strategies are configurable separately for activations (`input`), weights (`weight`) and gradients (`grad_output`).
1919

20-
## float8 linear with dynamic scaling for `input`, `weight` and `grad_output`
20+
## float8 linear with dynamic tensorwise scaling
2121

22-
This is the most accurate recipe as every tensor is scaled dynamically.
22+
This is the default recipe, with a good balance of performance and accuracy.
2323

2424
```python
2525
import torch
@@ -63,6 +63,57 @@ for _ in range(10):
6363
optimizer.step()
6464
```
6565

66+
## float8 linear with rowwise scaling
67+
68+
This is a more accurate recipe compared to tensorwise, with more granular scaling.
69+
70+
:warning: <em>The composability of float8 with rowwise scaling with Tensor Parallelism is WIP, please see https://github.com/pytorch/ao/issues/1732 for more details.</em>
71+
72+
```python
73+
import torch
74+
import torch.nn as nn
75+
from torchao.float8 import convert_to_float8_training, Float8LinearConfig
76+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
77+
78+
if not TORCH_VERSION_AT_LEAST_2_5:
79+
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
80+
81+
# create model and sample input
82+
m = nn.Sequential(
83+
nn.Linear(2048, 4096),
84+
nn.Linear(4096, 128),
85+
).bfloat16().cuda()
86+
x = torch.randn(4096, 2048, device="cuda", dtype=torch.bfloat16)
87+
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
88+
89+
# optional: filter modules from being eligible for float8 conversion
90+
def module_filter_fn(mod: torch.nn.Module, fqn: str):
91+
# don't convert the last module
92+
if fqn == "1":
93+
return False
94+
# don't convert linear modules with weight dimensions not divisible by 16
95+
if isinstance(mod, torch.nn.Linear):
96+
if mod.in_features % 16 != 0 or mod.out_features % 16 != 0:
97+
return False
98+
return True
99+
100+
# configure rowwise scaling
101+
config = Float8LinearConfig.from_recipe_name("rowwise")
102+
103+
# convert specified `torch.nn.Linear` modules to `Float8Linear`
104+
convert_to_float8_training(m, config=config, module_filter_fn=module_filter_fn)
105+
106+
# enable torch.compile for competitive performance
107+
m = torch.compile(m)
108+
109+
# toy training loop
110+
for _ in range(10):
111+
optimizer.zero_grad()
112+
y = m(x)
113+
y.sum().backward()
114+
optimizer.step()
115+
```
116+
66117
## float8 linear with delayed scaling
67118

68119
:warning: <em>We plan to deprecate delayed scaling in a future release, see https://github.com/pytorch/ao/issues/1680 for more details.</em>

0 commit comments

Comments
 (0)