Skip to content

Training Extremely Slow on Qwen3.5-35B-A3B + 8×B300 (280s/step), py-spy Shows FlashAttention-4 CUTLASS JIT Compilation During Backward #2635

Description

@yszhli

Environment
Hardware
1 node
8 × NVIDIA B300 GPUs
Software
Python: 3.12.13
ms-swift (latest)
Megatron backend
FlashAttention installed from source
pip list | grep flash

flash_attn 2.8.4
flash-attn-4 4.0.0b17.dev0+gb02b07e.d20260608
flash-linear-attention 0.5.0
flashinfer-cubin 0.6.8.post1
flashinfer-python 0.6.8.post1

Training Command
--model /mnt/wfs2139/model/Qwen3.5-35B-A3B
--cached_dataset /opt/wfs2139/data/vlm/cache/train
--use_distributed_optimizer true
--save_safetensors true
--load_from_cache_file true
--tensor_model_parallel_size 1
--pipeline_model_parallel_size 1
--expert_model_parallel_size 8
--freeze_llm false
--moe_permute_fusion true
--moe_grouped_gemm true
--moe_shared_expert_overlap true
--moe_aux_loss_coeff 1e-6
--micro_batch_size 4
--global_batch_size 32
--recompute_granularity full
--recompute_method uniform
--recompute_num_layers 1
--num_train_epochs 1
--finetune true
--cross_entropy_loss_fusion true
--lr 1e-5
--lr_warmup_fraction 0.05
--min_lr 5e-8
--output_dir /mnt/wfs2139/zli/megatron_output/Qwen3.5-35B-A3B-605
--eval_steps 500000
--save_steps 1500
--max_length 32000
--dataloader_num_workers 1
--dataset_num_proc 32
--attention_backend flash
--no_save_optim true
--no_save_rng true
--sequence_parallel true
--padding_free true
--packing true
--dataloader_pin_memory false
--gradient_accumulation_fusion false

Problem

Training is extremely slow.

Observed throughput
~280 seconds / step

This is much slower than expected for:
Qwen3.5-35B-A3B
EP=8
TP=1
PP=1
8×B300

py-spy Observation

I attached to one training process using:
The main thread is blocked in backward:
custom_backward
└── backward_step
└── forward_backward_no_pipelining
└── train_step
However, one Python thread is actively consuming CPU and holding the GIL.
The stack trace repeatedly enters:
flash_attn/cute/interface.py
_flash_attn_bwd

cutlass/base_dsl/compiler.py
_compile

cutlass/base_dsl/dsl.py
generate_mlir

flash_attn/cute/sm100_hd256_2cta_fmha_backward_dkdvkernel.py

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions