Commit 9b37408
authored
add torch.compile + FSDP2 float8 all-gather in CI (#468)
fixed my bug in float8_experimental. now we can torch.compile
transfromer blocks with FSDP float8 all-gather
meta-pytorch/float8_experimental#321
local test: `CONFIG_FILE="./train_configs/debug_model.toml"
./run_llama_train.sh --training.enable_float8_linear
--training.enable_fsdp_float8_all_gather
--training.precompute_float8_dynamic_scale_for_fsdp --training.compile`
profiler traces: I can see compiled region in cpu thread and float8
malmul `sm90_xmma_gemm_e4m3bf16...` in cuda stream
<img width="1468" alt="Screenshot 2024-07-18 at 4 22 17 PM"
src="https://github.com/user-attachments/assets/0cf58dee-aae1-4582-a3f1-b8aa48b45129">1 parent 2937167 commit 9b37408
1 file changed
+12
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
305 | 305 | | |
306 | 306 | | |
307 | 307 | | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
308 | 320 | | |
309 | 321 | | |
310 | 322 | | |
| |||
0 commit comments