Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions server/src/common/dflash_spec_decode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ bool run_dflash_spec_decode(
}

int verify_last_tok = -1;
if (!target.verify_batch(draft_tok, committed, verify_last_tok, &target_tok)) {
if (!target.verify_batch(draft_tok, committed, verify_last_tok, &target_tok,
/*capture_ssm_intermediates=*/true)) {
std::fprintf(stderr, "dflash-spec verify failed\n");
// Roll the snapshot back so we don't leak the speculative KV
// mutations into the caller's target cache.
Expand Down Expand Up @@ -210,22 +211,41 @@ bool run_dflash_spec_decode(
if (commit_n <= accept_n) bonus_tok = -1;
}

// ── Replay pass: roll back KV and re-run only the accepted tokens.
if (!target.restore_kv()) {
std::fprintf(stderr, "dflash-spec restore_kv failed\n");
return false;
}
// ── Commit accepted tokens to KV state ──────────────────────────
// Adaptive: use fast-rollback when acceptance is high enough to benefit.
constexpr int kFastRollbackThreshold = 5;
const bool use_fast_rollback =
target.supports_fast_rollback() && (accept_n >= kFastRollbackThreshold);

std::vector<int32_t> replay_tok((size_t)commit_n);
for (int i = 0; i < commit_n; i++) {
replay_tok[i] = (i < accept_n) ? draft_tok[i] : bonus_tok;
}
int replay_last_tok = -1;
if (!target.verify_batch(replay_tok, committed, replay_last_tok, nullptr)) {
std::fprintf(stderr, "dflash-spec replay failed\n");
return false;

if (use_fast_rollback) {
// Fast rollback: restore SSM from intermediates, skip replay.
// Implicit bonus: deferred to next step as draft_tok[0].
bonus_tok = -1;
commit_n = accept_n;
replay_tok.resize(commit_n);
if (!target.rollback_to(committed, accept_n)) {
std::fprintf(stderr, "dflash-spec rollback_to failed\n");
return false;
}
last_tok = target_tok[accept_n - 1];
} else {
// Legacy path: restore SSM snapshot and replay accepted + bonus tokens.
if (!target.restore_kv()) {
std::fprintf(stderr, "dflash-spec restore_kv failed\n");
return false;
}
int replay_last_tok = -1;
if (!target.verify_batch(replay_tok, committed, replay_last_tok, nullptr)) {
std::fprintf(stderr, "dflash-spec replay failed\n");
return false;
}
last_tok = replay_last_tok;
}
last_tok = replay_last_tok;

bool hit_eos = false;
int emitted = 0;
Expand Down
16 changes: 15 additions & 1 deletion server/src/common/dflash_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ struct DFlashTarget {
virtual bool verify_batch(const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax = nullptr) = 0;
std::vector<int32_t> * all_argmax = nullptr,
bool capture_ssm_intermediates = false) = 0;

// ── KV state management ─────────────────────────────────────────

Expand All @@ -42,6 +43,19 @@ struct DFlashTarget {
// Restore KV cache to the last snapshot (undo speculative forward).
virtual bool restore_kv() = 0;

// Whether fast rollback is supported — uses per-step SSM intermediate
// states captured during verify to restore recurrent state without replay.
// When true, verify_batch captures intermediates and rollback_to() works.
virtual bool supports_fast_rollback() const { return false; }

// Roll back recurrent state to position `commit_n` within the last
// verify batch (0-indexed). Uses SSM intermediate states captured during
// verify. Also truncates KV to `base_pos + commit_n`. No replay needed.
// Only valid when supports_fast_rollback() returns true.
virtual bool rollback_to(int base_pos, int commit_n) {
(void)base_pos; (void)commit_n; return false;
}

// ── Token utilities ─────────────────────────────────────────────

// Check if a token is end-of-sequence for this model.
Expand Down
4 changes: 3 additions & 1 deletion server/src/gemma4/gemma4_dflash_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ bool Gemma4DFlashTarget::verify_batch(
const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax) {
std::vector<int32_t> * all_argmax,
bool capture_ssm_intermediates) {
(void)capture_ssm_intermediates; // Gemma4 is pure-attention, no SSM state
const int n_tokens = (int)tokens.size();
if (n_tokens <= 0) return false;

Expand Down
3 changes: 2 additions & 1 deletion server/src/gemma4/gemma4_dflash_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class Gemma4DFlashTarget : public DFlashTarget {
bool verify_batch(const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax = nullptr) override;
std::vector<int32_t> * all_argmax = nullptr,
bool capture_ssm_intermediates = false) override;

// kvflash: route verify writes through the pool (slots allocated here,
// slot-space mask inside gemma4_verify_batch). Non-owning.
Expand Down
56 changes: 43 additions & 13 deletions server/src/qwen35/qwen35_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,10 +676,11 @@ DFlashTarget * Qwen35Backend::dflash_target() {
dflash_target_ = std::make_unique<Qwen35DFlashTarget>(
w_, cache_, target_backend_, sg_,
cfg_.kq_stride_pad, cfg_.fa_window);
auto * qt = static_cast<Qwen35DFlashTarget *>(dflash_target_.get());
if (kvflash_active()) {
static_cast<Qwen35DFlashTarget *>(dflash_target_.get())
->set_kvflash_pager(&kvflash_pager_);
qt->set_kvflash_pager(&kvflash_pager_);
}
qt->set_fast_rollback(true);
}
return dflash_target_.get();
}
Expand Down Expand Up @@ -1850,7 +1851,8 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
}

int verify_last_tok = -1;
if (!target->verify_batch(draft_tok, committed, verify_last_tok, &target_tok)) {
if (!target->verify_batch(draft_tok, committed, verify_last_tok, &target_tok,
/*capture_ssm_intermediates=*/true)) {
std::fprintf(stderr, "spec-decode: verify failed\n");
target->restore_kv();
step_graph_destroy(draft_sg);
Expand All @@ -1875,22 +1877,50 @@ bool Qwen35Backend::do_spec_decode(int committed, int n_gen,
if (commit_n <= accept_n) bonus_tok = -1;
}

// 6. Replay: roll back KV and re-run only accepted tokens
if (!target->restore_kv()) {
step_graph_destroy(draft_sg);
return false;
// 6. Fix state: adaptive fast-rollback vs legacy replay.
// Fast-rollback (implicit bonus, skip replay) is profitable when
// accept_n is large enough that skipping the replay saves more compute
// than the cost of deferring the bonus to the next step. Breakeven
// is around accept_n ≈ 5. Below that, legacy replay is cheaper.
constexpr int kFastRollbackThreshold = 5;
const bool use_fast_rollback =
target->supports_fast_rollback() && (accept_n >= kFastRollbackThreshold);

int replay_last_tok = -1;
if (use_fast_rollback) {
// Fast rollback: restore SSM from captured intermediates, skip replay.
// Implicit bonus: target_tok[accept_n-1] seeds next draft as draft_tok[0],
// always accepted on next step.
bonus_tok = -1;
commit_n = accept_n;
if (!target->rollback_to(committed, accept_n)) {
std::fprintf(stderr, "spec-decode: rollback_to failed\n");
step_graph_destroy(draft_sg);
return false;
}
replay_last_tok = target_tok[accept_n - 1];
} else {
// Legacy replay: restore SSM snapshot, replay accepted + bonus tokens.
if (!target->restore_kv()) {
step_graph_destroy(draft_sg);
return false;
}
std::vector<int32_t> replay_batch((size_t)commit_n);
for (int i = 0; i < commit_n; i++) {
replay_batch[i] = (i < accept_n) ? draft_tok[i] : bonus_tok;
}
if (!target->verify_batch(replay_batch, committed, replay_last_tok, nullptr)) {
std::fprintf(stderr, "spec-decode: replay failed\n");
step_graph_destroy(draft_sg);
return false;
}
}

// Build replay_tok for emitting committed tokens.
std::vector<int32_t> replay_tok((size_t)commit_n);
for (int i = 0; i < commit_n; i++) {
replay_tok[i] = (i < accept_n) ? draft_tok[i] : bonus_tok;
}
int replay_last_tok = -1;
if (!target->verify_batch(replay_tok, committed, replay_last_tok, nullptr)) {
std::fprintf(stderr, "spec-decode: replay failed\n");
step_graph_destroy(draft_sg);
return false;
}

// 7. Sync features for replayed range to mirror (needed for next draft step)
if (use_remote_draft && cache_.target_feat) {
Expand Down
79 changes: 77 additions & 2 deletions server/src/qwen35/qwen35_dflash_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

#include <cstring>

// ggml_get_to_fp32_cuda is not in any public header — it lives in
// ggml-cuda/convert.cuh. Declare here so we can link against it.
using to_fp32_cuda_t = void (*)(const void *, float *, int64_t, cudaStream_t);
extern "C++" to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);

namespace dflash::common {

Qwen35DFlashTarget::~Qwen35DFlashTarget() {
Expand All @@ -30,7 +35,8 @@ bool Qwen35DFlashTarget::verify_batch(
const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax) {
std::vector<int32_t> * all_argmax,
bool capture_ssm_intermediates) {
const int n_tokens = (int)tokens.size();
if (n_tokens <= 0) return false;

Expand All @@ -52,10 +58,12 @@ bool Qwen35DFlashTarget::verify_batch(
}
}

const bool do_capture = fast_rollback_ && capture_ssm_intermediates;

if (!build_target_step(sg_, w_, cache_, backend_,
/*kv_start=*/base_pos, n_tokens,
need_mask, /*capture=*/true,
/*capture_delta_intermediate=*/false,
/*capture_delta_intermediate=*/do_capture,
pool ? 0 : fa_window_,
/*last_token_logits_only=*/false,
kq_stride_pad_,
Expand Down Expand Up @@ -173,6 +181,73 @@ bool Qwen35DFlashTarget::restore_kv() {
return true;
}

bool Qwen35DFlashTarget::supports_fast_rollback() const {
return fast_rollback_;
}

bool Qwen35DFlashTarget::rollback_to(int base_pos, int commit_n) {
if (!fast_rollback_) return false;

const int n_delta = (int)sg_.delta_captures.size();
if (n_delta == 0) return false;

// If all tokens accepted, the SSM state after processing all q_len tokens
// is exactly what we want — no rollback needed, just fix cur_pos.
const int q_len = cache_.cur_pos - base_pos;
if (commit_n >= q_len) {
cache_.cur_pos = base_pos + commit_n;
return true;
}

const int rollback_idx = commit_n - 1; // index into per-step intermediates
cudaStream_t stream = nullptr;

for (int il = 0; il < n_delta; il++) {
const DeltaNetCapture & cap = sg_.delta_captures[il];
if (!cap.ssm_intermediate_states || !cap.conv_input) {
std::fprintf(stderr, "rollback_to: missing capture at layer %d\n", il);
return false;
}

// SSM rollback: copy intermediate[rollback_idx] → cache.ssm_state[il]
const size_t ssm_elems =
(size_t)cache_.ssm_state[il]->ne[0] *
(size_t)cache_.ssm_state[il]->ne[1] *
(size_t)cache_.ssm_state[il]->ne[2];
const size_t ssm_src_offset =
(size_t)rollback_idx * cap.ssm_intermediate_states->nb[3];
const void * ssm_src =
(const char *)cap.ssm_intermediate_states->data + ssm_src_offset;
ggml_get_to_fp32_cuda(cap.ssm_intermediate_states->type)(
ssm_src, (float *)cache_.ssm_state[il]->data,
(int64_t)ssm_elems, stream);

// Conv rollback: copy conv_input[commit_n..commit_n+K-2, :, :]
// into cache.conv_state[il].
const int K_conv = 4;
const int row_cnt = (int)cap.conv_input->ne[1];
const size_t elt = ggml_element_size(cap.conv_input);
const size_t dpitch = (K_conv - 1) * elt;
const size_t spitch = cap.conv_input->nb[1];
const size_t width = (K_conv - 1) * elt;
const void * conv_src =
(const char *)cap.conv_input->data + commit_n * elt;
cudaError_t ce = cudaMemcpy2DAsync(cache_.conv_state[il]->data, dpitch,
conv_src, spitch,
width, row_cnt,
cudaMemcpyDeviceToDevice, stream);
if (ce != cudaSuccess) {
std::fprintf(stderr, "rollback_to: cudaMemcpy2D conv il=%d: %s\n",
il, cudaGetErrorString(ce));
return false;
}
}
cudaStreamSynchronize(stream);

cache_.cur_pos = base_pos + commit_n;
return true;
}

bool Qwen35DFlashTarget::is_eos(int token) const {
return is_eos_tok(token, w_);
}
Expand Down
10 changes: 9 additions & 1 deletion server/src/qwen35/qwen35_dflash_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ class Qwen35DFlashTarget : public DFlashTarget {
bool verify_batch(const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax = nullptr) override;
std::vector<int32_t> * all_argmax = nullptr,
bool capture_ssm_intermediates = false) override;

bool snapshot_kv() override;
bool restore_kv() override;
bool supports_fast_rollback() const override;
bool rollback_to(int base_pos, int commit_n) override;

bool is_eos(int token) const override;

Expand All @@ -62,6 +65,10 @@ class Qwen35DFlashTarget : public DFlashTarget {
// Forces fa_window = 0 (logical windowing is meaningless in slot space).
void set_kvflash_pager(KvFlashPager * pager) { pager_ = pager; }

// Enable fast-rollback mode: verify will capture per-step SSM intermediates
// so rollback_to() can restore recurrent state without replay.
void set_fast_rollback(bool enabled) { fast_rollback_ = enabled; }

private:
TargetWeights & w_;
TargetCache & cache_;
Expand All @@ -70,6 +77,7 @@ class Qwen35DFlashTarget : public DFlashTarget {
int kq_stride_pad_;
int fa_window_;
KvFlashPager * pager_ = nullptr;
bool fast_rollback_ = false;

// Cached vector form of capture layer IDs (built once in constructor).
std::vector<int> capture_ids_;
Expand Down
3 changes: 2 additions & 1 deletion server/src/qwen35/qwen35_layer_split_dflash_target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ bool Qwen35LayerSplitDFlashTarget::verify_batch(
const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax) {
std::vector<int32_t> * all_argmax,
bool capture_ssm_intermediates) {
if (shards_.empty()) return false;
if (remote_target_shard_ && remote_target_shard_->active()) {
return run_qwen35_mixed_layer_split_forward(
Expand Down
3 changes: 2 additions & 1 deletion server/src/qwen35/qwen35_layer_split_dflash_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class Qwen35LayerSplitDFlashTarget : public DFlashTarget {
bool verify_batch(const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok,
std::vector<int32_t> * all_argmax = nullptr) override;
std::vector<int32_t> * all_argmax = nullptr,
bool capture_ssm_intermediates = false) override;

bool snapshot_kv() override;
bool restore_kv() override;
Expand Down
12 changes: 11 additions & 1 deletion server/src/qwen35/qwen35_target_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -959,8 +959,18 @@ static ggml_tensor * build_delta_net_block(
S_v * S_v * r_elt,
S_v * S_v * H_v * r_elt,
inter_offset);
ggml_tensor * dst;
if (n_seq_tokens == (int)cap->ssm_intermediate_states->ne[3]) {
dst = cap->ssm_intermediate_states;
} else {
dst = ggml_view_4d(ctx, cap->ssm_intermediate_states,
S_v, S_v, H_v, n_seq_tokens,
cap->ssm_intermediate_states->nb[1],
cap->ssm_intermediate_states->nb[2],
cap->ssm_intermediate_states->nb[3], 0);
}
ggml_build_forward_expand(gf,
ggml_cpy(ctx, inter_view, cap->ssm_intermediate_states));
ggml_cpy(ctx, inter_view, dst));
}
} // end of block started at `{` before `const int64_t S_v = head_v_dim;`

Expand Down