-
Notifications
You must be signed in to change notification settings - Fork 571
Merge branch 'main' into whc/merge_autoparallel #1834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
DO NOT MERGE: WIP This is a baseline for multi-node pretraining on H200s, since currently there don't see seem to be any numbers out for H200.
passing dtype argument to preprocess_data in generate_image
In a prior PR we added the `_init_filter_fn()` to configure a module filter function at Float8 component init time, but didn't actually use it. This went unnoticed because the existing module filter (`partial(module_filter_fn, filter_fqns=self.filter_fqns)` behaves the same way except for the case where the user uses `auto_filter_small_kn`. In this PR we fix that by using the `self.filter_fn`. ## Test plan - Test auto_filter_small_kn and verify the wk/wv are filtered for Llama3 8b: `NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --training.compile --model.converters="float8" --float8.recipe_name="rowwise" --parallelism.tensor_parallel_degree=2 --float8.filter_fqns="auto_filter_small_kn" --model.print-after-conversion` - Test without auto_filter_small_kn and verify all linears are converted: `NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --training.steps=100 --training.compile --model.converters="float8" --float8.recipe_name="rowwise" --parallelism.tensor_parallel_degree=2 --float8.filter_fqns="auto_filter" --model.print-after-conversion
This pr fix what seems to be a wrong estimation of the peak flops for the B200. With the current code the peak flops of B200 is 4.5x bigger that H100 which seems off. It seems that the number reported https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703 are with 2:4 sparsity ?
This PR does the following: 1. move `world_mesh` into `ParallelDims`, as they have a close relationship 2. move `enable_loss_parallel` out of `ParallelDims` constructor 3. add a convenient property `seq_len_divisor` to `ParallelDims` 4. set `dataloader` and `ft_manager` as optional in `CheckpointManager` 5. some minor improvements on typing and code organization
If checkpoint.last_save_in_safetensors_format is set, then save the checkpoint with DCP HF components that will save the checkpoint in .safetensors files instead of regular DCP format on final save. On load, we can decide which type of load to do based on checkpoint type. Successful save: ``` (titan) [[email protected] /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh + NGPU=8 + export LOG_RANK=0,1,2 + LOG_RANK=0,1,2 + CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml + overrides= + '[' 0 -ne 0 ']' + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0,1,2 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] ***************************************** W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] ***************************************** [rank0]:[titan] 2025-07-10 19:20:49,848 - root - INFO - Starting job: Llama 3 8B training [rank1]:[titan] 2025-07-10 19:20:49,985 - root - INFO - Starting job: Llama 3 8B training [rank2]:[titan] 2025-07-10 19:20:51,188 - root - INFO - Starting job: Llama 3 8B training [rank0]:[titan] 2025-07-10 19:20:52,644 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-10 19:20:52,646 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-07-10 19:20:52,650 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:NCCL version 2.27.5+cuda12.9 [rank1]:[titan] 2025-07-10 19:20:52,976 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank1]:[titan] 2025-07-10 19:20:52,979 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank1]:[titan] 2025-07-10 19:20:52,984 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank2]:[titan] 2025-07-10 19:20:53,902 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank2]:[titan] 2025-07-10 19:20:53,905 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank2]:[titan] 2025-07-10 19:20:53,910 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - Preparing c4 dataset from allenai/c4 [rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - Preparing c4 dataset from allenai/c4 [rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - Preparing c4 dataset from allenai/c4 [rank2]:[titan] 2025-07-10 19:21:02,550 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank2]:[titan] 2025-07-10 19:21:02,944 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank2]:[titan] 2025-07-10 19:21:02,968 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank2]:[titan] 2025-07-10 19:21:02,969 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank2]:[titan] 2025-07-10 19:21:02,970 - root - INFO - Applied selective activation checkpointing to the model [rank1]:[titan] 2025-07-10 19:21:03,101 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank0]:[titan] 2025-07-10 19:21:03,142 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank2]:[titan] 2025-07-10 19:21:03,123 - root - INFO - Applied FSDP to the model [rank1]:[titan] 2025-07-10 19:21:03,491 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank1]:[titan] 2025-07-10 19:21:03,515 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank1]:[titan] 2025-07-10 19:21:03,516 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank1]:[titan] 2025-07-10 19:21:03,517 - root - INFO - Applied selective activation checkpointing to the model [rank0]:[titan] 2025-07-10 19:21:03,550 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-1921 [rank0]:[titan] 2025-07-10 19:21:03,551 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank0]:[titan] 2025-07-10 19:21:03,574 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 19:21:03,575 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:[titan] 2025-07-10 19:21:03,576 - root - INFO - Applied selective activation checkpointing to the model [rank1]:[titan] 2025-07-10 19:21:03,675 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-10 19:21:03,732 - root - INFO - Applied FSDP to the model [rank2]:[titan] 2025-07-10 19:21:03,813 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank2]:[titan] 2025-07-10 19:21:03,813 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank2]:[titan] 2025-07-10 19:21:03,814 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank2]:[titan] 2025-07-10 19:21:03,817 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2. [rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Mixed precision training is handled by fully_shard [rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200). [rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Training starts at step 1. [rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank1]:[titan] 2025-07-10 19:21:04,369 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank1]:[titan] 2025-07-10 19:21:04,373 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2. [rank0]:[titan] 2025-07-10 19:21:04,335 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank0]:[titan] 2025-07-10 19:21:04,340 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2. [rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Mixed precision training is handled by fully_shard [rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200). [rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Training starts at step 1. [rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200). [rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 1,046 tflops: 60.58 mfu: 19.42% [rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1 [rank0]:[titan] 2025-07-10 19:21:11,408 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 971 tflops: 56.23 mfu: 18.02% [rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - Calling checkpoint save after step 1 [rank2]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank1]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step: 1 loss: 12.2520 grad_norm: 4.0543 memory: 42.12GiB(53.23%) tps: 1,038 tflops: 60.13 mfu: 19.27% [rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1 [rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank2]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Calling checkpoint save after step 2 [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2. [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291 [rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096]) [rank0]:[titan] 2025-07-10 19:21:14,015 - root - INFO - Calling checkpoint save after step 2 [rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2. [rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291 [rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096]) [rank1]:[titan] 2025-07-10 19:21:14,023 - root - INFO - Calling checkpoint save after step 2 [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving the checkpoint (or staging if async is enabled). [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2. [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Num keys before parsing 291, after 291 [rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096]) [rank0]:Done writing metadata. Took %.2f secs. 0.026559114456176758 [rank0]:Done writing data. Took %.2f secs. 66.62590146064758 [rank0]:Done consolidating. Took %.2f secs. 66.62735033035278 [rank0]:time taken for all reduce: 141.72666668891907 [rank1]:time taken for all reduce: 141.73284125328064 [rank2]:time taken for all reduce: 141.72900009155273 [rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds. [rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Training completed [rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Destroying the purge thread. [rank0]:[titan] 2025-07-10 19:23:36,827 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds. [rank0]:[titan] 2025-07-10 19:23:36,828 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds. [rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Training completed [rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Destroying the purge thread. [rank2]:[titan] 2025-07-10 19:23:37,243 - root - INFO - Process group destroyed. [rank0]:[titan] 2025-07-10 19:23:38,828 - root - INFO - Training completed [rank0]:[titan] 2025-07-10 19:23:38,829 - root - INFO - Destroying the purge thread. [rank1]:[titan] 2025-07-10 19:23:39,503 - root - INFO - Process group destroyed. [rank0]:[titan] 2025-07-10 19:23:39,705 - root - INFO - Process group destroyed. ``` Successful load: ``` (titan) [[email protected] /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh + NGPU=8 + export LOG_RANK=0 + LOG_RANK=0 + CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml + overrides= + '[' 0 -ne 0 ']' + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510/ + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] ***************************************** W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] ***************************************** [rank0]:[titan] 2025-07-10 20:56:24,765 - root - INFO - Starting job: Llama 3 8B training [rank0]:NCCL version 2.27.5+cuda12.9 [rank0]:[titan] 2025-07-10 20:56:27,746 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-10 20:56:27,748 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-07-10 20:56:27,753 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:[titan] 2025-07-10 20:56:36,070 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001) [rank0]:[titan] 2025-07-10 20:56:36,430 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-2056 [rank0]:[titan] 2025-07-10 20:56:36,431 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory [rank0]:[titan] 2025-07-10 20:56:36,452 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 20:56:36,454 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:[titan] 2025-07-10 20:56:36,455 - root - INFO - Applied selective activation checkpointing to the model [rank0]:[titan] 2025-07-10 20:56:36,598 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-10 20:56:37,138 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100 [rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14 [rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%) [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 1000 (warmup 200). [rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Loading the checkpoint from ./outputs/checkpoint/step-3. [rank0]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/checkpoint/hf_storage.py:259: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:1579.) [rank0]: tensor = torch.frombuffer( [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds. [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Finished loading the checkpoint in 27.21 seconds. [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Training starts at step 1. [rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace [rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - step: 1 loss: 12.0247 grad_norm: 42.7524 memory: 42.12GiB(53.23%) tps: 236 tflops: 13.67 mfu: 4.38% [rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 ``` --------- Co-authored-by: ankitageorge <[email protected]> Co-authored-by: ankitageorge <[email protected]> Co-authored-by: tianyu-l <[email protected]>
Integrated the validator together with metrics processor for better metrics logging. Key changes: - Metrics processor is passed to validator within training loop - Validator can reuse metrics processor's built-in functionalities such as memory profiling, throughput tracking, and tensorboard/wandb logging This is how the new logging looks from terminal: <img width="959" height="374" alt="Screenshot 2025-07-14 at 3 22 56 PM" src="https://github.com/user-attachments/assets/b16a9e00-3ab2-46ed-a42a-0c92d13697cb" />
This PR refactors `FTManager` to: 1. simplify construction logic 2. expose simpler interfact to `train.py` 3. make it optional when building optimizer and some other minor improvements.
Changes 1. New helper method to explicitly specify the modules to include 2. Update model.py to handle `None` attributes Can run the 16B DSV3 on 8 GPUs with PP: ``` NGPU=8 LOG_RANK=0,7 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.pipeline_parallel_degree 8 --parallelism.pipeline_parallel_schedule Interleaved1F1B ``` DP + TP + PP example run: https://meta.wandb.io/howardhuang_meta/torchtitan/reports/DeepSeekV3-16B-8-GPU-DP-TP-PP---VmlldzozMTAz?accessToken=ltsuxu4atlsmtk5u1g1zt04xcb6q1cs4mm9ianq8mlqlpq4ppm3lfpu1p53ei4pg TODO: - upstream `pipeline_module_split` to `torch.distributed.pipelining`?
Also see discussion in #1372 This PR: - Adds new config for SAC with the default such that per-op SAC automatically skips all mms with args[1].shape matching that of the Linear at fqn "moe.router.gate" - Adds general flop/act-mem/correctness tests for AC as well as the new config
…uffer (#1403) As titled. Tested on llama4 debugging model (dp=8, ep=2): <img width="1188" height="226" alt="Screenshot 2025-07-15 at 8 05 12 PM" src="https://github.com/user-attachments/assets/24a1bf87-b038-481e-b40b-96e2123c96fc" />
Summary: enable creating a separate outer optimizer for each of the parameter fragments for streaming diloco
## Changes in this diff: 1. Pass softmax_scale to sdpa() forward. 2. Change some default parameters for debug_model.toml
Summary: - log the tensorboard for ft replicas to a separate folder - log profiles for ft replicas to a separate folder --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1410). * #1411 * __->__ #1410
Summary: - allow using gloo process group - add a parameter to the ft config - only nccl and gloo will be supported for now --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1411). * __->__ #1411 * #1410
Stacked PRs: * __->__#1408 --- --- --- Remove flex+compile restriction Recently enabled this w/ pytorch/pytorch#150080 Before the change I got: ```Shell [rank0]:[rank0]: File "/home/drisspg/meta/torchtitan/torchtitan/models/llama3/model/args.py", line 48, in update_from_config [rank0]:[rank0]: raise ValueError( [rank0]:[rank0]: ValueError: FlexAttention is not compatible with selective AC yet. See pytorch/pytorch#147879 ``` I tried running this locally with; ``` CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" NGPU=1 ./run_train.sh --model.flavor debugmodel_flex_attn --training.compile ``` Got: <img width="858" height="568" alt="Screenshot 2025-07-16 at 8 25 53 PM" src="https://github.com/user-attachments/assets/f540b01f-62ad-4f81-b5e4-e734e2b09081" /> If there are other more robust testing we can do I am down I am trying to unblock some internal users and ensure I can close this issue: pytorch/pytorch#147879
This PRs adds a new field `StateDictAdapter` in `TrainSpec` - Currently it only contains a pair of static methods `to_hf` and `from_hf`. Later we could add other pairs like `to_meta` / `from_meta`, `to_vllm` / `from_vllm`, etc. - It is passed to `CheckpointManager` to convert between torchtitan model and HF model during checkpoint save / load. - It could also be potentially used by downstream inference engines which only supports HF models. In order to save / load in HF format, a model is required to have a corresponding `StateDictAdapter` subclass implementation. For Llama3, I created a placeholder `Llama3StateDictAdapter` to be implemented. cc @wesleytruong This PR also renames checkpoint config options `initial_load_model_weights_only` `last_save_model_weights_only` to simply `initial_load_model_only` `last_save_model_only`, respectively. It seems to me that the original names were made corresponding to `torch.load(..., weights_only=True)`. As long as we document & test clearly when this correspondence holds, I prefer the name in torchtitan to be simple and less ambiguous.
Fix issue mentioned in #1418 . Deepseek reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L592 Huggingface reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-0324/blob/main/modeling_deepseek.py#L468
Stacked PRs: * __->__#1437 --- --- --- Add flex as impl, debugging #1412 ### Running debug model ``` ❯ NGPU=8 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.flavor debugmodel_flex_attn + NGPU=8 + export LOG_RANK=0 + LOG_RANK=0 + CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml + overrides= + '[' 2 -ne 0 ']' + overrides='--model.flavor debugmodel_flex_attn' + TORCHFT_LIGHTHOUSE=http://localhost:29510 + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510 + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/debug_model.toml --model.flavor debugmodel_flex_attn ***************************************** Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. ***************************************** [rank0]:[titan] 2025-07-21 15:57:07,261 - root - INFO - Starting job: DeepSeek-V3 debug training [rank0]:[titan] 2025-07-21 15:57:08,850 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-21 15:57:08,852 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-07-21 15:57:08,853 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:NCCL version 2.27.5+cuda12.9 [rank0]:[titan] 2025-07-21 15:57:12,369 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-07-21 15:57:12,371 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test [rank0]:[titan] 2025-07-21 15:57:12,449 - root - INFO - Building deepseek_v3 debugmodel_flex_attn with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=2048, dtype='bf16', vocab_size=2000, dim=256, inter_dim=1024, moe_inter_dim=256, n_layers=3, n_dense_layers=1, n_heads=16, norm_eps=1e-05, n_routed_experts=8, n_shared_experts=2, n_activated_experts=3, n_expert_groups=1, n_limited_groups=1, score_func='softmax', route_scale=1.0, use_grouped_mm=True, load_balance_coeff=0.001, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=True, attn_mask_type='block_causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7, eos_id=0) [rank0]:[titan] 2025-07-21 15:57:12,471 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory [rank0]:[titan] 2025-07-21 15:57:12,593 - root - INFO - Total parameter count: dense 12,479,744, sparse 3,936,256, active 14,449,920 [rank0]:[titan] 2025-07-21 15:57:12,593 - root - INFO - Model deepseek_v3 debugmodel_flex_attn size: 16,416,000 total parameters [rank0]:[titan] 2025-07-21 15:57:12,615 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-21 15:57:12,858 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-07-21 15:57:12,858 - root - INFO - CUDA memory usage for model: 0.00GiB(0.00%) [rank0]:[titan] 2025-07-21 15:57:12,859 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-21 15:57:12,859 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2). [rank0]:[titan] 2025-07-21 15:57:12,859 - root - INFO - Training starts at step 1. [rank0]:/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/backends/cuda/__init__.py:131: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:78.) [rank0]: return torch._C._get_cublas_allow_tf32() [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 361472 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 263168 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 328704 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:AUTOTUNE flex_attention(8x16x2048x192, 8x16x2048x192, 8x16x2048x128, 8x16x2048, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x2048) [rank0]:strides: [6291456, 192, 3072, 1], [6291456, 192, 3072, 1], [8388608, 256, 4096, 1], [32768, 2048, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [2048, 1] [rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32 [rank0]: triton_flex_attention_4 2.7331 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_0 5.9175 ms 46.2% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_1 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_2 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=8 [rank0]: triton_flex_attention_3 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]:SingleProcess AUTOTUNE benchmarking takes 0.1680 seconds and 2.7997 seconds precompiling for 5 choices [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 248832 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 297984 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 247808 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 247808 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 297984 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 297984 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 347136 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 347136 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:AUTOTUNE flex_attention_backward(8x16x2048x192, 8x16x2048x192, 8x16x2048x128, 8x16x2048, 8x16x2048, 8x16x2048x128, 8x16x2048x192, 8x16x2048x128, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x2048) [rank0]:strides: [6291456, 192, 3072, 1], [6291456, 192, 3072, 1], [8388608, 256, 4096, 1], [32768, 2048, 1], [32768, 2048, 1], [4194304, 262144, 128, 1], [6291456, 192, 3072, 1], [4194304, 128, 2048, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [2048, 1] [rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32 [rank0]: triton_flex_attention_backward_16 6.9023 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=8 [rank0]: triton_flex_attention_backward_18 7.2956 ms 94.6% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=8 [rank0]: triton_flex_attention_backward_20 7.9480 ms 86.8% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=8 [rank0]: triton_flex_attention_backward_14 8.2263 ms 83.9% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8 [rank0]: triton_flex_attention_backward_9 10.0881 ms 68.4% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=4 [rank0]: triton_flex_attention_backward_10 10.2952 ms 67.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_backward_11 10.6428 ms 64.9% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=4 [rank0]: triton_flex_attention_backward_12 11.6085 ms 59.5% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=4 [rank0]: triton_flex_attention_backward_26 12.3810 ms 55.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8 [rank0]: triton_flex_attention_backward_33 13.3964 ms 51.5% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=4 [rank0]:SingleProcess AUTOTUNE benchmarking takes 4.0514 seconds and 7.8970 seconds precompiling for 29 choices [rank0]:[titan] 2025-07-21 15:57:41,765 - root - INFO - step: 1 loss: 7.9132 grad_norm: 2.5572 memory: 0.00GiB(0.00%) tps: 562 tflops: 0.06 mfu: 0.01% [rank0]:[titan] 2025-07-21 15:57:41,765 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-21 15:57:42,241 - root - INFO - step: 2 loss: 6.6131 grad_norm: 3.5341 memory: 0.00GiB(0.00%) tps: 34,387 tflops: 3.52 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:42,774 - root - INFO - step: 3 loss: 4.9925 grad_norm: 2.6422 memory: 0.00GiB(0.00%) tps: 30,768 tflops: 3.15 mfu: 0.32% [rank0]:[titan] 2025-07-21 15:57:43,242 - root - INFO - step: 4 loss: 4.6671 grad_norm: 2.5457 memory: 0.00GiB(0.00%) tps: 35,051 tflops: 3.59 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:43,713 - root - INFO - step: 5 loss: 4.4394 grad_norm: 2.2714 memory: 0.00GiB(0.00%) tps: 34,824 tflops: 3.57 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:44,183 - root - INFO - step: 6 loss: 4.2412 grad_norm: 2.0883 memory: 0.00GiB(0.00%) tps: 34,852 tflops: 3.57 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:44,656 - root - INFO - step: 7 loss: 4.0939 grad_norm: 1.9913 memory: 0.00GiB(0.00%) tps: 34,692 tflops: 3.56 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:45,126 - root - INFO - step: 8 loss: 3.9942 grad_norm: 1.8520 memory: 0.00GiB(0.00%) tps: 34,879 tflops: 3.58 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:45,616 - root - INFO - step: 9 loss: 4.0337 grad_norm: 1.6383 memory: 0.00GiB(0.00%) tps: 33,464 tflops: 3.43 mfu: 0.35% [rank0]:[titan] 2025-07-21 15:57:46,090 - root - INFO - step: 10 loss: 3.9119 grad_norm: 1.6558 memory: 0.00GiB(0.00%) tps: 34,574 tflops: 3.54 mfu: 0.36% [rank0]:[titan] 2025-07-21 15:57:46,090 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-07-21 15:57:48,090 - root - INFO - Training completed [rank0]:[titan] 2025-07-21 15:57:48,407 - r ``` With these changes On H100 ```Shell NGPU=8 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml ./run_train.sh + NGPU=8 + export LOG_RANK=0 + LOG_RANK=0 + CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml + overrides= + '[' 0 -ne 0 ']' + TORCHFT_LIGHTHOUSE=http://localhost:29510 + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510 + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml ***************************************** Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. ***************************************** [rank0]:[titan] 2025-07-21 14:32:24,245 - root - INFO - Starting job: DeepSeek-V3 16B model training [rank0]:[titan] 2025-07-21 14:32:25,879 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-21 14:32:25,880 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8] [rank0]:[titan] 2025-07-21 14:32:25,882 - root - INFO - [GC] Initial GC collection. 0.00 seconds. [rank0]:NCCL version 2.27.5+cuda12.9 [rank0]:[titan] 2025-07-21 14:32:29,510 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-07-21 14:32:29,716 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:[titan] 2025-07-21 14:32:33,715 - root - INFO - Building deepseek_v3 16B with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=4096, dtype='bf16', vocab_size=128815, dim=2048, inter_dim=10944, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, n_heads=16, norm_eps=1e-05, n_routed_experts=64, n_shared_experts=2, n_activated_experts=6, n_expert_groups=1, n_limited_groups=1, score_func='softmax', route_scale=1.0, use_grouped_mm=True, load_balance_coeff=0.001, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=True, attn_mask_type='causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7, eos_id=0) [rank0]:[titan] 2025-07-21 14:32:33,892 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory [rank0]:[titan] 2025-07-21 14:32:33,946 - root - INFO - Total parameter count: dense 966,581,760, sparse 14,848,098,304, active 2,769,346,048 [rank0]:[titan] 2025-07-21 14:32:33,946 - root - INFO - Model deepseek_v3 16B size: 15,814,680,064 total parameters [rank0]:[titan] 2025-07-21 14:32:33,947 - root - INFO - Applied full activation checkpointing to the model [rank0]:[titan] 2025-07-21 14:32:34,034 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-21 14:32:34,431 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-07-21 14:32:34,431 - root - INFO - CUDA memory usage for model: 0.00GiB(0.00%) [rank0]:[titan] 2025-07-21 14:32:34,433 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-21 14:32:34,433 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 4096, total steps 1000 (warmup 200). [rank0]:[titan] 2025-07-21 14:32:34,433 - root - INFO - Training starts at step 1. [rank0]:/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/backends/cuda/__init__.py:131: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:78.) [rank0]: return torch._C._get_cublas_allow_tf32() [rank0]:Exception No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_drisspg/uu/cuuzcmwjztg3bl3ujzslc4ma26j6dinag6nocsnpg3dtmquqyno2.py, BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=8) [rank0]:Traceback (most recent call last): [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/concurrent/futures/thread.py", line 59, in run [rank0]: result = self.fn(*self.args, **self.kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 2598, in precompile_with_captured_stdout [rank0]: choice.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1875, in precompile [rank0]: self.bmreq.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/autotune_process.py", line 657, in precompile [rank0]: getattr(mod, self.kernel_name).precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile [rank0]: self._make_launchers() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers [rank0]: raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") [rank0]:RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help. [rank0]:Exception No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 327680 Hardware limit:232448 Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_drisspg/rg/crg27sodhizvda2jjyznte4u2sxyexstosyzgyfj2lfuekd2u7fx.py, BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4) [rank0]:Traceback (most recent call last): [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/concurrent/futures/thread.py", line 59, in run [rank0]: result = self.fn(*self.args, **self.kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 2598, in precompile_with_captured_stdout [rank0]: choice.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1875, in precompile [rank0]: self.bmreq.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/autotune_process.py", line 657, in precompile [rank0]: getattr(mod, self.kernel_name).precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile [rank0]: self._make_launchers() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers [rank0]: raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") [rank0]:RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 327680 Hardware limit:232448 Reducing block sizes or `num_stages` may help. [rank0]:Exception No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 360448 Hardware limit:232448 Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_drisspg/7c/c7cznn55lunqd7ln454lhdbpef3s6vnbrkieghrh4l74inbpmkgo.py, BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4) [rank0]:Traceback (most recent call last): [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/concurrent/futures/thread.py", line 59, in run [rank0]: result = self.fn(*self.args, **self.kwargs) [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 2598, in precompile_with_captured_stdout [rank0]: choice.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1875, in precompile [rank0]: self.bmreq.precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/autotune_process.py", line 657, in precompile [rank0]: getattr(mod, self.kernel_name).precompile() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile [rank0]: self._make_launchers() [rank0]: File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers [rank0]: raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}") [rank0]:RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 360448 Hardware limit:232448 Reducing block sizes or `num_stages` may help. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 360448 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 327680 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:AUTOTUNE flex_attention(8x16x4096x192, 8x16x4096x192, 8x16x4096x128, 8x16x4096, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32) [rank0]:strides: [12582912, 192, 3072, 1], [12582912, 192, 3072, 1], [16777216, 256, 4096, 1], [65536, 4096, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1] [rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.int32, torch.int32, torch.int32, torch.int32 [rank0]: triton_flex_attention_4 9.7972 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_0 18.6366 ms 52.6% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_1 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_2 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=8 [rank0]: triton_flex_attention_3 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]:SingleProcess AUTOTUNE benchmarking takes 0.2523 seconds and 4.4028 seconds precompiling for 5 choices [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 247808 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 296960 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 246784 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 246784 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 296960 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 296960 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 346112 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:Runtime error during autotuning: [rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 346112 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. [rank0]:Ignoring this choice. [rank0]:AUTOTUNE flex_attention_backward(8x16x4096x192, 8x16x4096x192, 8x16x4096x128, 8x16x4096, 8x16x4096, 8x16x4096x128, 8x16x4096x192, 8x16x4096x128, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32) [rank0]:strides: [12582912, 192, 3072, 1], [12582912, 192, 3072, 1], [16777216, 256, 4096, 1], [65536, 4096, 1], [65536, 4096, 1], [8388608, 524288, 128, 1], [12582912, 192, 3072, 1], [8388608, 128, 2048, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1] [rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32 [rank0]: triton_flex_attention_backward_16 28.0839 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=8 [rank0]: triton_flex_attention_backward_18 28.5011 ms 98.5% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=8 [rank0]: triton_flex_attention_backward_20 30.6411 ms 91.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=8 [rank0]: triton_flex_attention_backward_14 31.6448 ms 88.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8 [rank0]: triton_flex_attention_backward_10 36.6753 ms 76.6% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4 [rank0]: triton_flex_attention_backward_11 39.1248 ms 71.8% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=4 [rank0]: triton_flex_attention_backward_9 39.5692 ms 71.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=4 [rank0]: triton_flex_attention_backward_12 44.0589 ms 63.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=4 [rank0]: triton_flex_attention_backward_26 44.6748 ms 62.9% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8 [rank0]: triton_flex_attention_backward_33 49.2283 ms 57.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=4 [rank0]:SingleProcess AUTOTUNE benchmarking takes 13.1756 seconds and 7.4547 seconds precompiling for 29 choices [rank0]:[titan] 2025-07-21 14:33:33,285 - root - INFO - step: 1 loss: 12.2446 grad_norm: 1.2270 memory: 0.00GiB(0.00%) tps: 552 tflops: 9.80 mfu: 0.99% [rank0]:[titan] 2025-07-21 14:33:33,286 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-21 14:36:41,972 - root - INFO - step: 10 loss: 11.4945 grad_norm: 1.5629 memory: 0.00GiB(0.00%) tps: 1,563 tflops: 27.74 mfu: 2.81% [rank0]:[titan] 2025-07-21 14:40:12,773 - root - INFO - step: 20 loss: 9.8251 grad_norm: 7.4642 memory: 0.00GiB(0.00%) tps: 1,554 tflops: 27.59 mfu: 2.79% [rank0]:[titan] 2025-07-21 14:43:42,445 - root - INFO - step: 30 loss: 8.9616 grad_norm: 2.4638 memory: 0.00GiB(0.00%) tps: 1,563 tflops: 27.74 mfu: 2.81% ... ```
based off of @fegin previous PR: #907 This sets up CI for torchft in torchtitan, it only runs when `torchtitan/components/ft.py` is changed, but we can expand it as necessary. This also makes it easier to run torchft tests / configs by just calling: `python tests/integration_tests_ft.py --test_id ...` TODO: - There is an issue with our CI where we cannot set `CUDA_VISIBLE_DEVICES` to partition the devices per replica group (I get error Cuda failure 217 'peer access is not supported between these two devices'), as a result this PR only runs with 1 replica group. In follow up PRs need to look into manually setting the cuda device when using torchft?
Without this PR, FlexAttention mask function will receive a mixed of plain tensors and DTensors. Given that FlexAttention + DTensor story is not clear, let's always convert to plain tensors when feeding things into FlexAttention / SDPA.
Previously we initialize ModelArgs, and then update it dynamically to 1. get `vocab_size` from tokenizer (used for specifying shapes of embedding / output module) 2. get `eos_id` from tokenizer (used for generating block causal attention mask) 3. update `max_seq_len` according to training job `seq_len` (used for precomputing `freqs_cis`) ------------- 1 is causing troubles as we found `vocab_size` for model checkpoints (embedding / output layer) loaded from HF may not be always the same as `tokenizer.get_vocab_size()`. In fact, as long as `vocab_size` in embedding / output layer is larger than tokenizer's `vocab_size`, the training is still OK. In addition, there have been requests to not let `ModeArgs` and model init depend on tokenizers, so this PR removes 2 and instead send `eos_id` as input to model. ---------------- For 3, there is a caveat that when torchtitan is used for continuing training from a checkpoint, users should be aware that the original model has an intrinsic limit on `max_seq_len`. E.g. for llama 3 it's 8k, for llama 4 Scout it's 1M. Currently torchtitan users could break this limit by specifying an arbitrarily large `--training.seq_len`, whether intentionally or not. This PR keeps this flexibility for two reasons: 1. when `seq_len` is less than intrinsic `max_seq_len`, we only need to generate `freqs_cis` of `seq_len` both because of resource consideration and because of CP compatibility. 2. when users intentionally want to test long context training / extend to longer context capability, they could still do that. Instead, this PR adds a warning when getting a `seq_len` larger than the original `max_seq_len` I noticed that llama "official" implementation also allows this https://github.com/meta-llama/llama-models/blob/main/models/llama4/generation.py#L72
This PR takes job_config out of the CheckpointManager class. Why? JobConfig is a monolith -- it has knowledge of every part of a titan training job. As a result, it is hard to actually use CheckpointManager in a standalone fashion. In practice the job config is mostly only used for its checkpoint config, plus two other usages as far as I can tell: 1) Getting the replica_id from the FTManager 2) Taking the dump_folder from the job field and joining it with the checkpoint folder For (1) we can just get this directly from FTManager without accessing the JobConfig field. For (2) we can pass `job_config.job.dump_folder` explicitly as a base folder, then join to `checkpoint_config.folder`. Personally I would try to consolidate `job.dump_folder` and `checkpoint.folder` (though I understand there are cases where only the former is needed) under Checkpoint, but not sure if this is preferable from titan's pov.
This PR creates a new folder `torchtitan/config` to host `job_config.py` and `manager.py`, for the reasons below: - Both are complicated enough to worth their own files. - The convention in torchtitan to extend custom `JobConfig` is to create a file under model folder called `job_config.py` (see https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/job_config.py). This PR makes the origin `JobConfig` consistent with that convention. - (minor) Creating a more succinct `torchtitan.config` namespace is more readable than importing from `torchtitan.config_manager`.
In this PR, I'm updated the outdated doc string for DeepSeekV3ModelArgs
## Context As PP didn't need persistent buffer, and `torch.compile` works with non-persistent buffer now, change freq_cis from persistent buffer to non-persistent buffer . In this way, checkpointer doesn't need to explicitly exclude freq_cis when loading. ## Test 1. llama3 model with torch.compile ✅ 2. llama4 model with torch.compile ✅ 3. deepseek-v3 model with torch.compile ✅
This PR adds the `FluxStateDictAdapter`, allowing us to convert checkpoints to and from HF. Additional changes: - Modifies `download_hf_assets` script to support downloading diffusion-type safetensor files - Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf` so that conversion script can be reused - e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py ./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux --model_flavor flux-dev` Tests: Performing KL divergence test on the forward pass of converted weights loaded in `torchtitan` and HF weights loaded with HF `FluxTransformer2DModel`, we get: ``` Average loss for test from_hf is 7.233546986222528e-13 ``` Addiitonally, we can now run inference with HF weights to verify changes made in #1548 ### Batched Inference on TorchTitan: | | prompt0 | prompt1 | prompt2 | | --- | --- | --- | --- | | no CFG | <img width="1024" height="1024" alt="prompt0_nocfg" src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd" /> | <img width="1024" height="1024" alt="prompt1_nocfg" src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40" /> | <img width="1024" height="1024" alt="prompt2_nocfg" src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9" /> | | CFG | <img width="1024" height="1024" alt="prompt0_cfg" src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c" /> | <img width="1024" height="1024" alt="prompt1_cfg" src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c" /> | <img width="1024" height="1024" alt="prompt2_cfg" src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee" /> |
This PR removes the `convert_from_llama.py` script since it is superseded by `convert_from_hf.py` instead.
Creating a new field in `JobConfig`, with the default being ``` [compile] enable=false components = ["model", "loss"] ``` This way we get to compile loss separately to get memory reduction, even when the model is not ready to be compiled. This PR also applies loss compilation to DeepSeek 16B and 671B.
Tested Loading weights from https://huggingface.co/deepseek-ai/DeepSeek-V3.1-Base <img width="1296" height="605" alt="Screenshot 2025-08-20 at 10 28 20 PM" src="https://github.com/user-attachments/assets/cc5bc9ef-0afd-45c9-bdf6-7cf36d9729e8" />
We want to include this PR in our next release ASAP. Created another branch and revert CODE_OF_CONDUCT.md from @BioGeek 's #1583 . Much appreciated for @BioGeek's contribution! --------- Co-authored-by: Jeroen Van Goey <[email protected]>
We put all experts' usage into a buffer such that we only need one reduce rather than #number-of-layers times Additionally, handle cases where tokens per expert are counted twice during full recompute. Co-authored-by: wang55 <[email protected]>
fix compile config in parallelize.py, ref https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py
Adds a `ModelProtocol.get_extra_metrics` method for more flexible custom metric reporting, as discussed in #1576 Probably this should be an abstract method, but I was wary of making this a breaking change for users who inherit this commit. The current signature is `get_extra_metrics(self, parallel_dims: ParallelDims) -> None | dict`. I also considered adding some subset of `JobConfig`, `TrainSpec`, and `pp_has_{first,last}_stage`; not sure what else might be useful. Tested via running the debugmodel with print statements. CC @rakkit @wwwjn
One perspective on the attention mask is that it should be coupled with the dataloader rather than the modeling component. Therefore, this PR moves the creation of the attention mask to the trainer, removing it from the model itself. This PR also fixes #1612 ``` -> % LOG_RANK=6 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --training.steps=100 --parallelism.pipeline_parallel_degree=4 + NGPU=8 + export LOG_RANK=6 + LOG_RANK=6 + CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml + TORCHFT_LIGHTHOUSE=http://localhost:29510 + PYTORCH_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510 + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 6 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml --training.steps=100 --parallelism.pipeline_parallel_degree=4 W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] ***************************************** W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] ***************************************** [rank6]:[titan] 2025-08-21 10:14:50,681 - root - INFO - Starting job: DeepSeek-V3 16B model training [rank6]:[titan] 2025-08-21 10:14:53,248 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank6]:[titan] 2025-08-21 10:14:53,250 - root - INFO - Building 2-D device mesh with ['pp', 'dp_shard'], [4, 2] [rank6]:[titan] 2025-08-21 10:14:53,265 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank6]:[titan] 2025-08-21 10:14:56,937 - root - INFO - Loading tokenizer from tokenizer.json [rank6]:[titan] 2025-08-21 10:14:57,076 - root - INFO - Preparing c4 dataset from allenai/c4 [rank6]:[titan] 2025-08-21 10:15:00,743 - root - INFO - Building deepseek_v3 16B with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=4096, vocab_size=102400, dim=2048, inter_dim=10944, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, n_heads=16, norm_eps=1e-05, moe_args=MoEArgs(num_experts=64, num_shared_experts=2, score_func='softmax', route_norm=True, route_scale=1.0, score_before_experts=False, top_k=6, use_grouped_mm=True, load_balance_coeff=0.001), n_expert_groups=1, n_limited_groups=1, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=True, attn_mask_type='block_causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7) [rank6]:[titan] 2025-08-21 10:15:00,966 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory [rank6]:[titan] 2025-08-21 10:15:01,008 - root - INFO - Total parameter count: dense 858,385,920, sparse 14,848,098,304, active 2,661,150,208 [rank6]:[titan] 2025-08-21 10:15:01,008 - root - INFO - Model deepseek_v3 16B size: 15,706,484,224 total parameters [rank6]:Stage 3: Modules to keep: {'layers.14', 'layers.13', 'layers.11', 'layers.12'} [rank6]:Stage 7: Modules to keep: {'output', 'norm', 'layers.26', 'layers.25'} [rank6]:[titan] 2025-08-21 10:15:01,029 - root - INFO - PP rank 3 is building stage_idx 3 with modules ['layers.11', 'layers.12', 'layers.13', 'layers.14'] [rank6]:[titan] 2025-08-21 10:15:01,048 - root - INFO - PP rank 3 is building stage_idx 7 with modules ['layers.25', 'layers.26', 'norm', 'output'] [rank6]:[titan] 2025-08-21 10:15:01,048 - root - INFO - Applied full activation checkpointing to the model [rank6]:[titan] 2025-08-21 10:15:01,072 - root - INFO - Applied FSDP to the model [rank6]:[titan] 2025-08-21 10:15:01,072 - root - INFO - Applied full activation checkpointing to the model [rank6]:[titan] 2025-08-21 10:15:01,080 - root - INFO - Applied FSDP to the model [rank6]:[titan] 2025-08-21 10:15:01,080 - root - INFO - Using pipeline schedule Interleaved1F1B with 8 microbatches and 8 stages. [rank6]:[titan] 2025-08-21 10:15:01,488 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank6]:[titan] 2025-08-21 10:15:01,488 - root - INFO - CUDA memory usage for model: 6.94GiB(7.31%) [rank6]:[titan] 2025-08-21 10:15:01,489 - root - WARNING - Warmup steps (200) exceed total training steps (100). Adjusting warmup steps to 100. [rank6]:[titan] 2025-08-21 10:15:01,489 - root - WARNING - Warmup (100) + decay (80) steps exceed total training steps (100). Adjusting decay steps to 0. [rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Mixed precision training is handled by fully_shard [rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Trainer is initialized with local batch size 8, global batch size 16, gradient accumulation steps 1, sequence length 4096, total steps 100 (warmup 200) [rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Training starts at step 1 [rank6]:[rank6]:[W821 10:15:10.781655306 ProcessGroupNCCL.cpp:3993] Warning: An unbatched P2P op (send/recv) was called on this ProcessGroup with size 4. In lazy initialization mode, this will result in a new 2-rank NCCL communicator to be created. (function operator()) [rank6]:NCCL version 2.27.5+cuda12.6 [rank6]:[rank6]:[W821 10:15:16.977607954 ProcessGroupNCCL.cpp:3993] Warning: An unbatched P2P op (send/recv) was called on this ProcessGroup with size 4. In lazy initialization mode, this will result in a new 2-rank NCCL communicator to be created. (function operator()) [rank6]:/data/users/chienchin/mywork/pytorch/torch/__init__.py:1539: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /data/users/chienchin/mywork/pytorch/aten/src/ATen/Context.cpp:80.) [rank6]: return _C._get_float32_matmul_precision() [rank6]:[titan] 2025-08-21 10:15:28,674 - root - INFO - step: 1 loss: 12.0194 grad_norm: 1.8958 memory: 53.94GiB(56.78%) tps: 296 tflops: 5.16 mfu: 0.52% [rank6]:[titan] 2025-08-21 10:15:28,674 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 no^[^[[rank6]:[titan] 2025-08-21 10:15:43,154 - root - INFO - step: 10 loss: 10.3629 grad_norm: 3.0762 memory: 67.11GiB(70.64%) tps: 5,092 tflops: 88.73 mfu: 8.97% [rank6]:[titan] 2025-08-21 10:15:59,017 - root - INFO - step: 20 loss: 8.9238 grad_norm: 2.5020 memory: 67.11GiB(70.64%) tps: 5,165 tflops: 90.00 mfu: 9.10% [rank6]:[titan] 2025-08-21 10:16:15,051 - root - INFO - step: 30 loss: 7.8167 grad_norm: 1.7460 memory: 67.11GiB(70.64%) tps: 5,109 tflops: 89.04 mfu: 9.00% [rank6]:[titan] 2025-08-21 10:16:31,989 - root - INFO - step: 40 loss: 7.1761 grad_norm: 1.1432 memory: 67.11GiB(70.64%) tps: 4,837 tflops: 84.29 mfu: 8.52% [rank6]:[titan] 2025-08-21 10:16:48,455 - root - INFO - step: 50 loss: 6.7850 grad_norm: 1.4950 memory: 67.11GiB(70.64%) tps: 4,975 tflops: 86.70 mfu: 8.77% [rank6]:[titan] 2025-08-21 10:17:04,602 - root - INFO - step: 60 loss: 6.8310 grad_norm: 1.2972 memory: 67.11GiB(70.64%) tps: 5,074 tflops: 88.42 mfu: 8.94% [rank6]:[titan] 2025-08-21 10:17:22,231 - root - INFO - step: 70 loss: 6.6627 grad_norm: 1.1630 memory: 67.11GiB(70.64%) tps: 4,647 tflops: 80.98 mfu: 8.19% [rank6]:[titan] 2025-08-21 10:17:41,358 - root - INFO - step: 80 loss: 6.3542 grad_norm: 0.8215 memory: 67.11GiB(70.64%) tps: 4,283 tflops: 74.64 mfu: 7.55% [rank6]:[titan] 2025-08-21 10:17:58,336 - root - INFO - step: 90 loss: 6.4442 grad_norm: 1.2542 memory: 67.11GiB(70.64%) tps: 4,825 tflops: 84.09 mfu: 8.50% [rank6]:[titan] 2025-08-21 10:18:12,542 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds [rank6]:[titan] 2025-08-21 10:18:14,566 - root - INFO - step: 100 loss: 6.7519 grad_norm: 1.3966 memory: 67.11GiB(70.64%) tps: 5,048 tflops: 87.97 mfu: 8.89% [rank6]:[titan] 2025-08-21 10:18:14,566 - root - INFO - Training completed [rank6]:[titan] 2025-08-21 10:18:17,159 - root - INFO - Process group destroyed ```
Currently, the only available backend for SDPA for DeepSeekV3 is efficient attention kernel. For FlashAttentionV2 (what current SDPA supports), the V embedding dimension must be the same as Q and K. For cuDNN attention, it is complaining the head dimension is too large. The reason for defaulting the attention to SDPA in TorchTitan is that FlexCP is not yet ready. However, the combination of SDPA + CP + DeepSeekV3 is also not functional. This PR updates all DeepSeekV3 configurations to use FlexAttention, which significantly improves the overall performance. **Document masking also contributes to MFU improvement, but the majority is from FlexAttention itself**. ``` CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --training.steps=100 --parallelism.expert_parallel_degree=8 ``` SDPA: ``` [rank0]:[titan] 2025-08-20 18:28:42,047 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 4096, total steps 100 (warmup 200) [rank0]:[titan] 2025-08-20 18:28:42,047 - root - INFO - Training starts at step 1 [rank0]:[titan] 2025-08-20 18:29:04,053 - root - INFO - step: 1 loss: 12.0401 grad_norm: 1.7464 memory: 63.55GiB(66.89%) tps: 1,416 tflops: 24.67 mfu: 2.49% [rank0]:[titan] 2025-08-20 18:29:04,053 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-08-20 18:29:46,138 - root - INFO - step: 10 loss: 10.3087 grad_norm: 3.1896 memory: 78.14GiB(82.25%) tps: 7,008 tflops: 122.12 mfu: 12.35% [rank0]:[titan] 2025-08-20 18:30:33,628 - root - INFO - step: 20 loss: 8.7601 grad_norm: 2.5195 memory: 78.14GiB(82.25%) tps: 6,900 tflops: 120.24 mfu: 12.16% [rank0]:[titan] 2025-08-20 18:31:22,497 - root - INFO - step: 30 loss: 7.7450 grad_norm: 1.9296 memory: 78.14GiB(82.25%) tps: 6,705 tflops: 116.85 mfu: 11.82% [rank0]:[titan] 2025-08-20 18:32:19,709 - root - INFO - step: 40 loss: 6.9795 grad_norm: 0.6893 memory: 78.14GiB(82.25%) tps: 5,728 tflops: 99.81 mfu: 10.09% [rank0]:[titan] 2025-08-20 18:33:34,343 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds [rank0]:[titan] 2025-08-20 18:33:43,863 - root - INFO - step: 50 loss: 6.8381 grad_norm: 1.1848 memory: 78.14GiB(82.25%) tps: 3,894 tflops: 67.86 mfu: 6.86% [rank0]:[titan] 2025-08-20 18:34:37,289 - root - INFO - step: 60 loss: 6.5727 grad_norm: 0.9871 memory: 78.14GiB(82.25%) tps: 6,133 tflops: 106.88 mfu: 10.81% [rank0]:[titan] 2025-08-20 18:35:27,959 - root - INFO - step: 70 loss: 6.5041 grad_norm: 1.5895 memory: 78.14GiB(82.25%) tps: 6,467 tflops: 112.70 mfu: 11.40% [rank0]:[titan] 2025-08-20 18:36:16,732 - root - INFO - step: 80 loss: 6.3179 grad_norm: 0.9556 memory: 78.14GiB(82.25%) tps: 6,719 tflops: 117.08 mfu: 11.84% [rank0]:[titan] 2025-08-20 18:37:05,604 - root - INFO - step: 90 loss: 6.2124 grad_norm: 0.8286 memory: 78.14GiB(82.25%) tps: 6,705 tflops: 116.85 mfu: 11.81% [rank0]:[titan] 2025-08-20 18:37:49,285 - root - INFO - [GC] Peforming periodical GC collection 0.04 seconds [rank0]:[titan] 2025-08-20 18:37:54,361 - root - INFO - step: 100 loss: 6.2596 grad_norm: 1.5143 memory: 78.14GiB(82.25%) tps: 6,721 tflops: 117.12 mfu: 11.84% [rank0]:[titan] 2025-08-20 18:37:54,361 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-08-20 18:37:56,364 - root - INFO - Training completed [rank0]:[titan] 2025-08-20 18:37:57,535 - root - INFO - Process group destroyed ``` FlexAttention (now) ``` [rank0]:/data/users/chienchin/mywork/pytorch/torch/__init__.py:1539: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /data/users/chienchin/mywork/pytorch/aten/src/ATen/Context.cpp:80.) [rank0]: return _C._get_float32_matmul_precision() [rank0]:[titan] 2025-08-20 22:16:59,699 - root - INFO - step: 1 loss: 11.9984 grad_norm: 1.7288 memory: 63.55GiB(66.89%) tps: 727 tflops: 12.67 mfu: 1.28% [rank0]:[titan] 2025-08-20 22:16:59,699 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-08-20 22:17:32,228 - root - INFO - step: 10 loss: 10.3101 grad_norm: 2.9111 memory: 78.14GiB(82.25%) tps: 9,066 tflops: 157.99 mfu: 15.97% [rank0]:[titan] 2025-08-20 22:18:08,957 - root - INFO - step: 20 loss: 8.7431 grad_norm: 2.5391 memory: 78.14GiB(82.25%) tps: 8,922 tflops: 155.47 mfu: 15.72% [rank0]:[titan] 2025-08-20 22:18:46,981 - root - INFO - step: 30 loss: 7.7133 grad_norm: 1.7743 memory: 78.14GiB(82.25%) tps: 8,618 tflops: 150.18 mfu: 15.19% [rank0]:[titan] 2025-08-20 22:19:26,672 - root - INFO - step: 40 loss: 6.9643 grad_norm: 0.7227 memory: 78.14GiB(82.25%) tps: 8,256 tflops: 143.88 mfu: 14.55% [rank0]:[titan] 2025-08-20 22:20:01,975 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds [rank0]:[titan] 2025-08-20 22:20:06,015 - root - INFO - step: 50 loss: 6.8046 grad_norm: 1.0556 memory: 78.14GiB(82.25%) tps: 8,329 tflops: 145.15 mfu: 14.68% [rank0]:[titan] 2025-08-20 22:20:45,784 - root - INFO - step: 60 loss: 6.5364 grad_norm: 1.7141 memory: 78.14GiB(82.25%) tps: 8,240 tflops: 143.59 mfu: 14.52% [rank0]:[titan] 2025-08-20 22:21:25,078 - root - INFO - step: 70 loss: 6.4709 grad_norm: 1.2385 memory: 78.14GiB(82.25%) tps: 8,340 tflops: 145.33 mfu: 14.69% [rank0]:[titan] 2025-08-20 22:22:03,088 - root - INFO - step: 80 loss: 6.2786 grad_norm: 2.2534 memory: 78.14GiB(82.25%) tps: 8,621 tflops: 150.24 mfu: 15.19% [rank0]:[titan] 2025-08-20 22:22:41,254 - root - INFO - step: 90 loss: 6.1441 grad_norm: 0.6878 memory: 78.14GiB(82.25%) tps: 8,586 tflops: 149.62 mfu: 15.13% [rank0]:[titan] 2025-08-20 22:23:15,059 - root - INFO - [GC] Peforming periodical GC collection 0.05 seconds [rank0]:[titan] 2025-08-20 22:23:19,063 - root - INFO - step: 100 loss: 6.1348 grad_norm: 1.2875 memory: 78.14GiB(82.25%) tps: 8,667 tflops: 151.04 mfu: 15.27% [rank0]:[titan] 2025-08-20 22:23:19,064 - root - INFO - Sleeping 2 seconds for other ranks to complete [rank0]:[titan] 2025-08-20 22:23:21,065 - root - INFO - Training completed [rank0]:[titan] 2025-08-20 22:23:22,436 - root - INFO - Process group destroyed ```
This PR addresses duplicated code related to enabling async TP across different parts of the codebase. It introduces a new API, `maybe_enable_async_tp()`, which centralizes the enablement logic and is reused consistently in all models. Note that while this PR fixes one async TP bug in TorchTitan, it does not fully resolve #1613, as there appear to be additional bugs in PyTorch's async TP implementation.
This PR makes several miscellaneous refactors to clean up `torchtitan` before release. Changes: - Sets each of `model_parts` to eval mode in `Validator` class to support PP (Bug fix) - Refactor `checkpoint.enable_checkpoint -> checkpoint.enable` (Refactor) - Refacotr `validation.enabled -> validation.enable` (Refactor)
follow up of #1619 to fix remaining errors. also fixing a TODO
…1627) The DataloaderStopIteration exception inherited from StopIteration. According to PEP 479, raising a StopIteration subclass from a generator causes a RuntimeError in Python 3.7+. This change modifies the base class to `Exception` to ensure it can be caught correctly by user code without triggering this behavior. Fixes ISSUE #1626
…1633) As titled. Only enable weight tying for smaller model
self.score_function should be self.score_func
Add some basic documentation on how to use Titan with TorchFt for DiLoCo. LMK if anything needs clarification @vishal9-team
In this PR, I'm adding the StateDictAdapter for Qwen3 to enable loading HF checkpoints. We can use this script to adapt the checkpoint from HF to the format that we can load into the torchtitan model and vice versa. This can enable us to do a parity test with the HF implementation and make sure that our results are aligned with the HF implementation. --------- Co-authored-by: Hossein Kavianihamedani <[email protected]>
Summary: Allow user to configure WandB entity (aka team) and run name through environment variables. Differential Revision: D80499210
As the [discussion](#1618 (comment)), I added: - warning message when the validation `steps`=-1 in both comment and logger - change the default `steps` to reasonable values with the common setup (world size = 8). - Add infinite loop support for validator to avoid hang when `steps` is large enough to exhaust the dataset. - Add the same fix for flux. ## Test - 8 GPUs with `steps=-1`: hang around step 1270 - 8 GPUs with `steps=1200`: good - 8 GPUs with `steps=1500`: `infinite` automatically set to true. Exhaust the dataset and re-iterate, but won't hang - Flux: `steps=-1` doesn't hang - Flux: `steps=60` doesn't hang; re-loop the dataset. Full thread: #1618 cc @ebsmothers @tianyu-l
Follow up for #1634 Updated the warning message according to [Jiani's suggestion](#1634 (review)) cc @wwwjn
Validated debugmodel llama3 works, but ds3 crashes becuase of `build_optimizers_with_moe_load_balancing` doing stuff that traverses the original model structure, only now its an AutoParallelModule which isn't compatible, we'll have to disable this optimization for now and think about what to do. Note: paths have changed, update your run commands: `CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` Failing (ds3): `CONFIG_FILE=././torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name deepseekv3_auto_parallel`
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Merge branch 'main' into whc/merge_autoparallel
Validated debugmodel llama3 works, but ds3 crashes becuase of
build_optimizers_with_moe_load_balancingdoing stuff that traversesthe original model structure, only now its an AutoParallelModule which
isn't compatible, we'll have to disable this optimization for now and
think about what to do.
Note: paths have changed, update your run commands:
CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4Failing (ds3):
CONFIG_FILE=././torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name deepseekv3_auto_parallel