[MLX] make SamplingHead directly exportable; drop sampler wrapper; wire runtime top-k#20612
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20612
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 67c9b62 with merge base db5f1ba ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review this code |
|
Claude finished @metascroy's task in 3m 23s —— View job Code ReviewI reviewed the diff against A few observations, none blocking: 1.
So the worst case feeds 2. No 3. dynamic_shapes = (({1: seq_dim}, {0: seq_dim}, None, None, None, None),)The extra nesting to mirror the single 4. 5. Docstring nit ( 6. Out of scope, FYI: Overall: correct, consistent, and the tests/runtime were exercised per the PR description. LGTM. |
Summary
SamplingHead.forward now takes the sampling params as trailing positional args (temperature, top_k, top_p, seed), so torch.export drives it directly and the per-model _MLXSampleWrapper is removed. As part of going positional, top_k becomes a runtime input threaded through export, run.py, and the C++ engine; qwen now supports per-request top-k, and the op-level "not implemented" rejections are dropped.
Changes
Testing