This repository was archived by the owner on Aug 7, 2024. It is now read-only.
Commit f4812ee
support float8 weight caching for gradient accumulation/PP (#164)
Summary:
In the cases where the optimizer update does not happen after every forward such as microbatching/PP, we can save the casted weight to trade some time for memory.
For now I'm just testing out performance+accuracy. We can improve on the API in future PRs. The current code is torch.compile friendly which is nice.
In terms of accuracy this should be no change, I will validate this further if we want to land this.
For performance, on drisspg's LLaMa 7B pretrain script, with bsz==128 and micro_bsz == 1:
1. baseline bf16 + compile: 2.38 it/s
2. delayed scaling + compile: 2.80 it/s (1.18x over baseline)
3. delayed scaling + compile + this PR: 3.04 it/s (1.28x over baseline)
Pull Request resolved: #164
Test Plan:
```
pytest test/test_base.py -s -k test_weight_caching
```
Reviewed By: drisspg
Differential Revision: D52356785
Pulled By: vkuzo
fbshipit-source-id: e0173666a6c7639246dfde636734900b9fc1657e1 parent b099049 commit f4812ee
File tree
6 files changed
+152
-16
lines changed- float8_experimental
- test
6 files changed
+152
-16
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
64 | 64 | | |
65 | 65 | | |
66 | 66 | | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
67 | 89 | | |
68 | 90 | | |
69 | 91 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
16 | 16 | | |
17 | 17 | | |
18 | 18 | | |
| 19 | + | |
| 20 | + | |
19 | 21 | | |
20 | 22 | | |
21 | | - | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
22 | 27 | | |
23 | 28 | | |
24 | 29 | | |
| |||
172 | 177 | | |
173 | 178 | | |
174 | 179 | | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
175 | 189 | | |
176 | 190 | | |
177 | 191 | | |
| |||
228 | 242 | | |
229 | 243 | | |
230 | 244 | | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
231 | 265 | | |
232 | | - | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
233 | 272 | | |
234 | 273 | | |
235 | 274 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
156 | 156 | | |
157 | 157 | | |
158 | 158 | | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
159 | 162 | | |
160 | 163 | | |
161 | 164 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
15 | 25 | | |
16 | 26 | | |
17 | 27 | | |
| |||
25 | 35 | | |
26 | 36 | | |
27 | 37 | | |
| 38 | + | |
28 | 39 | | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
38 | 47 | | |
39 | 48 | | |
40 | 49 | | |
41 | 50 | | |
42 | 51 | | |
43 | | - | |
| 52 | + | |
44 | 53 | | |
45 | | - | |
| 54 | + | |
46 | 55 | | |
47 | 56 | | |
48 | 57 | | |
| |||
122 | 131 | | |
123 | 132 | | |
124 | 133 | | |
125 | | - | |
| 134 | + | |
126 | 135 | | |
127 | 136 | | |
128 | 137 | | |
| |||
136 | 145 | | |
137 | 146 | | |
138 | 147 | | |
139 | | - | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
140 | 156 | | |
141 | 157 | | |
142 | 158 | | |
| |||
149 | 165 | | |
150 | 166 | | |
151 | 167 | | |
152 | | - | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
153 | 174 | | |
154 | 175 | | |
155 | 176 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
13 | 16 | | |
14 | 17 | | |
15 | 18 | | |
| |||
231 | 234 | | |
232 | 235 | | |
233 | 236 | | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
234 | 267 | | |
235 | 268 | | |
236 | 269 | | |
| |||
0 commit comments