@@ -1082,6 +1082,9 @@ struct llama_kv_cell {
1082
1082
struct llama_kv_cache {
1083
1083
bool has_shift = false ;
1084
1084
1085
+ // Note: The value of head isn't only used to optimize searching
1086
+ // for a free KV slot. llama_decode_internal also uses it, so it
1087
+ // cannot be freely changed after a slot has been allocated.
1085
1088
uint32_t head = 0 ;
1086
1089
uint32_t size = 0 ;
1087
1090
@@ -1339,6 +1342,8 @@ static bool llama_kv_cache_init(
1339
1342
1340
1343
// find an empty slot of size "n_tokens" in the cache
1341
1344
// updates the cache head
1345
+ // Note: On success, it's important that cache.head points
1346
+ // to the first cell of the slot.
1342
1347
static bool llama_kv_cache_find_slot (
1343
1348
struct llama_kv_cache & cache,
1344
1349
const struct llama_batch & batch) {
@@ -1354,8 +1359,8 @@ static bool llama_kv_cache_find_slot(
1354
1359
1355
1360
while (true ) {
1356
1361
if (cache.head + n_tokens > n_ctx) {
1362
+ n_tested += n_ctx - cache.head ;
1357
1363
cache.head = 0 ;
1358
- n_tested += n_ctx - cache.head ;
1359
1364
continue ;
1360
1365
}
1361
1366
@@ -1406,13 +1411,18 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
1406
1411
cache.cells [i].pos = -1 ;
1407
1412
cache.cells [i].seq_id .clear ();
1408
1413
}
1414
+
1415
+ // Searching for a free slot can start here since we know it will be empty.
1416
+ cache.head = uint32_t (c0);
1409
1417
}
1410
1418
1411
1419
static void llama_kv_cache_seq_rm (
1412
1420
struct llama_kv_cache & cache,
1413
1421
llama_seq_id seq_id,
1414
1422
llama_pos p0,
1415
1423
llama_pos p1) {
1424
+ uint32_t new_head = cache.size ;
1425
+
1416
1426
if (p0 < 0 ) p0 = 0 ;
1417
1427
if (p1 < 0 ) p1 = std::numeric_limits<llama_pos>::max ();
1418
1428
@@ -1421,9 +1431,13 @@ static void llama_kv_cache_seq_rm(
1421
1431
cache.cells [i].seq_id .erase (seq_id);
1422
1432
if (cache.cells [i].seq_id .empty ()) {
1423
1433
cache.cells [i].pos = -1 ;
1434
+ if (new_head == cache.size ) new_head = i;
1424
1435
}
1425
1436
}
1426
1437
}
1438
+
1439
+ // If we freed up a slot, set head to it so searching can start there.
1440
+ if (new_head != cache.size ) cache.head = new_head;
1427
1441
}
1428
1442
1429
1443
static void llama_kv_cache_seq_cp (
@@ -1435,6 +1449,8 @@ static void llama_kv_cache_seq_cp(
1435
1449
if (p0 < 0 ) p0 = 0 ;
1436
1450
if (p1 < 0 ) p1 = std::numeric_limits<llama_pos>::max ();
1437
1451
1452
+ cache.head = 0 ;
1453
+
1438
1454
for (uint32_t i = 0 ; i < cache.size ; ++i) {
1439
1455
if (cache.cells [i].has_seq_id (seq_id_src) && cache.cells [i].pos >= p0 && cache.cells [i].pos < p1) {
1440
1456
cache.cells [i].seq_id .insert (seq_id_dst);
@@ -1443,12 +1459,18 @@ static void llama_kv_cache_seq_cp(
1443
1459
}
1444
1460
1445
1461
static void llama_kv_cache_seq_keep (struct llama_kv_cache & cache, llama_seq_id seq_id) {
1462
+ uint32_t new_head = cache.size ;
1463
+
1446
1464
for (uint32_t i = 0 ; i < cache.size ; ++i) {
1447
1465
if (!cache.cells [i].has_seq_id (seq_id)) {
1448
1466
cache.cells [i].pos = -1 ;
1449
1467
cache.cells [i].seq_id .clear ();
1468
+ if (new_head == cache.size ) new_head = i;
1450
1469
}
1451
1470
}
1471
+
1472
+ // If we freed up a slot, set head to it so searching can start there.
1473
+ if (new_head != cache.size ) cache.head = new_head;
1452
1474
}
1453
1475
1454
1476
static void llama_kv_cache_seq_shift (
@@ -1457,6 +1479,8 @@ static void llama_kv_cache_seq_shift(
1457
1479
llama_pos p0,
1458
1480
llama_pos p1,
1459
1481
llama_pos delta) {
1482
+ uint32_t new_head = cache.size ;
1483
+
1460
1484
if (p0 < 0 ) p0 = 0 ;
1461
1485
if (p1 < 0 ) p1 = std::numeric_limits<llama_pos>::max ();
1462
1486
@@ -1466,12 +1490,17 @@ static void llama_kv_cache_seq_shift(
1466
1490
if (cache.cells [i].pos < 0 ) {
1467
1491
cache.cells [i].pos = -1 ;
1468
1492
cache.cells [i].seq_id .clear ();
1493
+ if (new_head == cache.size ) new_head = i;
1469
1494
} else {
1470
1495
cache.has_shift = true ;
1471
1496
cache.cells [i].delta = delta;
1472
1497
}
1473
1498
}
1474
1499
}
1500
+
1501
+ // If we freed up a slot, set head to it so searching can start there.
1502
+ // Otherwise we just start the next search from the beginning.
1503
+ cache.head = new_head != cache.size ? new_head : 0 ;
1475
1504
}
1476
1505
1477
1506
//
@@ -4492,10 +4521,6 @@ static int llama_decode_internal(
4492
4521
batch.seq_id = seq_id.data ();
4493
4522
}
4494
4523
4495
- // we always start to search for a free slot from the start of the cache
4496
- // TODO: better strategies can be implemented
4497
- kv_self.head = 0 ;
4498
-
4499
4524
if (!llama_kv_cache_find_slot (kv_self, batch)) {
4500
4525
return 1 ;
4501
4526
}
@@ -4581,8 +4606,12 @@ static int llama_decode_internal(
4581
4606
#endif
4582
4607
4583
4608
// update the kv ring buffer
4584
- lctx.kv_self .head += n_tokens;
4585
4609
lctx.kv_self .has_shift = false ;
4610
+ lctx.kv_self .head += n_tokens;
4611
+ // Ensure kv cache head points to a valid index.
4612
+ if (lctx.kv_self .head >= lctx.kv_self .size ) {
4613
+ lctx.kv_self .head = 0 ;
4614
+ }
4586
4615
4587
4616
#ifdef GGML_PERF
4588
4617
// print timing information per ggml operation (for debugging purposes)
0 commit comments