Skip to content

[MLX] make SamplingHead directly exportable; drop sampler wrapper; wire runtime top-k#20612

Merged
kiymetakdemir merged 2 commits into
pytorch:mainfrom
kiymetakdemir:qwen-sampling-positional
Jun 30, 2026
Merged

[MLX] make SamplingHead directly exportable; drop sampler wrapper; wire runtime top-k#20612
kiymetakdemir merged 2 commits into
pytorch:mainfrom
kiymetakdemir:qwen-sampling-positional

Conversation

@kiymetakdemir

Copy link
Copy Markdown
Contributor

Summary

SamplingHead.forward now takes the sampling params as trailing positional args (temperature, top_k, top_p, seed), so torch.export drives it directly and the per-model _MLXSampleWrapper is removed. As part of going positional, top_k becomes a runtime input threaded through export, run.py, and the C++ engine; qwen now supports per-request top-k, and the op-level "not implemented" rejections are dropped.

Changes

  • backends/mlx/llm/sampling.py — SamplingHead.forward(self, *args) unpacks *model_args, temperature, top_k, top_p, seed; samples (B) from (B, vocab). No wrapper needed (the params are positional, so it's directly exportable).
  • examples/models/qwen3_5_moe/export.py — delete _MLXSampleWrapper; model = SamplingHead(model); top_k added to example_args; dynamic_shapes nested to mirror the single *args parameter.
  • examples/models/qwen3_5_moe/run.py — _sampling_scalars feeds top_k (<=0 → keep-all maxint); --top-k flag; the non-sample guard rejects top_k/top_p/seed.
  • backends/mlx/test/test_ops.py, test_sample.py — fixtures call SamplingHead positionally, baking _KEEP_ALL_TOP_K/_TOP_P_OFF constants for params they don't expose as runtime inputs.
  • examples/models/qwen3_5_moe/main.cpp — --top_k flag → SamplingConfig.top_k.
  • examples/models/qwen3_5_moe/qwen35_moe_engine.cpp — top_k scalar wired through prefill/decode (fed between temp and top_p, matching the new forward order); 0→INT64_MAX keep-all mapping; the two "top_k is not implemented" rejections removed.

Testing

  • Re-exported qwen, greedy output matches a pre-refactor baseline, seeded runs are reproducible, and --top-k restricts sampling.
  • MLX sample op tests pass; test_sample.py passes (incl. top-k end-to-end).
  • C++ MLX runner builds and runs the sample path; seeded reproducible and --top_k works on-device.

@pytorch-bot

pytorch-bot Bot commented Jun 29, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20612

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 67c9b62 with merge base db5f1ba (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 29, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 30, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 3m 23s —— View job


Code Review

I reviewed the diff against origin/main. This is a clean, well-scoped refactor — making SamplingHead directly exportable by going fully positional, removing the per-model _MLXSampleWrapper, and threading top_k end-to-end (export → run.py → C++ engine → custom op). Call sites are consistent and the param order (temperature, top_k, top_p, seed) matches across Python, the C++ engine, and torch.ops.mlx.sample. Nicely done; already approved by @metascroy.

A few observations, none blocking:

1. top_k keep-all mapping is consistent across layers ✅

  • run.py _sampling_scalars: top_k <= 0 → INT64_MAX
  • engine set_sampling_inputs: (top_k <= 0) ? INT64_MAX : top_k
  • custom op custom_ops.py:430: k = min(int(top_k.item()), scaled.shape[-1])

So the worst case feeds INT64_MAX into min(..., vocab)k = vocab, then s_scaled[..., k-1:k] is valid. Good. Worth a one-line confirmation that the on-device MLX kernel applies the same min(top_k, vocab) clip (only the reference torch op is visible in this diff).

2. No top_k range validation in the engine
top_p gets valid_top_p() but top_k is passed through unchecked aside from the <= 0 keep-all branch. That's fine given the clip-to-vocab behavior, but a positive top_k larger than the vocab silently degrades to keep-all rather than erroring — acceptable, just noting it's intentional.

3. dynamic_shapes nesting for variadic *args (export.py:773)

dynamic_shapes = (({1: seq_dim}, {0: seq_dim}, None, None, None, None),)

The extra nesting to mirror the single *args parameter is the subtle part of this change. The PR notes a re-export with matching greedy output, so this is verified — good. The inline comment explaining why it's nested is helpful for the next reader.

4. SamplingConfig.top_k is int32_t, engine member top_k_ is int64_t (llm_session.h:32 vs qwen35_moe_engine.cpp:558)
The widening assignment is safe and the on-device tensor is correctly Long/int64. No action needed.

5. Docstring nit (sampling.py:30)
seed: scalar int tensor (seeded) or Noneseed is now a required trailing positional and the other three params dropped their None/default language, but seed still documents None. That's accurate (UnseededSampleModel passes None and the op signature is Optional), just slightly inconsistent in tone with the now-required top_k/top_p docs above it. Cosmetic.

6. Out of scope, FYI: examples/models/qwen3_5_moe/test_sampler.py:14 still carries the note "top-k / top-p tests are intentionally omitted — deferred to a follow-up PR." That file targets the separate host-side sampler.py, not SamplingHead, so it's correctly untouched — but the comment is now slightly stale relative to the MLX path gaining top-k. Not part of this PR's surface.

Overall: correct, consistent, and the tests/runtime were exercised per the PR description. LGTM.
· branch qwen-sampling-positional

@kiymetakdemir kiymetakdemir merged commit 20944fd into pytorch:main Jun 30, 2026
224 of 225 checks passed
@kiymetakdemir kiymetakdemir deleted the qwen-sampling-positional branch June 30, 2026 16:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants