diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index d2582a3f353..869a448eb2c 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -332,6 +332,14 @@ def _preprocess( f"input_ids shape {input_ids.shape}" ) decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + if self.config.sequence_parallel and not self.embedding.scatter_to_sequence_parallel: + # Some models build the embedding to not scatter for sequence parallelism (e.g. + # VLM language models, whose outer multimodal model scatters the merged + # vision/audio+text embeddings itself). When such a model is run standalone + # (input_ids, no external decoder_input) -- e.g. language-model-only + # distillation/PTQ -- scatter here so the decoder sees a sequence-parallel + # sharded sequence and the output-side gather does not double the sequence. + decoder_input = tensor_parallel.scatter_to_sequence_parallel_region(decoder_input) if padding_mask is not None and self.config.sequence_parallel: padding_mask = ( tensor_parallel.scatter_to_sequence_parallel_region(