Skip to content

[MLX] Qwen3.5 MoE ondevice sampling#20545

Open
kiymetakdemir wants to merge 8 commits into
pytorch:mainfrom
kiymetakdemir:qwen-moe-ondevice-sampling
Open

[MLX] Qwen3.5 MoE ondevice sampling#20545
kiymetakdemir wants to merge 8 commits into
pytorch:mainfrom
kiymetakdemir:qwen-moe-ondevice-sampling

Conversation

@kiymetakdemir

Copy link
Copy Markdown
Contributor

Summary
Lets the MLX-exported Qwen3.5 MoE model sample the next token on-device instead of returning logits for host-side sampling. Sampling is opt-in at export (--sample); temperature, top_p, and seed are runtime inputs, and the runner increments the seed per token. Measured ~17% higher decode throughput.

Changes

  • export.py --sample flag wraps the model in SamplingHead so forward(tokens, input_pos, temperature, top_p, seed) → int64 token; records a use_sampling constant-method flag. Non-sample export unchanged.
  • qwen35_moe_engine.cpp reads use_sampling from metadata; when set, consumes the token id directly instead of logits_to_token, feeds the scalar inputs in prefill/decode, and manages the per-token seed schedule. top_k is still rejected; top_p/seed are rejected on non-sample models.
  • main.cpp --top_p / --seed flags wired into SamplingConfig.
  • run.py detects use_sampling, mirrors the same seed schedule for parity, and rejects top_p/seed on non-sample models (matching the C++ runner). --top-p / --seed flags added.

Verification

  • Real model (C++ runner): coherent, constraint-following generation; same --seed reproducible; --temperature 0 deterministic.
  • +17% decode throughput vs host sampling; prefill unchanged.

@pytorch-bot

pytorch-bot Bot commented Jun 26, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20545

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 Pending, 1 Unrelated Failure, 2 Unclassified Failures

As of commit 2ad867d with merge base 51729bb (image):

UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 26, 2026
@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@kiymetakdemir kiymetakdemir marked this pull request as ready for review June 26, 2026 16:29
model.forward = types.MethodType(_clean_forward, model)


class _MLXSampleWrapper(nn.Module):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just use SamplingHead directly? What does the wrapper do extra?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is because temperature, top_p, seed are keyword only arguments, the wrapper maps them into positional args. I can pass them to export's kwargs instead but it also makes dynamic_shapes a structured args/kwargs dict rather than a tuple.

if sample:
# forward(tokens, input_pos, temperature, top_p, seed) -> token id.
# Scalars are static (None in dynamic_shapes); only the seq dim is dynamic.
model = _MLXSampleWrapper(model)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use SamplingHead directly?

