[MLX] Qwen3.5 MoE ondevice sampling#20545
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20545
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 Pending, 1 Unrelated Failure, 2 Unclassified FailuresAs of commit 2ad867d with merge base 51729bb ( UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
| model.forward = types.MethodType(_clean_forward, model) | ||
|
|
||
|
|
||
| class _MLXSampleWrapper(nn.Module): |
There was a problem hiding this comment.
Can we just use SamplingHead directly? What does the wrapper do extra?
There was a problem hiding this comment.
It is because temperature, top_p, seed are keyword only arguments, the wrapper maps them into positional args. I can pass them to export's kwargs instead but it also makes dynamic_shapes a structured args/kwargs dict rather than a tuple.
| if sample: | ||
| # forward(tokens, input_pos, temperature, top_p, seed) -> token id. | ||
| # Scalars are static (None in dynamic_shapes); only the seq dim is dynamic. | ||
| model = _MLXSampleWrapper(model) |
There was a problem hiding this comment.
Use SamplingHead directly?
| "with --sample (MLX on-device sampling)."); | ||
| DEFINE_int64( | ||
| seed, | ||
| 0, |
There was a problem hiding this comment.
Can default mean we generate base seed by calling random int from c++ library?
|
@claude review this code |
|
Claude finished @metascroy's task in 2m 7s —— View job Code review: Qwen3.5 MoE on-device samplingRead all four changed files plus Answering the reviewer questions1. Not directly with the current positional-export setup. def forward(self, *args, temperature, top_k=None, top_p=1.0, seed=None, **kwargs):The export builds a positional example_args = (tokens, input_pos, temperature, top_p, seed)
dynamic_shapes = ({1: seq_dim}, {0: seq_dim}, None, None, None)You can't feed keyword-only params positionally, so the C++/ Two ways to drop the wrapper if you prefer:
Given those trade-offs, keeping the thin wrapper local to 2. Yes, and it would arguably be better UX. Today DEFINE_int64(seed, -1, "Base RNG seed; <0 = random per run. Runner increments per token.");
...
uint64_t base_seed = FLAGS_seed >= 0
? static_cast<uint64_t>(FLAGS_seed)
: std::random_device{}();Keep an explicit non-negative FindingsSeed-schedule parity (verified, no bug). C++ prefill samples token 0 with No validation of
try:
meta_method = program.load_method("use_sampling")
...
except Exception:
passThis is fine for an old Minor — Defaults consistency. SummaryNo correctness bugs found — the seed schedule, chunking interaction, and cross-runner parity are sound. The wrapper is justified by |
Summary
Lets the MLX-exported Qwen3.5 MoE model sample the next token on-device instead of returning logits for host-side sampling. Sampling is opt-in at export (--sample); temperature, top_p, and seed are runtime inputs, and the runner increments the seed per token. Measured ~17% higher decode throughput.
Changes
Verification