Skip to content

fix: scatter Qwen3-VL text-model embeddings under SP for standalone LM forwards#4629

Closed
kevalmorabia97 wants to merge 1 commit into
mainfrom
kmorabia/qwen3vl-lm-sp-embedding-scatter
Closed

fix: scatter Qwen3-VL text-model embeddings under SP for standalone LM forwards#4629
kevalmorabia97 wants to merge 1 commit into
mainfrom
kmorabia/qwen3vl-lm-sp-embedding-scatter

Conversation

@kevalmorabia97

Copy link
Copy Markdown
Contributor

What

Fix a sequence-parallel bug in the Qwen3-VL text model that breaks standalone language-model forwards (e.g. distilling only the LM tower of a VLM) under tensor + sequence parallelism.

Why

Qwen3VLModel builds its text model with scatter_embedding_sequence_parallel=False because the outer VLM model manually scatters the merged vision+text embeddings for sequence parallelism (modelling_qwen3_vl/model.py).

When the language model is run on its own — no outer VLM, e.g. ModelOpt language-model distillation of a VLM — that manual scatter is bypassed. Under SP the embeddings therefore stay unscattered, the decoder runs the full sequence on every TP rank, and the output-side sequence-parallel gather then doubles the sequence length.

This surfaces downstream as a TP_size × seq_length vs seq_length shape mismatch. Concretely, ModelOpt VLM language-model distillation at TP=2 + SP fails in the KD loss-mask step:

RuntimeError: The size of tensor a (32) must match the size of tensor b (16) at non-singleton dimension 0

(32 = TP_size(2) × seq_length(16).) Plain TP (no SP) works; the plain LM path works; only the standalone-VLM-LM + SP combination is affected.

Fix

In Qwen3VLTextModel.forward, when no external decoder_input is provided (i.e. the LM embeds internally) and sequence parallelism is enabled, scatter the internally-embedded sequence — mirroring the existing MTP-path scatter a few lines below. The normal VLM forward (which passes pre-scattered decoder_input from the outer model) is unaffected.

if not _embeddings_provided and self.pre_process and self.config.sequence_parallel:
    decoder_input = tensor_parallel.scatter_to_sequence_parallel_region(
        decoder_input, group=self.pg_collection.tp
    )

Testing

Validated on nvcr.io/nvidia/nemo:26.06: Qwen3.5-VL language-model distillation at TP=2 + SP passes with this change (previously failed at the loss-mask step). Single-GPU and TP-without-SP were already passing and remain so.

🤖 Generated with Claude Code

…M forwards

The Qwen3-VL text model is built with scatter_embedding_sequence_parallel=False because the
outer VLM model manually scatters the merged vision+text embeddings for sequence parallelism.
When the language model is run standalone (e.g. distilling only the LM, with no outer VLM), that
manual scatter is bypassed: under SP the embeddings stay unscattered, the decoder runs the full
sequence on every TP rank, and the output-side sequence gather then doubles the sequence length.
This surfaces downstream as a TP_size x seq_length vs seq_length shape mismatch (e.g. a KD
loss-mask error during language-model distillation).

Scatter the internally-embedded sequence in the main forward when no external decoder_input is
provided and sequence parallelism is enabled, mirroring the existing MTP-path handling. The normal
VLM forward (which passes pre-scattered decoder_input from the outer model) is unaffected.

Validated on nemo:26.06: Qwen3.5-VL language-model distillation at TP=2 + SP passes with this
change (previously failed at the loss-mask step).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@claude

claude Bot commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Light review - LGTM.

The fix is minimal and correct:

  • _embeddings_provided (line 147) is captured before the _preprocess unpack reassigns decoder_input (line 157), so it faithfully reflects the original caller argument.
  • The guard "not _embeddings_provided and self.pre_process and self.config.sequence_parallel" correctly isolates the standalone-LM path: the normal VLM forward passes pre-scattered decoder_input (model.py:554-558), so _embeddings_provided is True and the new scatter is skipped.
  • The scatter mirrors the existing MTP-path scatter (lines 196-204) and uses the same self.pg_collection.tp group, consistent with the outer model.

Minor notes (non-blocking):

  • Consider a short comment noting this matches the MTP scatter below, so the two SP-scatter sites stay in sync if one is ever refactored.
  • No automated regression test accompanies the fix. An SP-only path is hard to cover cheaply, but a 2-GPU (TP=2 + SP) standalone-LM-forward shape assertion would guard against regressions if it fits the 2-GPU budget.

Suggested test cases

  • No perf tests impacted. Only models/qwen_vl/modelling_qwen3_vl/text_model.py is changed; no scripts/performance/configs/ entries were touched.
  • If a regression test is added: a TP=2 + sequence-parallel standalone Qwen3VLTextModel.forward (no external decoder_input) asserting the output hidden_states sequence length equals seq_length rather than TP_size x seq_length.

@kevalmorabia97 kevalmorabia97 requested a review from yaoyu-33 July 2, 2026 18:15
@yaoyu-33 yaoyu-33 added area:model Model implementations and HF bridge logic bug Something isn't working needs-more-tests Requires additional L0 and L1 test coverage before merge needs-review PR is ready for code review and waiting on a reviewer labels Jul 2, 2026
@kevalmorabia97

Copy link
Copy Markdown
Contributor Author

Superseded by a general fix in Megatron-LM (mcore): NVIDIA/Megatron-LM#5628.

Investigation showed this is not Qwen3-VL-specific — every Megatron-Bridge VLM/omni/audio model builds its language tower with scatter_embedding_sequence_parallel=False (the outer multimodal model scatters the merged embeddings), so a standalone-LM forward under sequence parallelism hits the same TP×seq doubling. Confirmed for both Qwen3.5-VL and Gemma3-VL.

The Qwen3-VL text-model change here only covered Qwen3-VL/3.5-VL and would double-scatter if combined with the mcore fix. The mcore GPTModel._preprocess fix (#5628) covers all such models in one place and is a no-op for standard LMs, so closing this in favor of it.

@kevalmorabia97 kevalmorabia97 deleted the kmorabia/qwen3vl-lm-sp-embedding-scatter branch July 2, 2026 19:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:model Model implementations and HF bridge logic bug Something isn't working needs-more-tests Requires additional L0 and L1 test coverage before merge needs-review PR is ready for code review and waiting on a reviewer

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants