Skip to content

Commit 9ca79d5

Browse files
authored
kv cache slot search improvements (ggml-org#3493)
* kv cache slot search improvements * Use n_ctx in kv find slot for consistency * Ensure kv cache head points to a valid slot in llama_decode internal * Add some comments to prevent dumb people (like me) from getting confused.
1 parent 0c731ca commit 9ca79d5

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

llama.cpp

+35-6
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,9 @@ struct llama_kv_cell {
10821082
struct llama_kv_cache {
10831083
bool has_shift = false;
10841084

1085+
// Note: The value of head isn't only used to optimize searching
1086+
// for a free KV slot. llama_decode_internal also uses it, so it
1087+
// cannot be freely changed after a slot has been allocated.
10851088
uint32_t head = 0;
10861089
uint32_t size = 0;
10871090

@@ -1339,6 +1342,8 @@ static bool llama_kv_cache_init(
13391342

13401343
// find an empty slot of size "n_tokens" in the cache
13411344
// updates the cache head
1345+
// Note: On success, it's important that cache.head points
1346+
// to the first cell of the slot.
13421347
static bool llama_kv_cache_find_slot(
13431348
struct llama_kv_cache & cache,
13441349
const struct llama_batch & batch) {
@@ -1354,8 +1359,8 @@ static bool llama_kv_cache_find_slot(
13541359

13551360
while (true) {
13561361
if (cache.head + n_tokens > n_ctx) {
1362+
n_tested += n_ctx - cache.head;
13571363
cache.head = 0;
1358-
n_tested += n_ctx - cache.head;
13591364
continue;
13601365
}
13611366

@@ -1406,13 +1411,18 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
14061411
cache.cells[i].pos = -1;
14071412
cache.cells[i].seq_id.clear();
14081413
}
1414+
1415+
// Searching for a free slot can start here since we know it will be empty.
1416+
cache.head = uint32_t(c0);
14091417
}
14101418

14111419
static void llama_kv_cache_seq_rm(
14121420
struct llama_kv_cache & cache,
14131421
llama_seq_id seq_id,
14141422
llama_pos p0,
14151423
llama_pos p1) {
1424+
uint32_t new_head = cache.size;
1425+
14161426
if (p0 < 0) p0 = 0;
14171427
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
14181428

@@ -1421,9 +1431,13 @@ static void llama_kv_cache_seq_rm(
14211431
cache.cells[i].seq_id.erase(seq_id);
14221432
if (cache.cells[i].seq_id.empty()) {
14231433
cache.cells[i].pos = -1;
1434+
if (new_head == cache.size) new_head = i;
14241435
}
14251436
}
14261437
}
1438+
1439+
// If we freed up a slot, set head to it so searching can start there.
1440+
if (new_head != cache.size) cache.head = new_head;
14271441
}
14281442

14291443
static void llama_kv_cache_seq_cp(
@@ -1435,6 +1449,8 @@ static void llama_kv_cache_seq_cp(
14351449
if (p0 < 0) p0 = 0;
14361450
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
14371451

1452+
cache.head = 0;
1453+
14381454
for (uint32_t i = 0; i < cache.size; ++i) {
14391455
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
14401456
cache.cells[i].seq_id.insert(seq_id_dst);
@@ -1443,12 +1459,18 @@ static void llama_kv_cache_seq_cp(
14431459
}
14441460

14451461
static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
1462+
uint32_t new_head = cache.size;
1463+
14461464
for (uint32_t i = 0; i < cache.size; ++i) {
14471465
if (!cache.cells[i].has_seq_id(seq_id)) {
14481466
cache.cells[i].pos = -1;
14491467
cache.cells[i].seq_id.clear();
1468+
if (new_head == cache.size) new_head = i;
14501469
}
14511470
}
1471+
1472+
// If we freed up a slot, set head to it so searching can start there.
1473+
if (new_head != cache.size) cache.head = new_head;
14521474
}
14531475

14541476
static void llama_kv_cache_seq_shift(
@@ -1457,6 +1479,8 @@ static void llama_kv_cache_seq_shift(
14571479
llama_pos p0,
14581480
llama_pos p1,
14591481
llama_pos delta) {
1482+
uint32_t new_head = cache.size;
1483+
14601484
if (p0 < 0) p0 = 0;
14611485
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
14621486

@@ -1466,12 +1490,17 @@ static void llama_kv_cache_seq_shift(
14661490
if (cache.cells[i].pos < 0) {
14671491
cache.cells[i].pos = -1;
14681492
cache.cells[i].seq_id.clear();
1493+
if (new_head == cache.size) new_head = i;
14691494
} else {
14701495
cache.has_shift = true;
14711496
cache.cells[i].delta = delta;
14721497
}
14731498
}
14741499
}
1500+
1501+
// If we freed up a slot, set head to it so searching can start there.
1502+
// Otherwise we just start the next search from the beginning.
1503+
cache.head = new_head != cache.size ? new_head : 0;
14751504
}
14761505

14771506
//
@@ -4492,10 +4521,6 @@ static int llama_decode_internal(
44924521
batch.seq_id = seq_id.data();
44934522
}
44944523

4495-
// we always start to search for a free slot from the start of the cache
4496-
// TODO: better strategies can be implemented
4497-
kv_self.head = 0;
4498-
44994524
if (!llama_kv_cache_find_slot(kv_self, batch)) {
45004525
return 1;
45014526
}
@@ -4581,8 +4606,12 @@ static int llama_decode_internal(
45814606
#endif
45824607

45834608
// update the kv ring buffer
4584-
lctx.kv_self.head += n_tokens;
45854609
lctx.kv_self.has_shift = false;
4610+
lctx.kv_self.head += n_tokens;
4611+
// Ensure kv cache head points to a valid index.
4612+
if (lctx.kv_self.head >= lctx.kv_self.size) {
4613+
lctx.kv_self.head = 0;
4614+
}
45864615

45874616
#ifdef GGML_PERF
45884617
// print timing information per ggml operation (for debugging purposes)

0 commit comments

Comments
 (0)