Comment thread examples/models/qwen3_5_moe/main.cpp Outdated
"with --sample (MLX on-device sampling).");
DEFINE_int64(
seed,
0,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can default mean we generate base seed by calling random int from c++ library?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this.

@metascroy

Copy link
Copy Markdown
Contributor

@claude review this code

@claude

claude Bot commented Jun 26, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 2m 7s —— View job


Code review: Qwen3.5 MoE on-device sampling

Read all four changed files plus backends/mlx/llm/sampling.py (SamplingHead). Overall this is clean and the seed schedule parity between the C++ runner and run.py checks out. Below are answers to the inline questions and a few findings.


Answering the reviewer questions

1. export.py:736 / :780 — "Can we just use SamplingHead directly?"

Not directly with the current positional-export setup. SamplingHead.forward makes temperature, top_k, top_p, seed keyword-only:

def forward(self, *args, temperature, top_k=None, top_p=1.0, seed=None, **kwargs):

The export builds a positional example_args tuple and a positional dynamic_shapes tuple:

example_args = (tokens, input_pos, temperature, top_p, seed)
dynamic_shapes = ({1: seq_dim}, {0: seq_dim}, None, None, None)

You can't feed keyword-only params positionally, so the C++/run.py callers (which pass 5 positional inputs) wouldn't line up with a bare SamplingHead. _MLXSampleWrapper exists purely to flatten the signature to positional forward(tokens, input_pos, temperature, top_p, seed). That's the only thing it adds — no extra logic.

Two ways to drop the wrapper if you prefer:

  • Export SamplingHead directly via export(head, args=(tokens, input_pos), kwargs={"temperature":..., "top_p":..., "seed":...}, dynamic_shapes={...dict...}). This changes the exported graph's input order/naming (kwargs come after args), so the C++ inputs.push_back order and run.py would need to match — worth confirming MLX lowering preserves the expected positional order before going this route.
  • Make temperature/top_p/seed positional-or-keyword in SamplingHead itself. That's the cleanest if SamplingHead is the canonical entry point, but it touches shared MLX code used elsewhere.

Given those trade-offs, keeping the thin wrapper local to export.py is defensible. A one-line comment on _MLXSampleWrapper saying "positional adapter — SamplingHead's sampling params are keyword-only and export needs them positional" would make the intent obvious and pre-empt this exact question.

2. main.cpp:47 — "Can default mean we generate base seed by calling a random int from the C++ library?"

Yes, and it would arguably be better UX. Today --seed defaults to 0, so every run is byte-for-byte identical even at temperature>0 — surprising for "sampling." A common pattern: treat seed < 0 (or a dedicated sentinel) as "draw a fresh base seed," e.g.

DEFINE_int64(seed, -1, "Base RNG seed; <0 = random per run. Runner increments per token.");
...
uint64_t base_seed = FLAGS_seed >= 0
    ? static_cast<uint64_t>(FLAGS_seed)
    : std::random_device{}();

Keep an explicit non-negative --seed reproducible (your verification relies on that). Mirror the same sentinel in run.py for parity. One caveat: if you go random-by-default, log the chosen seed so a run can still be reproduced.


Findings

Seed-schedule parity (verified, no bug). C++ prefill samples token 0 with seed_ == base, then seed_ += 1 so the first decode draws with base+1; each subsequent decode increments. run.py prefill samples with seed, then decode uses seed + len(generated_tokens) (== base+1 for the first decode since generated_tokens already holds the prefill token). These line up. Chunked prefill in C++ reuses seed_ across chunks but only the final chunk's token is kept, so it matches run.py's single-pass prefill. 👍

No validation of top_p range. Both runners and SamplingHead document top_p ∈ (0, 1] but nothing rejects 0, negative, or >1. top_p=0 in particular could produce a degenerate nucleus on-device. Consider a cheap guard in both main.cpp and run.py (and/or prefill_tokens/decode_one) alongside the existing temperature check. Low severity.

run.py use_sampling detection swallows all exceptions.

try:
    meta_method = program.load_method("use_sampling")
    ...
except Exception:
    pass

This is fine for an old .pte lacking the method, but it also silently hides a genuine load/execute failure on a sampling model (you'd fall back to logits-path and then crash on the 2-input forward with a confusing error). Not blocking, but narrowing the catch or logging at debug would aid debugging.

Minor — noqa: C901 added to two functions. _export_mlx and run_inference/main picked up complexity suppressions. Reasonable given the branching, but if run_inference keeps growing, the sampling vs. logits split (_sampling_scalars is already factored out) could move into a small helper. Non-blocking.

Defaults consistency. main.cpp --top_p default 1.0 and run.py --top-p default 1.0 match; rejection messages match. Good parity across the two runners.


Summary

No correctness bugs found — the seed schedule, chunking interaction, and cross-runner parity are sound. The wrapper is justified by SamplingHead's keyword-only signature (worth a clarifying comment). The two suggestions worth acting on: (a) optional top_p range validation, and (b) consider a random-by-default --seed per the reviewer, logging the chosen value for reproducibility.

@digantdesai digantdesai changed the title Qwen3.5 MoE ondevice sampling [MLX] Qwen3.5 MoE ondevice sampling Jun 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants