Skip to content

Conversation

@zyongye
Copy link
Member

@zyongye zyongye commented Nov 18, 2025

Purpose

Deepseek recently find error in their official implementation that ROPE in indexer shouldn't be interleaved.

Test Plan

gsm8k 20-shots

Test Result

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|    20|exact_match|↑  |0.9568|±  |0.0056|
|     |       |strict-match    |    20|exact_match|↑  |0.9553|±  |0.0057|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Yongye Zhu <[email protected]>
@mergify mergify bot added the deepseek Related to DeepSeek models label Nov 18, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix an issue with the Rotary Positional Embedding (RoPE) in the DeepSeek V3.2 indexer, ensuring it uses an interleaved implementation as per the official fix. The changes correctly introduce a separate indexer_rope_emb with the appropriate is_neox_style=True setting and pass it to the indexer.

However, I've identified a critical issue in Indexer.forward where the tensor shapes are incorrectly manipulated after applying the rotary embeddings. The use of squeeze(0) will cause runtime errors due to shape mismatches during both prefill and decoding. I have provided a detailed comment and a code suggestion to fix this bug.

Comment on lines 848 to 850
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe)
q = torch.cat([q_pe.squeeze(0), q_nope], dim=-1)
k = torch.cat([k_pe.squeeze(0), k_nope], dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of squeeze(0) on q_pe and k_pe is incorrect and will lead to shape mismatches and runtime errors during both prefill (num_tokens > 1) and decoding (num_tokens == 1) phases.

  • For q_pe with shape [num_tokens, n_head, rope_dim], squeeze(0) will fail during decoding when num_tokens is 1, as it would try to concatenate a 2D tensor with a 3D tensor.
  • For k_pe with shape [num_tokens, 1, rope_dim], squeeze(0) will fail during prefill when num_tokens > 1, as it would try to concatenate a 3D tensor with a 2D tensor.

The correct approach is to not squeeze q_pe and to squeeze k_pe on dimension 1.

Suggested change
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe)
q = torch.cat([q_pe.squeeze(0), q_nope], dim=-1)
k = torch.cat([k_pe.squeeze(0), k_nope], dim=-1)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe)
q = torch.cat([q_pe, q_nope], dim=-1)
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines 24 to 28
q_b_proj: torch.nn.Module | None
q_proj: torch.nn.Module | None
indexer: torch.nn.Module | None
indexer_rotary_emb: torch.nn.Module
is_sparse: bool

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Provide indexer_rotary_emb for all MLAModules callers

MLAModules now requires a non-default indexer_rotary_emb, but other MLA users still instantiate MLAModules without that argument (e.g., OpenPangu at vllm/model_executor/models/openpangu.py:361-378 and Kimi at vllm/model_executor/models/kimi_linear.py:252-265). Constructing those models will now raise TypeError: MLAModules.__init__() missing 1 required positional argument: 'indexer_rotary_emb' before attention ever runs, breaking non-DeepSeek MLA model initialization.

Useful? React with 👍 / 👎.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zyongye I think this a fair concern; can we add defaults?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Changed.

Signed-off-by: Yongye Zhu <[email protected]>
Signed-off-by: Yongye Zhu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants