Skip to content

[MLX] Qwen return last-token logits from forward; make SamplingHead operate on (B, vocab)#20604

Merged
kiymetakdemir merged 1 commit into
pytorch:mainfrom
kiymetakdemir:qwen-samplinghead-2d
Jun 29, 2026
Merged

[MLX] Qwen return last-token logits from forward; make SamplingHead operate on (B, vocab)#20604
kiymetakdemir merged 1 commit into
pytorch:mainfrom
kiymetakdemir:qwen-samplinghead-2d

Conversation

@kiymetakdemir

@kiymetakdemir kiymetakdemir commented Jun 29, 2026

Copy link
Copy Markdown
Contributor

Summary

Qwen's MLX forward now returns last-token logits (B, vocab) instead of (B, S, vocab), so lm_head runs on one position per prefill instead of the whole sequence. SamplingHead correspondingly drops its internal last-token slice and samples directly from (B, vocab), making it model-agnostic.

Changes

  • backends/mlx/llm/sampling.py — SamplingHead drops its internal logits[:, -1, :] slice; samples a token (B) directly from (B, vocab), removing the (B, S, vocab) path.

  • examples/models/qwen3_5_moe/export.py — _clean_forward returns lm_head(x[:, -1, :]) → (B, vocab) instead of all S positions, so lm_head runs once per prefill (chunk) instead of over the whole sequence.

  • run.py + test_chunked_prefill.py — consumers of qwen's forward output, both dropping the [0, -1, :] indexing.

  • test_ops.py + test_sample.py — op-level tests of SamplingHead/sample shape + node counts (upstream's top-k PR).

  • No C++ changes — the runtime sampler already handles 2D; chunked prefill result is identical.

Testing

  • MLX sample op tests pass on-device.
  • Prefill ~19% faster (1505-token prompt): ~1352 → ~1612 tok/s; decode unchanged (~75 tok/s).

@pytorch-bot

pytorch-bot Bot commented Jun 29, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20604

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 5e23a68 with merge base 0cef6de (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 29, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@kiymetakdemir kiymetakdemir force-pushed the qwen-samplinghead-2d branch from c05ca24 to 26fa891 Compare June 29, 2026 20:52
@kiymetakdemir kiymetakdemir force-pushed the qwen-samplinghead-2d branch from 26fa891 to 5e23a68 Compare June 29, 2026 21:08
@kiymetakdemir kiymetakdemir marked this pull request as ready for review June 29, 2026 21:21
@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 29, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 2m 10s —— View job


Code Review

  • Read PR diff against origin/main
  • Trace forward-output consumers (Python + C++)
  • Verify runtime sampler handles 2D logits

I reviewed the diff and traced every consumer of qwen's forward output. This looks correct and clean — I agree with the approval.

Correctness ✅

The central claim — "the runtime sampler already handles 2D; chunked prefill result is identical" — checks out:

  • extension/llm/sampler/util.h:47-51sample_from_logits reads vocab_size from the last dim and only applies the (num_tokens-1)*vocab_size last-token offset when dim() == 3. A 2D (B, vocab) tensor falls straight through to sampler.sample(logits), so the host-side non-sampling path (logits_to_token in qwen35_moe_engine.cpp:93) is unaffected.
  • The on-device sampling path (qwen35_moe_engine.cpp:90-91) reads output.const_data_ptr<int64_t>()[0], which is correct whether the token is shaped (B) or (B,1).
  • Slicing x[:, -1, :] before lm_head is numerically equivalent to slicing the last position after lm_head (the head is position-wise), so prefill output is identical — only the wasted compute over non-final positions is removed. The ~19% prefill speedup is consistent with that.

Consumers all updated consistently ✅

  • run.py:54outputs[0][0] now indexes (B, vocab)(vocab,). Correct.
  • test_chunked_prefill.py:56 and test_ops.py / test_sample.py shapes all moved from (1, S, vocab)(1, vocab), and the TakeNode: 2 → 1 count drop correctly reflects the removed last-token slice.

Things I checked that are correctly left alone

  • model.py:672Qwen35MoE.forward still returns [B, T, V] on the temperature is None (CUDA/eager) path. This is intentional: _strip_sampler_from_forward only rebinds forward for MLX/Metal export, and the CUDA runtime path handles 3D. No staleness here.
  • test_turboquant.py:93,100 — uses logits[0, -1] against the un-stripped CUDA forward (3D), so it correctly was not touched.

Minor / non-blocking

  • The SamplingHead docstring now says it wraps "a model that returns last-token logits (B, vocab)". That makes the head model-specific in its contract — fine for the qwen wrapper, just worth keeping in mind if SamplingHead is ever reused with a model whose forward returns the full sequence. Not a change request.

No bugs, no missing consumers, no C++ changes needed. LGTM. 👍
· branch qwen-samplinghead-2d

@kiymetakdemir kiymetakdemir merged commit 6386cef into pytorch:main Jun 29, 2026
213 checks passed
@kiymetakdemir kiymetakdemir deleted the qwen-samplinghead-2d branch June 29, 2026 23:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants