fix: scatter Qwen3-VL text-model embeddings under SP for standalone LM forwards#4629
fix: scatter Qwen3-VL text-model embeddings under SP for standalone LM forwards#4629kevalmorabia97 wants to merge 1 commit into
Conversation
…M forwards The Qwen3-VL text model is built with scatter_embedding_sequence_parallel=False because the outer VLM model manually scatters the merged vision+text embeddings for sequence parallelism. When the language model is run standalone (e.g. distilling only the LM, with no outer VLM), that manual scatter is bypassed: under SP the embeddings stay unscattered, the decoder runs the full sequence on every TP rank, and the output-side sequence gather then doubles the sequence length. This surfaces downstream as a TP_size x seq_length vs seq_length shape mismatch (e.g. a KD loss-mask error during language-model distillation). Scatter the internally-embedded sequence in the main forward when no external decoder_input is provided and sequence parallelism is enabled, mirroring the existing MTP-path handling. The normal VLM forward (which passes pre-scattered decoder_input from the outer model) is unaffected. Validated on nemo:26.06: Qwen3.5-VL language-model distillation at TP=2 + SP passes with this change (previously failed at the loss-mask step). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
|
Light review - LGTM. The fix is minimal and correct:
Minor notes (non-blocking):
Suggested test cases
|
|
Superseded by a general fix in Megatron-LM (mcore): NVIDIA/Megatron-LM#5628. Investigation showed this is not Qwen3-VL-specific — every Megatron-Bridge VLM/omni/audio model builds its language tower with The Qwen3-VL text-model change here only covered Qwen3-VL/3.5-VL and would double-scatter if combined with the mcore fix. The mcore |
What
Fix a sequence-parallel bug in the Qwen3-VL text model that breaks standalone language-model forwards (e.g. distilling only the LM tower of a VLM) under tensor + sequence parallelism.
Why
Qwen3VLModelbuilds its text model withscatter_embedding_sequence_parallel=Falsebecause the outer VLM model manually scatters the merged vision+text embeddings for sequence parallelism (modelling_qwen3_vl/model.py).When the language model is run on its own — no outer VLM, e.g. ModelOpt language-model distillation of a VLM — that manual scatter is bypassed. Under SP the embeddings therefore stay unscattered, the decoder runs the full sequence on every TP rank, and the output-side sequence-parallel gather then doubles the sequence length.
This surfaces downstream as a
TP_size × seq_lengthvsseq_lengthshape mismatch. Concretely, ModelOpt VLM language-model distillation at TP=2 + SP fails in the KD loss-mask step:(32 = TP_size(2) × seq_length(16).) Plain TP (no SP) works; the plain LM path works; only the standalone-VLM-LM + SP combination is affected.
Fix
In
Qwen3VLTextModel.forward, when no externaldecoder_inputis provided (i.e. the LM embeds internally) and sequence parallelism is enabled, scatter the internally-embedded sequence — mirroring the existing MTP-path scatter a few lines below. The normal VLM forward (which passes pre-scattereddecoder_inputfrom the outer model) is unaffected.Testing
Validated on
nvcr.io/nvidia/nemo:26.06: Qwen3.5-VL language-model distillation at TP=2 + SP passes with this change (previously failed at the loss-mask step). Single-GPU and TP-without-SP were already passing and remain so.🤖 Generated with Claude Code