perf(cuda): port sample_logits chain to a fused GPU sampler kernel#478
Open
geometric[bot] wants to merge 17 commits into
Open
perf(cuda): port sample_logits chain to a fused GPU sampler kernel#478geometric[bot] wants to merge 17 commits into
geometric[bot] wants to merge 17 commits into
Conversation
The draft top-K + logsumexp kernel launched only n_positions (~15) blocks,
leaving most of the GPU's SMs idle, and kept its per-thread top-K in a
data-dependent insertion index that the compiler spilled to local memory and
re-read on every vocab element. Both made the ~9 MB vocab scan run at a small
fraction of peak DRAM bandwidth.
Rework into a split-K two-pass design:
- pass 1 (draft_topk_partial) splits each position's vocab scan across many
blocks (2D grid n_positions x split) so all SMs stay busy;
- pass 2 (draft_topk_combine) merges the per-split partials per position.
Template both kernels on K (compile-time) so the top-K stays register-resident
via a branchless unrolled bubble instead of spilling, and read logits as float4
(one coalesced 16-byte transaction per 4 logits) with a scalar fallback when a
row base is not 16-byte aligned (vocab % 4 != 0), preserving any-vocab
correctness. split is auto-tuned (env override DFLASH_TOPK_SPLIT).
Measured on an RTX 3090 (n=15, vocab=151936, K=8):
- GPU kernel time: 392 us -> 36.3 us (30.6 partial + 5.75 combine), 10.8x
- full call (kernel+sync+D2H): 0.407 ms -> 0.053 ms, 7.7x
Full-call speedup is 5.9-8.4x across n in {7,15,31,63}. Output is bit-for-bit
equivalent to the CPU reference (id_mismatches=0) across K in {1,2,4,8} and
both aligned and odd vocab; compute-sanitizer memcheck clean on both paths.
Adds bench_topk.cu, a standalone microbenchmark + CPU-reference correctness
harness (not wired into the build) used to profile and A/B this change.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01T55hNb5cgyCwNYnNAE1hun
Contributor
|
@pramodith @DeanoC it seems that the PR has some conflict, would be useful for us to be able to visualize them |
Contributor
|
@davide221 sorry the merge conflict was because a parallel worktree messed up the |
Contributor
|
New results after latest commit for coalescing the cuda reads and writes. End-to-End Baseline CPU only path: GPU Path: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Ports the CPU
sample_logitschain (repetition/frequency/presence penalty, softmax(temp), top_p nucleus, multinomial draw) to CUDA so the Qwen3.5 AR decode loop can sample straight off the device logits tensorinstead of paying a full vocab-wide (151,936-float) D2H copy every token.
Penalty application, the softmax reductions, and the draw are now one fused kernel (mode-selected: greedy / sample / emit-probs) using warp-shuffle block reductions, replacing what used to be a separate penalty kernel plus a shared-memory tree reduction.
top_ktruncation stays on the CPU but we usepartial_sortinstead of a full sort, and puretop_pisGPU-assisted: the GPU computes penalties+softmax and hands back the probability vector for the CPU to truncate.
CPU-side
top_palso changed independently of the GPU work: nucleus truncation no longer does a fullO(vocab log vocab)sort. It now usesnucleus_cutoff, anstd::nth_element-based recursive bisection thatfinds the cutoff index in
O(vocab)total work regardless of where the nucleus lands.The GPU path is on by default for CUDA builds; opt out with
DFLASH_GPU_SAMPLE=0.Impact
Kernel Level
DFLASH_SAMPLER_BENCH=1 ./test_gpu_sampler_cudaEnd-to-End
Baseline
Compute on CPU
Run via:
DFLASH_SAMP=0.8,1.0,0,1.1,42 DFLASH_GPU_SAMPLE=0 python -m server.scripts.bench_llm --bench HumanEvalNote: In the above command temp=0.8, top-p=1.0, top-k=0, rep_pen=1.1,r andom_seed=42
New Kernel
Run via:
DFLASH_SAMP=0.8,1.0,0,1.1,42 DFLASH_GPU_SAMPLE=1 python -m server.scripts.bench_llm --bench HumanEvalResults: An approx ~32% increase in tok/s with similar Acceptance Lengths.
Implementation
server/src/common/sampler.cpp— addednucleus_cutoff(nth_element-based O(vocab) bisection, replaces a full sort for top_p) anddraw_from_weights(deduplicates the final weighted CDF draw); wired the two GPU dispatch points (full-GPU for greedy/temp, GPU-assisted for pure top_p) intosample_logits.server/src/common/geometric_sampler_cuda.cu/h— new/rewritten: single fusedgeometric_sample_kernel(mode-selected greedy / sample / emit-probs) doing penalty application, softmax reductions, and the multinomial draw in one launch, with warp-shuffle block reductions and a per-devicepick_block_size.server/src/qwen35/qwen35_backend.cpp— AR decode now callsgeometric_sample_logits_cudadirectly on the device logits tensor when the sampler config is GPU-supported, skipping the vocab-wide D2H copy the CPU chain otherwise needs.server/CMakeLists.txt— added theDFLASH_GPU_SAMPLERbuild option (defaultON) that compilesgeometric_sampler_cuda.cuintodflash_common, and registered thetest_gpu_sampler_cudactest target.server/test/test_gpu_sampler_cuda.cpp— new correctness test: GPU vs CPU agreement for greedy, greedy+penalties, temperature-sample distribution, and top_k/top_p CPU-fallback signaling.server/test/test_dflash.cpp— added--samp=temp,top_p,top_k,rep_pen,seed[,freq,pres]to the positional (non-daemon) harness so benchmarks can exercise the sampler chain.server/scripts/bench_llm.py— addedDFLASH_SAMP(forwards the sampler tail totest_dflash --samp=) andDFLASH_N_SAMPLE(overrides prompts-per-dataset) env vars.README.md— documented GPU sampler coverage, runtime/build flags, and the benchmark table below.Runtime Flags / Configuration
Default-on paths:
DFLASH_GPU_SAMPLE— on by default on CUDA builds; handles greedy and plain temperature/penalty sampling entirely on GPU, and assists puretop_p.Disable path:
DFLASH_GPU_SAMPLE=0— opt out at runtime; every call falls back to the CPU chain.-DDFLASH_GPU_SAMPLER=OFF(CMake option, defaultON) — dropgeometric_sampler_cuda.cufrom the build entirely.Debug/profiling-only flags:
--samp=temp,top_p,top_k,rep_pen,seed[,freq,pres](test_dflashpositional harness) — exercise the sampler chain instead of greedy decode.DFLASH_SAMP=temp,top_p,top_k,rep_pen,seed[,freq,pres]/DFLASH_N_SAMPLE=N(bench_llm.py) — forward the same sampler tail to every DFlash bench call, and override the per-dataset prompt count.top_k(with or withouttop_p) is intentionally never routed to the GPU— its CPU
partial_sortcost scales withk, not vocab, and a GPU roundtrip (kernel launch + D2H copy) measured as a net regression, not just a
non-win. This is a deliberate, measurement-driven exclusion, not a gap.
Notes
top_psupport directly on the GPU kernel (rather than GPU-assisted) isdeliberately out of scope for this PR;