You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Torchao plans to deprecate delayed scaling, delete it at torchtitan
Fix the issue: #654
Here are the logs running with `enable_float8_linear = true`
1. `compile = false`
```
[rank0]:2025-01-31 10:12:50,551 - root - INFO - Float8 training active
[rank0]:2025-01-31 10:12:50,571 - root - INFO - Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=False
[rank0]:2025-01-31 10:12:50,572 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2025-01-31 10:12:50,572 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:2025-01-31 10:12:50,635 - root - INFO - Applied FSDP to the model
[rank0]:2025-01-31 10:12:50,835 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank0]:2025-01-31 10:12:50,835 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank0]:2025-01-31 10:12:50,837 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250131-1012
[rank0]:2025-01-31 10:12:50,837 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 5 (warmup 200)
[rank0]:2025-01-31 10:12:50,837 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:2025-01-31 10:13:02,460 - root - INFO - step: 1 loss: 12.2581 memory: 74.27GiB(78.18%) tps: 705 mfu: 4.13%
[rank0]:2025-01-31 10:13:02,460 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-01-31 10:13:04,973 - root - INFO - step: 2 loss: 12.0754 memory: 81.77GiB(86.07%) tps: 3,262 mfu: 19.10%
[rank0]:2025-01-31 10:13:07,033 - root - INFO - step: 3 loss: 11.7432 memory: 81.77GiB(86.07%) tps: 3,980 mfu: 23.30%
[rank0]:2025-01-31 10:13:09,089 - root - INFO - step: 4 loss: 11.3079 memory: 81.77GiB(86.07%) tps: 3,986 mfu: 23.34%
[rank0]:2025-01-31 10:13:11,146 - root - INFO - step: 5 loss: 10.9303 memory: 81.77GiB(86.07%) tps: 3,985 mfu: 23.33%
[rank0]:2025-01-31 10:13:11,147 - root - INFO - Saving a full checkpoint at last step, step 5.
[rank0]:2025-01-31 10:13:31,549 - root - INFO - Finished saving the checkpoint (or staging if async is enabled)in 20.40 seconds.
[rank0]:2025-01-31 10:13:31,549 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2025-01-31 10:13:33,551 - root - INFO - Training completed
```
2. `compile = true`
```
[rank0]:2025-01-31 10:18:55,527 - root - INFO - Float8 training active
[rank0]:2025-01-31 10:18:55,547 - root - INFO - Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=False
[rank0]:2025-01-31 10:18:55,548 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2025-01-31 10:18:55,549 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:2025-01-31 10:18:55,591 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2025-01-31 10:18:55,656 - root - INFO - Applied FSDP to the model
[rank0]:2025-01-31 10:18:56,530 - root - INFO - CUDA memory usage for model: 3.77GiB(3.97%)
[rank0]:2025-01-31 10:18:56,532 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250131-1018
[rank0]:2025-01-31 10:18:56,533 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 5 (warmup 200)
[rank0]:2025-01-31 10:18:56,533 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0]
[rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0] Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
[rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0] Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_).
[rank0]:[rank0]:W0131 10:19:01.052000 1427728 torch/_logging/_internal.py:1093] [0/0]
[rank0]:/data/users/yifanmao/pytorch/torch/_inductor/lowering.py:1903: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]: warnings.warn(
[rank0]:2025-01-31 10:19:15,619 - root - INFO - step: 1 loss: 12.2476 memory: 40.21GiB(42.32%) tps: 429 mfu: 2.51%
[rank0]:2025-01-31 10:19:15,619 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2025-01-31 10:19:16,747 - root - INFO - step: 2 loss: 12.0860 memory: 47.77GiB(50.28%) tps: 7,267 mfu: 42.55%
[rank0]:2025-01-31 10:19:17,852 - root - INFO - step: 3 loss: 11.7620 memory: 47.77GiB(50.28%) tps: 7,420 mfu: 43.45%
[rank0]:2025-01-31 10:19:18,953 - root - INFO - step: 4 loss: 11.3075 memory: 47.77GiB(50.28%) tps: 7,449 mfu: 43.62%
[rank0]:2025-01-31 10:19:20,054 - root - INFO - step: 5 loss: 10.9359 memory: 47.77GiB(50.28%) tps: 7,448 mfu: 43.61%
[rank0]:2025-01-31 10:19:20,054 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:2025-01-31 10:19:22,056 - root - INFO - Training completed
```
0 commit comments