[MLX] Qwen return last-token logits from forward; make SamplingHead operate on (B, vocab)#20604
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20604
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5e23a68 with merge base 0cef6de ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
c05ca24 to
26fa891
Compare
26fa891 to
5e23a68
Compare
|
@claude review this code |
|
Claude finished @metascroy's task in 2m 10s —— View job Code Review
I reviewed the diff and traced every consumer of qwen's Correctness ✅The central claim — "the runtime sampler already handles 2D; chunked prefill result is identical" — checks out:
Consumers all updated consistently ✅
Things I checked that are correctly left alone
Minor / non-blocking
No bugs, no missing consumers, no C++ changes needed. LGTM. 👍 |
Summary
Qwen's MLX forward now returns last-token logits (B, vocab) instead of (B, S, vocab), so lm_head runs on one position per prefill instead of the whole sequence. SamplingHead correspondingly drops its internal last-token slice and samples directly from (B, vocab), making it model-agnostic.
Changes
backends/mlx/llm/sampling.py — SamplingHead drops its internal logits[:, -1, :] slice; samples a token (B) directly from (B, vocab), removing the (B, S, vocab) path.
examples/models/qwen3_5_moe/export.py — _clean_forward returns lm_head(x[:, -1, :]) → (B, vocab) instead of all S positions, so lm_head runs once per prefill (chunk) instead of over the whole sequence.
run.py + test_chunked_prefill.py — consumers of qwen's forward output, both dropping the [0, -1, :] indexing.
test_ops.py + test_sample.py — op-level tests of SamplingHead/sample shape + node counts (upstream's top-k PR).
No C++ changes — the runtime sampler already handles 2D; chunked prefill result is identical.
Testing