Skip to content

Conversation

@IvanKobzarev
Copy link
Contributor

Merge branch 'main' into whc/merge_autoparallel

Validated debugmodel llama3 works, but ds3 crashes becuase of
build_optimizers_with_moe_load_balancing doing stuff that traverses
the original model structure, only now its an AutoParallelModule which
isn't compatible, we'll have to disable this optimization for now and
think about what to do.

Note: paths have changed, update your run commands:

CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4

Failing (ds3):
CONFIG_FILE=././torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name deepseekv3_auto_parallel

asaiacai and others added 30 commits July 11, 2025 22:57
DO NOT MERGE: WIP

This is a baseline for multi-node pretraining on H200s, since currently
there don't see seem to be any numbers out for H200.
passing dtype argument to preprocess_data in generate_image
In a prior PR we added the `_init_filter_fn()` to configure a module
filter function at Float8 component init time, but didn't actually use
it. This went unnoticed because the existing module filter
(`partial(module_filter_fn, filter_fqns=self.filter_fqns)` behaves the
same way except for the case where the user uses `auto_filter_small_kn`.
In this PR we fix that by using the `self.filter_fn`.

## Test plan
- Test auto_filter_small_kn and verify the wk/wv are filtered for Llama3
8b: `NGPU=4
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --training.compile
--model.converters="float8" --float8.recipe_name="rowwise"
--parallelism.tensor_parallel_degree=2
--float8.filter_fqns="auto_filter_small_kn"
--model.print-after-conversion`
- Test without auto_filter_small_kn and verify all linears are
converted: `NGPU=4
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml"
./run_train.sh --training.steps=100 --training.compile
--model.converters="float8" --float8.recipe_name="rowwise"
--parallelism.tensor_parallel_degree=2
--float8.filter_fqns="auto_filter" --model.print-after-conversion
This pr fix what seems to be a wrong estimation of the peak flops for
the B200. With the current code the peak flops of B200 is 4.5x bigger
that H100 which seems off. It seems that the number reported
https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703 are
with 2:4 sparsity ?
This PR does the following:
1. move `world_mesh` into `ParallelDims`, as they have a close
relationship
2. move `enable_loss_parallel` out of `ParallelDims` constructor
3. add a convenient property `seq_len_divisor` to `ParallelDims`
4. set `dataloader` and `ft_manager` as optional in `CheckpointManager`
5. some minor improvements on typing and code organization
If checkpoint.last_save_in_safetensors_format is set, then save the
checkpoint with DCP HF components that will save the checkpoint in
.safetensors files instead of regular DCP format on final save. On load,
we can decide which type of load to do based on checkpoint type.

Successful save:
```
(titan) [[email protected] /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
+ NGPU=8
+ export LOG_RANK=0,1,2
+ LOG_RANK=0,1,2
+ CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml
+ overrides=
+ '[' 0 -ne 0 ']'
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0,1,2 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] 
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] *****************************************
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0710 19:20:37.727000 1310423 site-packages/torch/distributed/run.py:774] *****************************************
[rank0]:[titan] 2025-07-10 19:20:49,848 - root - INFO - Starting job: Llama 3 8B training
[rank1]:[titan] 2025-07-10 19:20:49,985 - root - INFO - Starting job: Llama 3 8B training
[rank2]:[titan] 2025-07-10 19:20:51,188 - root - INFO - Starting job: Llama 3 8B training
[rank0]:[titan] 2025-07-10 19:20:52,644 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-07-10 19:20:52,646 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-07-10 19:20:52,650 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:NCCL version 2.27.5+cuda12.9
[rank1]:[titan] 2025-07-10 19:20:52,976 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank1]:[titan] 2025-07-10 19:20:52,979 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank1]:[titan] 2025-07-10 19:20:52,984 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank2]:[titan] 2025-07-10 19:20:53,902 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank2]:[titan] 2025-07-10 19:20:53,905 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank2]:[titan] 2025-07-10 19:20:53,910 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank0]:[titan] 2025-07-10 19:20:56,568 - root - INFO - Preparing c4 dataset from allenai/c4
[rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank2]:[titan] 2025-07-10 19:20:56,593 - root - INFO - Preparing c4 dataset from allenai/c4
[rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank1]:[titan] 2025-07-10 19:20:56,616 - root - INFO - Preparing c4 dataset from allenai/c4
[rank2]:[titan] 2025-07-10 19:21:02,550 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank2]:[titan] 2025-07-10 19:21:02,944 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank2]:[titan] 2025-07-10 19:21:02,968 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank2]:[titan] 2025-07-10 19:21:02,969 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank2]:[titan] 2025-07-10 19:21:02,970 - root - INFO - Applied selective activation checkpointing to the model
[rank1]:[titan] 2025-07-10 19:21:03,101 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank0]:[titan] 2025-07-10 19:21:03,142 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank2]:[titan] 2025-07-10 19:21:03,123 - root - INFO - Applied FSDP to the model
[rank1]:[titan] 2025-07-10 19:21:03,491 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank1]:[titan] 2025-07-10 19:21:03,515 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank1]:[titan] 2025-07-10 19:21:03,516 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank1]:[titan] 2025-07-10 19:21:03,517 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-07-10 19:21:03,550 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-1921
[rank0]:[titan] 2025-07-10 19:21:03,551 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank0]:[titan] 2025-07-10 19:21:03,574 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 19:21:03,575 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:[titan] 2025-07-10 19:21:03,576 - root - INFO - Applied selective activation checkpointing to the model
[rank1]:[titan] 2025-07-10 19:21:03,675 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-10 19:21:03,732 - root - INFO - Applied FSDP to the model
[rank2]:[titan] 2025-07-10 19:21:03,813 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank2]:[titan] 2025-07-10 19:21:03,813 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank2]:[titan] 2025-07-10 19:21:03,814 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank2]:[titan] 2025-07-10 19:21:03,817 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Mixed precision training is handled by fully_shard
[rank2]:[titan] 2025-07-10 19:21:03,876 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Training starts at step 1.
[rank2]:[titan] 2025-07-10 19:21:03,877 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank1]:[titan] 2025-07-10 19:21:04,369 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank1]:[titan] 2025-07-10 19:21:04,370 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank1]:[titan] 2025-07-10 19:21:04,373 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank0]:[titan] 2025-07-10 19:21:04,335 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank0]:[titan] 2025-07-10 19:21:04,336 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank0]:[titan] 2025-07-10 19:21:04,340 - root - WARNING - Warmup steps (200) exceed total training steps (2). Adjusting warmup steps to 2.
[rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank1]:[titan] 2025-07-10 19:21:04,430 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank0]:[titan] 2025-07-10 19:21:04,415 - root - INFO - Mixed precision training is handled by fully_shard
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Training starts at step 1.
[rank1]:[titan] 2025-07-10 19:21:04,431 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 2 (warmup 200).
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-10 19:21:04,416 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - step:  1  loss: 12.2520  grad_norm:  4.0543  memory: 42.12GiB(53.23%)  tps: 1,046  tflops: 60.58  mfu: 19.42%
[rank0]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1
[rank0]:[titan] 2025-07-10 19:21:11,408 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step:  1  loss: 12.2520  grad_norm:  4.0543  memory: 42.12GiB(53.23%)  tps: 971  tflops: 56.23  mfu: 18.02%
[rank2]:[titan] 2025-07-10 19:21:11,406 - root - INFO - Calling checkpoint save after step 1
[rank2]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank1]:[titan] 2025-07-10 19:21:11,406 - root - INFO - step:  1  loss: 12.2520  grad_norm:  4.0543  memory: 42.12GiB(53.23%)  tps: 1,038  tflops: 60.13  mfu: 19.27%
[rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Calling checkpoint save after step 1
[rank1]:[titan] 2025-07-10 19:21:11,407 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank2]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Calling checkpoint save after step 2
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291
[rank2]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank0]:[titan] 2025-07-10 19:21:14,015 - root - INFO - Calling checkpoint save after step 2
[rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank0]:[titan] 2025-07-10 19:21:14,016 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - Num keys before parsing 291, after 291
[rank0]:[titan] 2025-07-10 19:21:14,017 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank1]:[titan] 2025-07-10 19:21:14,023 - root - INFO - Calling checkpoint save after step 2
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving the checkpoint (or staging if async is enabled).
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Saving a model weights only checkpoint in torch.float32 at last step, step 2.
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - Num keys before parsing 291, after 291
[rank1]:[titan] 2025-07-10 19:21:14,024 - root - INFO - key tok_embeddings.weight, shape torch.Size([128256, 4096])
[rank0]:Done writing metadata. Took %.2f secs. 0.026559114456176758
[rank0]:Done writing data. Took %.2f secs. 66.62590146064758
[rank0]:Done consolidating. Took %.2f secs. 66.62735033035278
[rank0]:time taken for all reduce:  141.72666668891907
[rank1]:time taken for all reduce:  141.73284125328064
[rank2]:time taken for all reduce:  141.72900009155273
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds.
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Training completed
[rank2]:[titan] 2025-07-10 19:23:36,832 - root - INFO - Destroying the purge thread.
[rank0]:[titan] 2025-07-10 19:23:36,827 - root - INFO - [GC] GC collection invoked by checkpointer. 0.02 seconds.
[rank0]:[titan] 2025-07-10 19:23:36,828 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - [GC] GC collection invoked by checkpointer. 0.03 seconds.
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Training completed
[rank1]:[titan] 2025-07-10 19:23:36,837 - root - INFO - Destroying the purge thread.
[rank2]:[titan] 2025-07-10 19:23:37,243 - root - INFO - Process group destroyed.
[rank0]:[titan] 2025-07-10 19:23:38,828 - root - INFO - Training completed
[rank0]:[titan] 2025-07-10 19:23:38,829 - root - INFO - Destroying the purge thread.
[rank1]:[titan] 2025-07-10 19:23:39,503 - root - INFO - Process group destroyed.
[rank0]:[titan] 2025-07-10 19:23:39,705 - root - INFO - Process group destroyed.
```

Successful load:
```
(titan) [[email protected] /data/users/ankitageorge/torchtitan (dcp-hf)]$ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh
+ NGPU=8
+ export LOG_RANK=0
+ LOG_RANK=0
+ CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml
+ overrides=
+ '[' 0 -ne 0 ']'
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510/
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] 
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] *****************************************
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0710 20:56:16.459000 1982701 site-packages/torch/distributed/run.py:774] *****************************************
[rank0]:[titan] 2025-07-10 20:56:24,765 - root - INFO - Starting job: Llama 3 8B training
[rank0]:NCCL version 2.27.5+cuda12.9
[rank0]:[titan] 2025-07-10 20:56:27,746 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-07-10 20:56:27,748 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-07-10 20:56:27,753 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001
[rank0]:[titan] 2025-07-10 20:56:30,608 - root - INFO - Preparing c4 dataset from allenai/c4
[rank0]:[titan] 2025-07-10 20:56:36,070 - root - INFO - Building llama3 8B with TransformerModelArgs(_enforced='This field is used to enforce all fields have defaults.', dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_seq_len=8192, depth_init=True, use_flex_attn=False, attn_mask_type='causal', eos_id=128001)
[rank0]:[titan] 2025-07-10 20:56:36,430 - root - INFO - TensorBoard logging enabled. Logs will be saved at ./outputs/tb/20250710-2056
[rank0]:[titan] 2025-07-10 20:56:36,431 - root - INFO - CUDA capacity: NVIDIA PG509-210 with 79.14GiB memory
[rank0]:[titan] 2025-07-10 20:56:36,452 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 20:56:36,454 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:[titan] 2025-07-10 20:56:36,455 - root - INFO - Applied selective activation checkpointing to the model
[rank0]:[titan] 2025-07-10 20:56:36,598 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - WARNING - Peak flops undefined for: NVIDIA PG509-210, fallback to A100
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - Peak FLOPS used for computing MFU: 3.120e+14
[rank0]:[titan] 2025-07-10 20:56:37,138 - root - INFO - CUDA memory usage for model: 3.77GiB(4.77%)
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Checkpointing active. Checkpoints will be loaded from and saved to ./outputs/checkpoint
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Trainer is initialized with local batch size 1, global batch size 8, gradient accumulation steps 1, sequence length 8192, total steps 1000 (warmup 200).
[rank0]:[titan] 2025-07-10 20:56:37,190 - root - INFO - Loading the checkpoint from ./outputs/checkpoint/step-3.
[rank0]:/home/ankitageorge/.conda/envs/titan/lib/python3.13/site-packages/torch/distributed/checkpoint/hf_storage.py:259: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. This means you can write to the underlying (supposedly non-writable) buffer using the tensor. You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:1579.)
[rank0]:  tensor = torch.frombuffer(
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - [GC] GC collection for checkpoint loading. 0.01 seconds.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Finished loading the checkpoint in 27.21 seconds.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-10 20:57:04,398 - root - INFO - Profiling active. Traces will be saved at ./outputs/profile_trace
[rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - step:  1  loss: 12.0247  grad_norm: 42.7524  memory: 42.12GiB(53.23%)  tps: 236  tflops: 13.67  mfu: 4.38%
[rank0]:[titan] 2025-07-10 20:57:11,168 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
```

---------

Co-authored-by: ankitageorge <[email protected]>
Co-authored-by: ankitageorge <[email protected]>
Co-authored-by: tianyu-l <[email protected]>
Integrated the validator together with metrics processor for better
metrics logging.

Key changes:
- Metrics processor is passed to validator within training loop
- Validator can reuse metrics processor's built-in functionalities such
as memory profiling, throughput tracking, and tensorboard/wandb logging

This is how the new logging looks from terminal:
<img width="959" height="374" alt="Screenshot 2025-07-14 at 3 22 56 PM"
src="https://github.com/user-attachments/assets/b16a9e00-3ab2-46ed-a42a-0c92d13697cb"
/>
This PR refactors `FTManager` to:
1. simplify construction logic
2. expose simpler interfact to `train.py`
3. make it optional when building optimizer

and some other minor improvements.
Call `pre-commit run --all-files` 

Lint job did not run on #1361?
Changes
1. New helper method to explicitly specify the modules to include
2. Update model.py to handle `None` attributes

Can run the 16B DSV3 on 8 GPUs with PP:
```
NGPU=8 LOG_RANK=0,7 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.pipeline_parallel_degree 8 --parallelism.pipeline_parallel_schedule Interleaved1F1B
```

DP + TP + PP example run:

https://meta.wandb.io/howardhuang_meta/torchtitan/reports/DeepSeekV3-16B-8-GPU-DP-TP-PP---VmlldzozMTAz?accessToken=ltsuxu4atlsmtk5u1g1zt04xcb6q1cs4mm9ianq8mlqlpq4ppm3lfpu1p53ei4pg

TODO:
- upstream `pipeline_module_split` to `torch.distributed.pipelining`?
Also see discussion in #1372

This PR:
- Adds new config for SAC with the default such that per-op SAC
automatically skips all mms with args[1].shape matching that of the
Linear at fqn "moe.router.gate"
- Adds general flop/act-mem/correctness tests for AC as well as the new
config
…uffer (#1403)

As titled. 

Tested on llama4 debugging model (dp=8, ep=2):
<img width="1188" height="226" alt="Screenshot 2025-07-15 at 8 05 12 PM"
src="https://github.com/user-attachments/assets/24a1bf87-b038-481e-b40b-96e2123c96fc"
/>
Summary:
enable creating a separate outer optimizer for each of the parameter
fragments for streaming diloco
## Changes in this diff:
1. Pass softmax_scale to sdpa() forward. 
2. Change some default parameters for debug_model.toml
Summary:
- log the tensorboard for ft replicas to a separate folder
- log profiles for ft replicas to a separate folder

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1410).
* #1411
* __->__ #1410
Summary:
- allow using gloo process group
- add a parameter to the ft config
- only nccl and gloo will be supported for now

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1411).
* __->__ #1411
* #1410
Stacked PRs:
 * __->__#1408


--- --- ---

Remove flex+compile restriction

Recently enabled this w/  pytorch/pytorch#150080
Before the change I got:

```Shell

[rank0]:[rank0]:   File "/home/drisspg/meta/torchtitan/torchtitan/models/llama3/model/args.py", line 48, in update_from_config
[rank0]:[rank0]:     raise ValueError(
[rank0]:[rank0]: ValueError: FlexAttention is not compatible with selective AC yet. See pytorch/pytorch#147879
```

I tried running this locally with;
```
 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" NGPU=1 ./run_train.sh --model.flavor
   debugmodel_flex_attn --training.compile
 ```
 
 Got:
 
<img width="858" height="568" alt="Screenshot 2025-07-16 at 8 25 53 PM" src="https://github.com/user-attachments/assets/f540b01f-62ad-4f81-b5e4-e734e2b09081" />

If there are other more robust testing we can do I am down I am trying to unblock some internal users and ensure I can close this issue: pytorch/pytorch#147879
This PRs adds a new field `StateDictAdapter` in `TrainSpec`
- Currently it only contains a pair of static methods `to_hf` and
`from_hf`. Later we could add other pairs like `to_meta` / `from_meta`,
`to_vllm` / `from_vllm`, etc.
- It is passed to `CheckpointManager` to convert between torchtitan
model and HF model during checkpoint save / load.
- It could also be potentially used by downstream inference engines
which only supports HF models.

In order to save / load in HF format, a model is required to have a
corresponding `StateDictAdapter` subclass implementation. For Llama3, I
created a placeholder `Llama3StateDictAdapter` to be implemented. cc
@wesleytruong


This PR also renames checkpoint config options
`initial_load_model_weights_only` `last_save_model_weights_only` to
simply `initial_load_model_only` `last_save_model_only`, respectively.
It seems to me that the original names were made corresponding to
`torch.load(..., weights_only=True)`. As long as we document & test
clearly when this correspondence holds, I prefer the name in torchtitan
to be simple and less ambiguous.
Stacked PRs:
 * __->__#1437


--- --- ---

Add flex as impl, debugging
#1412

### Running debug model

```
❯ NGPU=8 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.flavor debugmodel_flex_attn
+ NGPU=8
+ export LOG_RANK=0
+ LOG_RANK=0
+ CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml
+ overrides=
+ '[' 2 -ne 0 ']'
+ overrides='--model.flavor debugmodel_flex_attn'
+ TORCHFT_LIGHTHOUSE=http://localhost:29510
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/debug_model.toml --model.flavor debugmodel_flex_attn

*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
[rank0]:[titan] 2025-07-21 15:57:07,261 - root - INFO - Starting job: DeepSeek-V3 debug training
[rank0]:[titan] 2025-07-21 15:57:08,850 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-07-21 15:57:08,852 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-07-21 15:57:08,853 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:NCCL version 2.27.5+cuda12.9
[rank0]:[titan] 2025-07-21 15:57:12,369 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-07-21 15:57:12,371 - root - INFO - Preparing c4_test dataset from tests/assets/c4_test
[rank0]:[titan] 2025-07-21 15:57:12,449 - root - INFO - Building deepseek_v3 debugmodel_flex_attn with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=2048, dtype='bf16', vocab_size=2000, dim=256, inter_dim=1024, moe_inter_dim=256, n_layers=3, n_dense_layers=1, n_heads=16, norm_eps=1e-05, n_routed_experts=8, n_shared_experts=2, n_activated_experts=3, n_expert_groups=1, n_limited_groups=1, score_func='softmax', route_scale=1.0, use_grouped_mm=True, load_balance_coeff=0.001, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=True, attn_mask_type='block_causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7, eos_id=0)
[rank0]:[titan] 2025-07-21 15:57:12,471 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory
[rank0]:[titan] 2025-07-21 15:57:12,593 - root - INFO - Total parameter count: dense 12,479,744, sparse 3,936,256, active 14,449,920
[rank0]:[titan] 2025-07-21 15:57:12,593 - root - INFO - Model deepseek_v3 debugmodel_flex_attn size: 16,416,000 total parameters
[rank0]:[titan] 2025-07-21 15:57:12,615 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-21 15:57:12,858 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-07-21 15:57:12,858 - root - INFO - CUDA memory usage for model: 0.00GiB(0.00%)
[rank0]:[titan] 2025-07-21 15:57:12,859 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-21 15:57:12,859 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 2048, total steps 10 (warmup 2).
[rank0]:[titan] 2025-07-21 15:57:12,859 - root - INFO - Training starts at step 1.
[rank0]:/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/backends/cuda/__init__.py:131: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:78.)
[rank0]:  return torch._C._get_cublas_allow_tf32()
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 361472 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 263168 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 328704 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:AUTOTUNE flex_attention(8x16x2048x192, 8x16x2048x192, 8x16x2048x128, 8x16x2048, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x2048)
[rank0]:strides: [6291456, 192, 3072, 1], [6291456, 192, 3072, 1], [8388608, 256, 4096, 1], [32768, 2048, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [2048, 1]
[rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32
[rank0]:  triton_flex_attention_4 2.7331 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4
[rank0]:  triton_flex_attention_0 5.9175 ms 46.2% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4
[rank0]:  triton_flex_attention_1 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4
[rank0]:  triton_flex_attention_2 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=8
[rank0]:  triton_flex_attention_3 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4
[rank0]:SingleProcess AUTOTUNE benchmarking takes 0.1680 seconds and 2.7997 seconds precompiling for 5 choices
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 248832 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 297984 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 247808 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 247808 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 297984 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 297984 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 347136 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 347136 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:AUTOTUNE flex_attention_backward(8x16x2048x192, 8x16x2048x192, 8x16x2048x128, 8x16x2048, 8x16x2048, 8x16x2048x128, 8x16x2048x192, 8x16x2048x128, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x1x16, 8x1x16x16, 8x2048)
[rank0]:strides: [6291456, 192, 3072, 1], [6291456, 192, 3072, 1], [8388608, 256, 4096, 1], [32768, 2048, 1], [32768, 2048, 1], [4194304, 262144, 128, 1], [6291456, 192, 3072, 1], [4194304, 128, 2048, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [16, 16, 1], [256, 256, 16, 1], [2048, 1]
[rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32
[rank0]:  triton_flex_attention_backward_16 6.9023 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=8
[rank0]:  triton_flex_attention_backward_18 7.2956 ms 94.6% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=8
[rank0]:  triton_flex_attention_backward_20 7.9480 ms 86.8% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=8
[rank0]:  triton_flex_attention_backward_14 8.2263 ms 83.9% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8
[rank0]:  triton_flex_attention_backward_9 10.0881 ms 68.4% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=4
[rank0]:  triton_flex_attention_backward_10 10.2952 ms 67.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4
[rank0]:  triton_flex_attention_backward_11 10.6428 ms 64.9% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=4
[rank0]:  triton_flex_attention_backward_12 11.6085 ms 59.5% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=4
[rank0]:  triton_flex_attention_backward_26 12.3810 ms 55.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8
[rank0]:  triton_flex_attention_backward_33 13.3964 ms 51.5% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=4
[rank0]:SingleProcess AUTOTUNE benchmarking takes 4.0514 seconds and 7.8970 seconds precompiling for 29 choices
[rank0]:[titan] 2025-07-21 15:57:41,765 - root - INFO - step:  1  loss:  7.9132  grad_norm:  2.5572  memory:  0.00GiB(0.00%)  tps: 562  tflops: 0.06  mfu: 0.01%
[rank0]:[titan] 2025-07-21 15:57:41,765 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-21 15:57:42,241 - root - INFO - step:  2  loss:  6.6131  grad_norm:  3.5341  memory:  0.00GiB(0.00%)  tps: 34,387  tflops: 3.52  mfu: 0.36%
[rank0]:[titan] 2025-07-21 15:57:42,774 - root - INFO - step:  3  loss:  4.9925  grad_norm:  2.6422  memory:  0.00GiB(0.00%)  tps: 30,768  tflops: 3.15  mfu: 0.32%
[rank0]:[titan] 2025-07-21 15:57:43,242 - root - INFO - step:  4  loss:  4.6671  grad_norm:  2.5457  memory:  0.00GiB(0.00%)  tps: 35,051  tflops: 3.59  mfu: 0.36%
[rank0]:[titan] 2025-07-21 15:57:43,713 - root - INFO - step:  5  loss:  4.4394  grad_norm:  2.2714  memory:  0.00GiB(0.00%)  tps: 34,824  tflops: 3.57  mfu: 0.36%
[rank0]:[titan] 2025-07-21 15:57:44,183 - root - INFO - step:  6  loss:  4.2412  grad_norm:  2.0883  memory:  0.00GiB(0.00%)  tps: 34,852  tflops: 3.57  mfu: 0.36%
[rank0]:[titan] 2025-07-21 15:57:44,656 - root - INFO - step:  7  loss:  4.0939  grad_norm:  1.9913  memory:  0.00GiB(0.00%)  tps: 34,692  tflops: 3.56  mfu: 0.36%
[rank0]:[titan] 2025-07-21 15:57:45,126 - root - INFO - step:  8  loss:  3.9942  grad_norm:  1.8520  memory:  0.00GiB(0.00%)  tps: 34,879  tflops: 3.58  mfu: 0.36%
[rank0]:[titan] 2025-07-21 15:57:45,616 - root - INFO - step:  9  loss:  4.0337  grad_norm:  1.6383  memory:  0.00GiB(0.00%)  tps: 33,464  tflops: 3.43  mfu: 0.35%
[rank0]:[titan] 2025-07-21 15:57:46,090 - root - INFO - step: 10  loss:  3.9119  grad_norm:  1.6558  memory:  0.00GiB(0.00%)  tps: 34,574  tflops: 3.54  mfu: 0.36%
[rank0]:[titan] 2025-07-21 15:57:46,090 - root - INFO - Sleeping 2 seconds for other ranks to complete
[rank0]:[titan] 2025-07-21 15:57:48,090 - root - INFO - Training completed
[rank0]:[titan] 2025-07-21 15:57:48,407 - r

```





With these changes

On H100

```Shell
NGPU=8 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml ./run_train.sh
+ NGPU=8
+ export LOG_RANK=0
+ LOG_RANK=0
+ CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml
+ overrides=
+ '[' 0 -ne 0 ']'
+ TORCHFT_LIGHTHOUSE=http://localhost:29510
+ PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
*****************************************
[rank0]:[titan] 2025-07-21 14:32:24,245 - root - INFO - Starting job: DeepSeek-V3 16B model training
[rank0]:[titan] 2025-07-21 14:32:25,879 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank0]:[titan] 2025-07-21 14:32:25,880 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[rank0]:[titan] 2025-07-21 14:32:25,882 - root - INFO - [GC] Initial GC collection. 0.00 seconds.
[rank0]:NCCL version 2.27.5+cuda12.9
[rank0]:[titan] 2025-07-21 14:32:29,510 - root - INFO - Loading tokenizer from tokenizer.json
[rank0]:[titan] 2025-07-21 14:32:29,716 - root - INFO - Preparing c4 dataset from allenai/c4
[rank0]:[titan] 2025-07-21 14:32:33,715 - root - INFO - Building deepseek_v3 16B with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=4096, dtype='bf16', vocab_size=128815, dim=2048, inter_dim=10944, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, n_heads=16, norm_eps=1e-05, n_routed_experts=64, n_shared_experts=2, n_activated_experts=6, n_expert_groups=1, n_limited_groups=1, score_func='softmax', route_scale=1.0, use_grouped_mm=True, load_balance_coeff=0.001, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=True, attn_mask_type='causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7, eos_id=0)
[rank0]:[titan] 2025-07-21 14:32:33,892 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory
[rank0]:[titan] 2025-07-21 14:32:33,946 - root - INFO - Total parameter count: dense 966,581,760, sparse 14,848,098,304, active 2,769,346,048
[rank0]:[titan] 2025-07-21 14:32:33,946 - root - INFO - Model deepseek_v3 16B size: 15,814,680,064 total parameters
[rank0]:[titan] 2025-07-21 14:32:33,947 - root - INFO - Applied full activation checkpointing to the model
[rank0]:[titan] 2025-07-21 14:32:34,034 - root - INFO - Applied FSDP to the model
[rank0]:[titan] 2025-07-21 14:32:34,431 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank0]:[titan] 2025-07-21 14:32:34,431 - root - INFO - CUDA memory usage for model: 0.00GiB(0.00%)
[rank0]:[titan] 2025-07-21 14:32:34,433 - root - INFO - Mixed precision training is handled by fully_shard
[rank0]:[titan] 2025-07-21 14:32:34,433 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 4096, total steps 1000 (warmup 200).
[rank0]:[titan] 2025-07-21 14:32:34,433 - root - INFO - Training starts at step 1.
[rank0]:/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/backends/cuda/__init__.py:131: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:78.)
[rank0]:  return torch._C._get_cublas_allow_tf32()
[rank0]:Exception No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_drisspg/uu/cuuzcmwjztg3bl3ujzslc4ma26j6dinag6nocsnpg3dtmquqyno2.py, BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=8)
[rank0]:Traceback (most recent call last):
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/concurrent/futures/thread.py", line 59, in run
[rank0]:    result = self.fn(*self.args, **self.kwargs)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 2598, in precompile_with_captured_stdout
[rank0]:    choice.precompile()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1875, in precompile
[rank0]:    self.bmreq.precompile()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/autotune_process.py", line 657, in precompile
[rank0]:    getattr(mod, self.kernel_name).precompile()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile
[rank0]:    self._make_launchers()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers
[rank0]:    raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
[rank0]:RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.
[rank0]:Exception No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 327680 Hardware limit:232448 Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_drisspg/rg/crg27sodhizvda2jjyznte4u2sxyexstosyzgyfj2lfuekd2u7fx.py, BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4)
[rank0]:Traceback (most recent call last):
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/concurrent/futures/thread.py", line 59, in run
[rank0]:    result = self.fn(*self.args, **self.kwargs)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 2598, in precompile_with_captured_stdout
[rank0]:    choice.precompile()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1875, in precompile
[rank0]:    self.bmreq.precompile()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/autotune_process.py", line 657, in precompile
[rank0]:    getattr(mod, self.kernel_name).precompile()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile
[rank0]:    self._make_launchers()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers
[rank0]:    raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
[rank0]:RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 327680 Hardware limit:232448 Reducing block sizes or `num_stages` may help.
[rank0]:Exception No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 360448 Hardware limit:232448 Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/tmp/torchinductor_drisspg/7c/c7cznn55lunqd7ln454lhdbpef3s6vnbrkieghrh4l74inbpmkgo.py, BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4)
[rank0]:Traceback (most recent call last):
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/concurrent/futures/thread.py", line 59, in run
[rank0]:    result = self.fn(*self.args, **self.kwargs)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 2598, in precompile_with_captured_stdout
[rank0]:    choice.precompile()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/select_algorithm.py", line 1875, in precompile
[rank0]:    self.bmreq.precompile()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/autotune_process.py", line 657, in precompile
[rank0]:    getattr(mod, self.kernel_name).precompile()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 440, in precompile
[rank0]:    self._make_launchers()
[rank0]:  File "/home/drisspg/.conda/envs/nightly/lib/python3.12/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 608, in _make_launchers
[rank0]:    raise RuntimeError(f"No valid triton configs. {type(exc).__name__}: {exc}")
[rank0]:RuntimeError: No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 360448 Hardware limit:232448 Reducing block sizes or `num_stages` may help.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 360448 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 262144 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention Required: 327680 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:AUTOTUNE flex_attention(8x16x4096x192, 8x16x4096x192, 8x16x4096x128, 8x16x4096, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32)
[rank0]:strides: [12582912, 192, 3072, 1], [12582912, 192, 3072, 1], [16777216, 256, 4096, 1], [65536, 4096, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1]
[rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.int32, torch.int32, torch.int32, torch.int32
[rank0]:  triton_flex_attention_4 9.7972 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4
[rank0]:  triton_flex_attention_0 18.6366 ms 52.6% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4
[rank0]:  triton_flex_attention_1 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4
[rank0]:  triton_flex_attention_2 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=128, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=8
[rank0]:  triton_flex_attention_3 inf ms 0.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M=64, BLOCK_N=128, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, USE_TMA=False, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4
[rank0]:SingleProcess AUTOTUNE benchmarking takes 0.2523 seconds and 4.4028 seconds precompiling for 5 choices
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 247808 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 296960 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 246784 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 246784 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 296960 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 296960 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 346112 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:Runtime error during autotuning: 
[rank0]:No valid triton configs. OutOfMemoryError: out of resource: triton_flex_attention_backward Required: 346112 Hardware limit:232448 Reducing block sizes or `num_stages` may help.. 
[rank0]:Ignoring this choice.
[rank0]:AUTOTUNE flex_attention_backward(8x16x4096x192, 8x16x4096x192, 8x16x4096x128, 8x16x4096, 8x16x4096, 8x16x4096x128, 8x16x4096x192, 8x16x4096x128, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32, 1x1x32, 1x1x32x32)
[rank0]:strides: [12582912, 192, 3072, 1], [12582912, 192, 3072, 1], [16777216, 256, 4096, 1], [65536, 4096, 1], [65536, 4096, 1], [8388608, 524288, 128, 1], [12582912, 192, 3072, 1], [8388608, 128, 2048, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1], [32, 32, 1], [1024, 1024, 32, 1]
[rank0]:dtypes: torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.float32, torch.float32, torch.bfloat16, torch.bfloat16, torch.bfloat16, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32, torch.int32
[rank0]:  triton_flex_attention_backward_16 28.0839 ms 100.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=8
[rank0]:  triton_flex_attention_backward_18 28.5011 ms 98.5% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=8
[rank0]:  triton_flex_attention_backward_20 30.6411 ms 91.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=8
[rank0]:  triton_flex_attention_backward_14 31.6448 ms 88.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8
[rank0]:  triton_flex_attention_backward_10 36.6753 ms 76.6% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=3, num_warps=4
[rank0]:  triton_flex_attention_backward_11 39.1248 ms 71.8% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=4, num_warps=4
[rank0]:  triton_flex_attention_backward_9 39.5692 ms 71.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=4
[rank0]:  triton_flex_attention_backward_12 44.0589 ms 63.7% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=32, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=32, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=5, num_warps=4
[rank0]:  triton_flex_attention_backward_26 44.6748 ms 62.9% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=128, BLOCK_N1=128, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=1, num_warps=8
[rank0]:  triton_flex_attention_backward_33 49.2283 ms 57.0% BLOCKS_ARE_CONTIGUOUS=False, BLOCK_M1=64, BLOCK_M2=64, BLOCK_N1=64, BLOCK_N2=64, FLOAT32_PRECISION="'ieee'", GQA_SHARED_HEADS=1, HAS_FULL_BLOCKS=True, IS_DIVISIBLE=True, OUTPUT_LOGSUMEXP=True, PRESCALE_QK=False, QK_HEAD_DIM=192, QK_HEAD_DIM_ROUNDED=256, ROWS_GUARANTEED_SAFE=False, SAFE_HEAD_DIM=False, SM_SCALE=0.07216878364870322, SPARSE_KV_BLOCK_SIZE=128, SPARSE_Q_BLOCK_SIZE=128, V_HEAD_DIM=128, V_HEAD_DIM_ROUNDED=128, WRITE_DQ=True, num_stages=2, num_warps=4
[rank0]:SingleProcess AUTOTUNE benchmarking takes 13.1756 seconds and 7.4547 seconds precompiling for 29 choices
[rank0]:[titan] 2025-07-21 14:33:33,285 - root - INFO - step:  1  loss: 12.2446  grad_norm:  1.2270  memory:  0.00GiB(0.00%)  tps: 552  tflops: 9.80  mfu: 0.99%
[rank0]:[titan] 2025-07-21 14:33:33,286 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-21 14:36:41,972 - root - INFO - step: 10  loss: 11.4945  grad_norm:  1.5629  memory:  0.00GiB(0.00%)  tps: 1,563  tflops: 27.74  mfu: 2.81%
[rank0]:[titan] 2025-07-21 14:40:12,773 - root - INFO - step: 20  loss:  9.8251  grad_norm:  7.4642  memory:  0.00GiB(0.00%)  tps: 1,554  tflops: 27.59  mfu: 2.79%
[rank0]:[titan] 2025-07-21 14:43:42,445 - root - INFO - step: 30  loss:  8.9616  grad_norm:  2.4638  memory:  0.00GiB(0.00%)  tps: 1,563  tflops: 27.74  mfu: 2.81%
...

```
based off of @fegin previous PR:
#907

This sets up CI for torchft in torchtitan, it only runs when
`torchtitan/components/ft.py` is changed, but we can expand it as
necessary. This also makes it easier to run torchft tests / configs by
just calling:

`python tests/integration_tests_ft.py --test_id ...`

TODO:
- There is an issue with our CI where we cannot set
`CUDA_VISIBLE_DEVICES` to partition the devices per replica group (I get
error Cuda failure 217 'peer access is not supported between these two
devices'), as a result this PR only runs with 1 replica group. In follow
up PRs need to look into manually setting the cuda device when using
torchft?
Without this PR, FlexAttention mask function will receive a mixed of
plain tensors and DTensors.

Given that FlexAttention + DTensor story is not clear, let's always
convert to plain tensors when feeding things into FlexAttention / SDPA.
Previously we initialize ModelArgs, and then update it dynamically to
1. get `vocab_size` from tokenizer (used for specifying shapes of
embedding / output module)
2. get `eos_id` from tokenizer (used for generating block causal
attention mask)
3. update `max_seq_len` according to training job `seq_len` (used for
precomputing `freqs_cis`)

-------------
1 is causing troubles as we found `vocab_size` for model checkpoints
(embedding / output layer) loaded from HF may not be always the same as
`tokenizer.get_vocab_size()`. In fact, as long as `vocab_size` in
embedding / output layer is larger than tokenizer's `vocab_size`, the
training is still OK.

In addition, there have been requests to not let `ModeArgs` and model
init depend on tokenizers, so this PR removes 2 and instead send
`eos_id` as input to model.

----------------

For 3, there is a caveat that when torchtitan is used for continuing
training from a checkpoint, users should be aware that the original
model has an intrinsic limit on `max_seq_len`. E.g. for llama 3 it's 8k,
for llama 4 Scout it's 1M. Currently torchtitan users could break this
limit by specifying an arbitrarily large `--training.seq_len`, whether
intentionally or not.

This PR keeps this flexibility for two reasons:
1. when `seq_len` is less than intrinsic `max_seq_len`, we only need to
generate `freqs_cis` of `seq_len` both because of resource consideration
and because of CP compatibility.
2. when users intentionally want to test long context training / extend
to longer context capability, they could still do that.

Instead, this PR adds a warning when getting a `seq_len` larger than the
original `max_seq_len`

I noticed that llama "official" implementation also allows this
https://github.com/meta-llama/llama-models/blob/main/models/llama4/generation.py#L72
This PR takes job_config out of the CheckpointManager class. Why?
JobConfig is a monolith -- it has knowledge of every part of a titan
training job. As a result, it is hard to actually use CheckpointManager
in a standalone fashion. In practice the job config is mostly only used
for its checkpoint config, plus two other usages as far as I can tell:

1) Getting the replica_id from the FTManager
2) Taking the dump_folder from the job field and joining it with the
checkpoint folder

For (1) we can just get this directly from FTManager without accessing
the JobConfig field. For (2) we can pass `job_config.job.dump_folder`
explicitly as a base folder, then join to `checkpoint_config.folder`.
Personally I would try to consolidate `job.dump_folder` and
`checkpoint.folder` (though I understand there are cases where only the
former is needed) under Checkpoint, but not sure if this is preferable
from titan's pov.
This PR creates a new folder `torchtitan/config` to host `job_config.py`
and `manager.py`, for the reasons below:
- Both are complicated enough to worth their own files.
- The convention in torchtitan to extend custom `JobConfig` is to create
a file under model folder called `job_config.py` (see
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/job_config.py).
This PR makes the origin `JobConfig` consistent with that convention.
- (minor) Creating a more succinct `torchtitan.config` namespace is more
readable than importing from `torchtitan.config_manager`.
lckr and others added 24 commits August 19, 2025 13:35
In this PR, I'm updated the outdated doc string for DeepSeekV3ModelArgs
## Context

As PP didn't need persistent buffer, and `torch.compile` works with
non-persistent buffer now, change freq_cis from persistent buffer to
non-persistent buffer . In this way, checkpointer doesn't need to
explicitly exclude freq_cis when loading.


## Test
1. llama3 model with torch.compile ✅
2. llama4 model with torch.compile ✅
3. deepseek-v3 model with torch.compile ✅
This PR adds the `FluxStateDictAdapter`, allowing us to convert
checkpoints to and from HF.

Additional changes:
- Modifies `download_hf_assets` script to support downloading
diffusion-type safetensor files
- Registers Flux's `TrainSpec` in `convert_from_hf` and `convert_to_hf`
so that conversion script can be reused
- e.g. `python ./scripts/checkpoint_conversion/convert_from_hf.py
./assets/hf/FLUX.1-dev/transformer ./outputs/temp --model_name flux
--model_flavor flux-dev`

Tests:
Performing KL divergence test on the forward pass of converted weights
loaded in `torchtitan` and HF weights loaded with HF
`FluxTransformer2DModel`, we get:
```
Average loss for test from_hf is 7.233546986222528e-13
```

Addiitonally, we can now run inference with HF weights to verify changes
made in #1548

### Batched Inference on TorchTitan:
| | prompt0 | prompt1 | prompt2 |
| --- | --- | --- | --- |
| no CFG | <img width="1024" height="1024" alt="prompt0_nocfg"
src="https://github.com/user-attachments/assets/421fab49-239a-4ca2-b51a-16823d89acfd"
/> | <img width="1024" height="1024" alt="prompt1_nocfg"
src="https://github.com/user-attachments/assets/534b557e-7b93-4f2e-b3b3-3a0c7cf57c40"
/> | <img width="1024" height="1024" alt="prompt2_nocfg"
src="https://github.com/user-attachments/assets/d0f33526-f95d-47db-b5a6-6200bfa151f9"
/> |
| CFG | <img width="1024" height="1024" alt="prompt0_cfg"
src="https://github.com/user-attachments/assets/83234675-eb47-4785-abe1-0f07dd854f1c"
/> | <img width="1024" height="1024" alt="prompt1_cfg"
src="https://github.com/user-attachments/assets/5e76f3e7-0ca3-47a4-a0ef-3c7e983e8c2c"
/> | <img width="1024" height="1024" alt="prompt2_cfg"
src="https://github.com/user-attachments/assets/c8cbe367-d96e-4559-a201-48e8dc3d18ee"
/> |
This PR removes the `convert_from_llama.py` script since it is
superseded by `convert_from_hf.py` instead.
Creating a new field in `JobConfig`, with the default being
```
[compile]
enable=false
components = ["model", "loss"]
```

This way we get to compile loss separately to get memory reduction, even
when the model is not ready to be compiled.

This PR also applies loss compilation to DeepSeek 16B and 671B.
Tested Loading weights from
https://huggingface.co/deepseek-ai/DeepSeek-V3.1-Base

<img width="1296" height="605" alt="Screenshot 2025-08-20 at 10 28
20 PM"
src="https://github.com/user-attachments/assets/cc5bc9ef-0afd-45c9-bdf6-7cf36d9729e8"
/>
We want to include this PR in our next release ASAP. Created another
branch and revert CODE_OF_CONDUCT.md from @BioGeek 's #1583 . Much
appreciated for @BioGeek's contribution!

---------

Co-authored-by: Jeroen Van Goey <[email protected]>
We put all experts' usage into a buffer such that we only need one
reduce rather than #number-of-layers times

Additionally, handle cases where tokens per expert are counted twice
during full recompute.

Co-authored-by: wang55 <[email protected]>
Adds a `ModelProtocol.get_extra_metrics` method for more flexible custom
metric reporting, as discussed in #1576

Probably this should be an abstract method, but I was wary of making
this a breaking change for users who inherit this commit.

The current signature is `get_extra_metrics(self, parallel_dims:
ParallelDims) -> None | dict`. I also considered adding some subset of
`JobConfig`, `TrainSpec`, and `pp_has_{first,last}_stage`; not sure what
else might be useful.

Tested via running the debugmodel with print statements. 

CC @rakkit @wwwjn
One perspective on the attention mask is that it should be coupled with
the dataloader rather than the modeling component. Therefore, this PR
moves the creation of the attention mask to the trainer, removing it
from the model itself.

This PR also fixes #1612


```
-> % LOG_RANK=6 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --training.steps=100 --parallelism.pipeline_parallel_degree=4

+ NGPU=8
+ export LOG_RANK=6
+ LOG_RANK=6
+ CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml
+ TORCHFT_LIGHTHOUSE=http://localhost:29510
+ PYTORCH_ALLOC_CONF=expandable_segments:True
+ TORCHFT_LIGHTHOUSE=http://localhost:29510
+ torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 6 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml --training.steps=100 --parallelism.pipeline_parallel_degree=4
W0821 10:14:43.689000 1062175 torch/distributed/run.py:803]
W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] *****************************************
W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] *****************************************
[rank6]:[titan] 2025-08-21 10:14:50,681 - root - INFO - Starting job: DeepSeek-V3 16B model training
[rank6]:[titan] 2025-08-21 10:14:53,248 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config
[rank6]:[titan] 2025-08-21 10:14:53,250 - root - INFO - Building 2-D device mesh with ['pp', 'dp_shard'], [4, 2]
[rank6]:[titan] 2025-08-21 10:14:53,265 - root - INFO - [GC] Initial GC collection 0.00 seconds
[rank6]:[titan] 2025-08-21 10:14:56,937 - root - INFO - Loading tokenizer from tokenizer.json
[rank6]:[titan] 2025-08-21 10:14:57,076 - root - INFO - Preparing c4 dataset from allenai/c4
[rank6]:[titan] 2025-08-21 10:15:00,743 - root - INFO - Building deepseek_v3 16B with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=4096, vocab_size=102400, dim=2048, inter_dim=10944, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, n_heads=16, norm_eps=1e-05, moe_args=MoEArgs(num_experts=64, num_shared_experts=2, score_func='softmax', route_norm=True, route_scale=1.0, score_before_experts=False, top_k=6, use_grouped_mm=True, load_balance_coeff=0.001), n_expert_groups=1, n_limited_groups=1, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=True, attn_mask_type='block_causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7)
[rank6]:[titan] 2025-08-21 10:15:00,966 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory
[rank6]:[titan] 2025-08-21 10:15:01,008 - root - INFO - Total parameter count: dense 858,385,920, sparse 14,848,098,304, active 2,661,150,208
[rank6]:[titan] 2025-08-21 10:15:01,008 - root - INFO - Model deepseek_v3 16B size: 15,706,484,224 total parameters
[rank6]:Stage 3: Modules to keep: {'layers.14', 'layers.13', 'layers.11', 'layers.12'}
[rank6]:Stage 7: Modules to keep: {'output', 'norm', 'layers.26', 'layers.25'}
[rank6]:[titan] 2025-08-21 10:15:01,029 - root - INFO - PP rank 3 is building stage_idx 3 with modules ['layers.11', 'layers.12', 'layers.13', 'layers.14']
[rank6]:[titan] 2025-08-21 10:15:01,048 - root - INFO - PP rank 3 is building stage_idx 7 with modules ['layers.25', 'layers.26', 'norm', 'output']
[rank6]:[titan] 2025-08-21 10:15:01,048 - root - INFO - Applied full activation checkpointing to the model
[rank6]:[titan] 2025-08-21 10:15:01,072 - root - INFO - Applied FSDP to the model

[rank6]:[titan] 2025-08-21 10:15:01,072 - root - INFO - Applied full activation checkpointing to the model
[rank6]:[titan] 2025-08-21 10:15:01,080 - root - INFO - Applied FSDP to the model
[rank6]:[titan] 2025-08-21 10:15:01,080 - root - INFO - Using pipeline schedule Interleaved1F1B with 8 microbatches and 8 stages.
[rank6]:[titan] 2025-08-21 10:15:01,488 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14
[rank6]:[titan] 2025-08-21 10:15:01,488 - root - INFO - CUDA memory usage for model: 6.94GiB(7.31%)
[rank6]:[titan] 2025-08-21 10:15:01,489 - root - WARNING - Warmup steps (200) exceed total training steps (100). Adjusting warmup steps to 100.
[rank6]:[titan] 2025-08-21 10:15:01,489 - root - WARNING - Warmup (100) + decay (80) steps exceed total
training steps (100). Adjusting decay steps to 0.
[rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Mixed precision training is handled by fully_shard
[rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Trainer is initialized with local batch size 8,
global batch size 16, gradient accumulation steps 1, sequence length 4096, total steps 100 (warmup 200)
[rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Training starts at step 1
[rank6]:[rank6]:[W821 10:15:10.781655306 ProcessGroupNCCL.cpp:3993] Warning: An unbatched P2P op (send/recv) was called on this ProcessGroup with size 4.  In lazy initialization mode, this will result in a new 2-rank NCCL communicator to be created. (function operator())
[rank6]:NCCL version 2.27.5+cuda12.6
[rank6]:[rank6]:[W821 10:15:16.977607954 ProcessGroupNCCL.cpp:3993] Warning: An unbatched P2P op (send/recv) was called on this ProcessGroup with size 4.  In lazy initialization mode, this will result in a new 2-rank NCCL communicator to be created. (function operator())
[rank6]:/data/users/chienchin/mywork/pytorch/torch/__init__.py:1539: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 =
True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated
after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /data/users/chienchin/mywork/pytorch/aten/src/ATen/Context.cpp:80.)
[rank6]:  return _C._get_float32_matmul_precision()
[rank6]:[titan] 2025-08-21 10:15:28,674 - root - INFO - step:  1  loss: 12.0194  grad_norm:  1.8958  memory: 53.94GiB(56.78%)  tps: 296  tflops: 5.16  mfu: 0.52%
[rank6]:[titan] 2025-08-21 10:15:28,674 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
no^[^[[rank6]:[titan] 2025-08-21 10:15:43,154 - root - INFO - step: 10  loss: 10.3629  grad_norm:  3.0762  memory: 67.11GiB(70.64%)  tps: 5,092  tflops: 88.73  mfu: 8.97%
[rank6]:[titan] 2025-08-21 10:15:59,017 - root - INFO - step: 20  loss:  8.9238  grad_norm:  2.5020  memory: 67.11GiB(70.64%)  tps: 5,165  tflops: 90.00  mfu: 9.10%
[rank6]:[titan] 2025-08-21 10:16:15,051 - root - INFO - step: 30  loss:  7.8167  grad_norm:  1.7460  memory: 67.11GiB(70.64%)  tps: 5,109  tflops: 89.04  mfu: 9.00%
[rank6]:[titan] 2025-08-21 10:16:31,989 - root - INFO - step: 40  loss:  7.1761  grad_norm:  1.1432  memory: 67.11GiB(70.64%)  tps: 4,837  tflops: 84.29  mfu: 8.52%
[rank6]:[titan] 2025-08-21 10:16:48,455 - root - INFO - step: 50  loss:  6.7850  grad_norm:  1.4950  memory: 67.11GiB(70.64%)  tps: 4,975  tflops: 86.70  mfu: 8.77%
[rank6]:[titan] 2025-08-21 10:17:04,602 - root - INFO - step: 60  loss:  6.8310  grad_norm:  1.2972  memory: 67.11GiB(70.64%)  tps: 5,074  tflops: 88.42  mfu: 8.94%
[rank6]:[titan] 2025-08-21 10:17:22,231 - root - INFO - step: 70  loss:  6.6627  grad_norm:  1.1630  memory: 67.11GiB(70.64%)  tps: 4,647  tflops: 80.98  mfu: 8.19%
[rank6]:[titan] 2025-08-21 10:17:41,358 - root - INFO - step: 80  loss:  6.3542  grad_norm:  0.8215  memory: 67.11GiB(70.64%)  tps: 4,283  tflops: 74.64  mfu: 7.55%
[rank6]:[titan] 2025-08-21 10:17:58,336 - root - INFO - step: 90  loss:  6.4442  grad_norm:  1.2542  memory: 67.11GiB(70.64%)  tps: 4,825  tflops: 84.09  mfu: 8.50%
[rank6]:[titan] 2025-08-21 10:18:12,542 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds
[rank6]:[titan] 2025-08-21 10:18:14,566 - root - INFO - step: 100  loss:  6.7519  grad_norm:  1.3966  memory: 67.11GiB(70.64%)  tps: 5,048  tflops: 87.97  mfu: 8.89%
[rank6]:[titan] 2025-08-21 10:18:14,566 - root - INFO - Training completed
[rank6]:[titan] 2025-08-21 10:18:17,159 - root - INFO - Process group destroyed
```
Currently, the only available backend for SDPA for DeepSeekV3 is
efficient attention kernel. For FlashAttentionV2 (what current SDPA
supports), the V embedding dimension must be the same as Q and K. For
cuDNN attention, it is complaining the head dimension is too large.

The reason for defaulting the attention to SDPA in TorchTitan is that
FlexCP is not yet ready. However, the combination of SDPA + CP +
DeepSeekV3 is also not functional.

This PR updates all DeepSeekV3 configurations to use FlexAttention,
which significantly improves the overall performance. **Document masking
also contributes to MFU improvement, but the majority is from
FlexAttention itself**.

```
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --training.steps=100 --parallelism.expert_parallel_degree=8
```

SDPA:

```
[rank0]:[titan] 2025-08-20 18:28:42,047 - root - INFO - Trainer is initialized with local batch size 8, global batch size 64, gradient accumulation steps 1, sequence length 4096, total steps 100 (warmup 200)
[rank0]:[titan] 2025-08-20 18:28:42,047 - root - INFO - Training starts at step 1
[rank0]:[titan] 2025-08-20 18:29:04,053 - root - INFO - step:  1  loss: 12.0401  grad_norm:  1.7464  memory: 63.55GiB(66.89%)  tps: 1,416  tflops: 24.67  mfu: 2.49%
[rank0]:[titan] 2025-08-20 18:29:04,053 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-08-20 18:29:46,138 - root - INFO - step: 10  loss: 10.3087  grad_norm:  3.1896  memory: 78.14GiB(82.25%)  tps: 7,008  tflops: 122.12  mfu: 12.35%
[rank0]:[titan] 2025-08-20 18:30:33,628 - root - INFO - step: 20  loss:  8.7601  grad_norm:  2.5195  memory: 78.14GiB(82.25%)  tps: 6,900  tflops: 120.24  mfu: 12.16%
[rank0]:[titan] 2025-08-20 18:31:22,497 - root - INFO - step: 30  loss:  7.7450  grad_norm:  1.9296  memory: 78.14GiB(82.25%)  tps: 6,705  tflops: 116.85  mfu: 11.82%
[rank0]:[titan] 2025-08-20 18:32:19,709 - root - INFO - step: 40  loss:  6.9795  grad_norm:  0.6893  memory: 78.14GiB(82.25%)  tps: 5,728  tflops: 99.81  mfu: 10.09%
[rank0]:[titan] 2025-08-20 18:33:34,343 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds
[rank0]:[titan] 2025-08-20 18:33:43,863 - root - INFO - step: 50  loss:  6.8381  grad_norm:  1.1848  memory: 78.14GiB(82.25%)  tps: 3,894  tflops: 67.86  mfu: 6.86%
[rank0]:[titan] 2025-08-20 18:34:37,289 - root - INFO - step: 60  loss:  6.5727  grad_norm:  0.9871  memory: 78.14GiB(82.25%)  tps: 6,133  tflops: 106.88  mfu: 10.81%
[rank0]:[titan] 2025-08-20 18:35:27,959 - root - INFO - step: 70  loss:  6.5041  grad_norm:  1.5895  memory: 78.14GiB(82.25%)  tps: 6,467  tflops: 112.70  mfu: 11.40%
[rank0]:[titan] 2025-08-20 18:36:16,732 - root - INFO - step: 80  loss:  6.3179  grad_norm:  0.9556  memory: 78.14GiB(82.25%)  tps: 6,719  tflops: 117.08  mfu: 11.84%
[rank0]:[titan] 2025-08-20 18:37:05,604 - root - INFO - step: 90  loss:  6.2124  grad_norm:  0.8286  memory: 78.14GiB(82.25%)  tps: 6,705  tflops: 116.85  mfu: 11.81%
[rank0]:[titan] 2025-08-20 18:37:49,285 - root - INFO - [GC] Peforming periodical GC collection 0.04 seconds
[rank0]:[titan] 2025-08-20 18:37:54,361 - root - INFO - step: 100  loss:  6.2596  grad_norm:
1.5143  memory: 78.14GiB(82.25%)  tps: 6,721  tflops: 117.12  mfu: 11.84%
[rank0]:[titan] 2025-08-20 18:37:54,361 - root - INFO - Sleeping 2 seconds for other ranks to
complete
[rank0]:[titan] 2025-08-20 18:37:56,364 - root - INFO - Training completed
[rank0]:[titan] 2025-08-20 18:37:57,535 - root - INFO - Process group destroyed
```

FlexAttention (now)
```
[rank0]:/data/users/chienchin/mywork/pytorch/torch/__init__.py:1539: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /data/users/chienchin/mywork/pytorch/aten/src/ATen/Context.cpp:80.)
[rank0]:  return _C._get_float32_matmul_precision()
[rank0]:[titan] 2025-08-20 22:16:59,699 - root - INFO - step:  1  loss: 11.9984  grad_norm:  1.7288  memory: 63.55GiB(66.89%)  tps: 727  tflops: 12.67  mfu: 1.28%
[rank0]:[titan] 2025-08-20 22:16:59,699 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-08-20 22:17:32,228 - root - INFO - step: 10  loss: 10.3101  grad_norm:  2.9111  memory: 78.14GiB(82.25%)  tps: 9,066  tflops: 157.99  mfu: 15.97%
[rank0]:[titan] 2025-08-20 22:18:08,957 - root - INFO - step: 20  loss:  8.7431  grad_norm:  2.5391  memory: 78.14GiB(82.25%)  tps: 8,922  tflops: 155.47  mfu: 15.72%
[rank0]:[titan] 2025-08-20 22:18:46,981 - root - INFO - step: 30  loss:  7.7133  grad_norm:  1.7743  memory: 78.14GiB(82.25%)  tps: 8,618  tflops: 150.18  mfu: 15.19%
[rank0]:[titan] 2025-08-20 22:19:26,672 - root - INFO - step: 40  loss:  6.9643  grad_norm:  0.7227  memory: 78.14GiB(82.25%)  tps: 8,256  tflops: 143.88  mfu: 14.55%
[rank0]:[titan] 2025-08-20 22:20:01,975 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds
[rank0]:[titan] 2025-08-20 22:20:06,015 - root - INFO - step: 50  loss:  6.8046  grad_norm:  1.0556  memory: 78.14GiB(82.25%)  tps: 8,329  tflops: 145.15  mfu: 14.68%
[rank0]:[titan] 2025-08-20 22:20:45,784 - root - INFO - step: 60  loss:  6.5364  grad_norm:  1.7141  memory: 78.14GiB(82.25%)  tps: 8,240  tflops: 143.59  mfu: 14.52%
[rank0]:[titan] 2025-08-20 22:21:25,078 - root - INFO - step: 70  loss:  6.4709  grad_norm:  1.2385  memory: 78.14GiB(82.25%)  tps: 8,340  tflops: 145.33  mfu: 14.69%
[rank0]:[titan] 2025-08-20 22:22:03,088 - root - INFO - step: 80  loss:  6.2786  grad_norm:  2.2534  memory: 78.14GiB(82.25%)  tps: 8,621  tflops: 150.24  mfu: 15.19%
[rank0]:[titan] 2025-08-20 22:22:41,254 - root - INFO - step: 90  loss:  6.1441  grad_norm:  0.6878  memory: 78.14GiB(82.25%)  tps: 8,586  tflops: 149.62  mfu: 15.13%
[rank0]:[titan] 2025-08-20 22:23:15,059 - root - INFO - [GC] Peforming periodical GC collection 0.05 seconds
[rank0]:[titan] 2025-08-20 22:23:19,063 - root - INFO - step: 100  loss:  6.1348  grad_norm:
1.2875  memory: 78.14GiB(82.25%)  tps: 8,667  tflops: 151.04  mfu: 15.27%
[rank0]:[titan] 2025-08-20 22:23:19,064 - root - INFO - Sleeping 2 seconds for other ranks to
complete
[rank0]:[titan] 2025-08-20 22:23:21,065 - root - INFO - Training completed
[rank0]:[titan] 2025-08-20 22:23:22,436 - root - INFO - Process group destroyed
```
This PR addresses duplicated code related to enabling async TP across
different parts of the codebase. It introduces a new API,
`maybe_enable_async_tp()`, which centralizes the enablement logic and is
reused consistently in all models.

Note that while this PR fixes one async TP bug in TorchTitan, it does
not fully resolve #1613, as
there appear to be additional bugs in PyTorch's async TP implementation.
This PR makes several miscellaneous refactors to clean up `torchtitan`
before release.

Changes:
- Sets each of `model_parts` to eval mode in `Validator` class to
support PP (Bug fix)
- Refactor `checkpoint.enable_checkpoint -> checkpoint.enable`
(Refactor)
- Refacotr `validation.enabled -> validation.enable` (Refactor)
follow up of #1619 to fix
remaining errors.

also fixing a TODO
…1627)

The DataloaderStopIteration exception inherited from StopIteration.
According to PEP 479, raising a StopIteration subclass from a generator
causes a RuntimeError in Python 3.7+.

This change modifies the base class to `Exception` to ensure it can be
caught correctly by user code without triggering this behavior.

Fixes ISSUE #1626
…1633)

As titled. Only enable weight tying for smaller model
self.score_function should be self.score_func
Add some basic documentation on how to use Titan with TorchFt for
DiLoCo.

LMK if anything needs clarification @vishal9-team
In this PR, I'm adding the StateDictAdapter for Qwen3 to enable loading
HF checkpoints. We can use this script to adapt the checkpoint from HF
to the format that we can load into the torchtitan model and vice versa.
This can enable us to do a parity test with the HF implementation and
make sure that our results are aligned with the HF implementation.

---------

Co-authored-by: Hossein Kavianihamedani <[email protected]>
Summary: Allow user to configure WandB entity (aka team) and run name
through environment variables.

Differential Revision: D80499210
As the
[discussion](#1618 (comment)),
I added:
- warning message when the validation `steps`=-1 in both comment and
logger
- change the default `steps` to reasonable values with the common setup
(world size = 8).
- Add infinite loop support for validator to avoid hang when `steps` is
large enough to exhaust the dataset.
- Add the same fix for flux.


## Test
- 8 GPUs with `steps=-1`: hang around step 1270
- 8 GPUs with `steps=1200`: good
- 8 GPUs with `steps=1500`: `infinite` automatically set to true.
Exhaust the dataset and re-iterate, but won't hang
- Flux: `steps=-1` doesn't hang
- Flux: `steps=60` doesn't hang; re-loop the dataset.

Full thread: #1618

cc @ebsmothers @tianyu-l
Follow up for #1634
Updated the warning message according to [Jiani's
suggestion](#1634 (review))
cc @wwwjn
Validated debugmodel llama3 works, but ds3 crashes becuase of
`build_optimizers_with_moe_load_balancing` doing stuff that traverses
the original model structure, only now its an AutoParallelModule which
isn't compatible, we'll have to disable this optimization for now and
think about what to do.

Note: paths have changed, update your run commands:

`CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml
./run_train.sh --model.name llama3_auto_parallel
--parallelism.tensor_parallel_degree 4`

Failing (ds3):
`CONFIG_FILE=././torchtitan/models/deepseek_v3/train_configs/debug_model.toml
./run_train.sh --model.name deepseekv3_auto_parallel`
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.