6
6
from typing import TYPE_CHECKING
7
7
if TYPE_CHECKING :
8
8
from exllamav2 .model import ExLlamaV2
9
-
9
+ from exllamav2 import ExLlamaV2Tokenizer
10
10
11
11
class ExLlamaV2CacheBase :
12
12
@@ -204,6 +204,10 @@ def all_tensors(self):
204
204
raise NotImplementedError ()
205
205
206
206
207
+ def reset (self ):
208
+ self .current_seq_len = 0
209
+
210
+
207
211
class ExLlamaV2Cache (ExLlamaV2CacheBase ):
208
212
"""
209
213
FP16 cache
@@ -348,27 +352,31 @@ def all_tensors(self):
348
352
return self .key_states + self .value_states
349
353
350
354
351
- class ExLlamaV2Cache_Q4 (ExLlamaV2CacheBase ):
355
+ class ExLlamaV2Cache_Q (ExLlamaV2CacheBase ):
352
356
"""
353
- Q4 cache. Uses grouped RTN quantization for keys/values
357
+ Q cache. Uses grouped RTN quantization for keys/values
354
358
"""
355
359
360
+ wbits : int
361
+
356
362
def __init__ (self ,
357
363
model : ExLlamaV2 ,
358
364
batch_size : int = 1 ,
359
365
max_seq_len : int = - 1 ,
360
366
copy_from : ExLlamaV2Cache_Q4 | None = None ,
361
- lazy : bool = False ):
367
+ lazy : bool = False ,
368
+ weights_per_byte : int = - 1 ):
362
369
363
- super ().__init__ (model , batch_size , max_seq_len , torch .uint8 , 2 , True )
370
+ super ().__init__ (model , batch_size , max_seq_len , torch .uint8 , weights_per_byte , True )
371
+ cfg = self .model .config
364
372
365
373
self .create_state_tensors (copy_from , lazy )
366
374
367
- # Models with odd key/value dims need to to quantize/dequantize in multi-token blocks. Make sure the quant
375
+ # Models with odd key/value dims need to quantize/dequantize in multi-token blocks. Make sure the quant
368
376
# blocksize aligns with a whole number of tokens
369
377
370
378
Q_CACHE_BLOCKSIZE_Q = 512
371
- kv_dim = model . config . num_key_value_heads * model . config .head_dim
379
+ kv_dim = cfg . num_key_value_heads * cfg .head_dim
372
380
self .q_block = 1
373
381
while (kv_dim * self .q_block ) % Q_CACHE_BLOCKSIZE_Q :
374
382
self .q_block += 1
@@ -380,6 +388,14 @@ def __init__(self,
380
388
if not lazy :
381
389
for device in self .model .get_cache_devices (): self .touch_device (device )
382
390
391
+ # Calibration mode
392
+
393
+ self .calibrated = False
394
+ self .calibrating = False
395
+ self .calibration_rows = [0 ] * cfg .num_hidden_layers
396
+ self .calibration_k = {}
397
+ self .calibration_v = {}
398
+
383
399
384
400
def touch_device (self , device ):
385
401
@@ -410,7 +426,7 @@ def get_kv_state(self,
410
426
offset = a
411
427
width = b - a
412
428
413
- ext_c .q4_to_fp16_kv (
429
+ ext_c .q_to_fp16_kv (
414
430
self .key_states [layer_idx ],
415
431
temp_key_state ,
416
432
self .key_scales [layer_idx ],
@@ -422,8 +438,18 @@ def get_kv_state(self,
422
438
width ,
423
439
page_size ,
424
440
cache_seqlens if cache_seqlens is not None else none_tensor ,
425
- block_table if block_table is not None else none_tensor
441
+ block_table if block_table is not None else none_tensor ,
442
+ # none_tensor,
443
+ # none_tensor
444
+ self .calibration_k [layer_idx ] if self .calibrated else none_tensor ,
445
+ self .calibration_v [layer_idx ] if self .calibrated else none_tensor ,
446
+ self .wbits
426
447
)
448
+
449
+ # if self.calibrated:
450
+ # temp_key_state *= self.calibration_k[layer_idx]
451
+ # temp_value_state *= self.calibration_v[layer_idx]
452
+
427
453
return temp_key_state , temp_value_state
428
454
429
455
@@ -448,7 +474,12 @@ def store_kv_state(self,
448
474
449
475
device = self .model .cache_map [layer_idx ]
450
476
temp_key_state , temp_value_state = self .temp_tensors [device ]
451
- ext_c .fp16_to_q4_kv (
477
+
478
+ # if self.calibrated:
479
+ # temp_key_state /= self.calibration_k[layer_idx]
480
+ # temp_value_state /= self.calibration_v[layer_idx]
481
+
482
+ ext_c .fp16_to_q_kv (
452
483
temp_key_state ,
453
484
self .key_states [layer_idx ],
454
485
self .key_scales [layer_idx ],
@@ -460,9 +491,43 @@ def store_kv_state(self,
460
491
width ,
461
492
page_size ,
462
493
cache_seqlens if cache_seqlens is not None else none_tensor ,
463
- block_table if block_table is not None else none_tensor
494
+ block_table if block_table is not None else none_tensor ,
495
+ # none_tensor,
496
+ # none_tensor
497
+ self .calibration_k [layer_idx ] if self .calibrated else none_tensor ,
498
+ self .calibration_v [layer_idx ] if self .calibrated else none_tensor ,
499
+ self .wbits
464
500
)
465
501
502
+ # Collect calibration data
503
+
504
+ if self .calibrating :
505
+
506
+ cfg = self .model .config
507
+
508
+ if layer_idx not in self .calibration_k :
509
+ self .calibration_k [layer_idx ] = torch .zeros (
510
+ (cfg .num_key_value_heads , cfg .head_dim ,),
511
+ dtype = torch .float ,
512
+ device = temp_key_state .device
513
+ )
514
+ self .calibration_v [layer_idx ] = torch .zeros (
515
+ (cfg .num_key_value_heads , cfg .head_dim ,),
516
+ dtype = torch .float ,
517
+ device = temp_key_state .device
518
+ )
519
+
520
+ b , l , h , d = temp_key_state .shape
521
+ cal_k = self .calibration_k [layer_idx ]
522
+ cal_v = self .calibration_v [layer_idx ]
523
+ cal_k_input = temp_key_state [:, offset :offset + width , :, :].view (b * width , h * d )
524
+ cal_v_input = temp_value_state [:, offset :offset + width , :, :].view (b * width , h * d )
525
+ cal_k_sum = torch .norm (cal_k_input , p = 1 , dim = 0 , dtype = torch .float )
526
+ cal_v_sum = torch .norm (cal_v_input , p = 1 , dim = 0 , dtype = torch .float )
527
+ cal_k .add_ (cal_k_sum .view (h , d ))
528
+ cal_v .add_ (cal_v_sum .view (h , d ))
529
+ self .calibration_rows [layer_idx ] += width
530
+
466
531
467
532
def footprint (self ) -> list [int ]:
468
533
@@ -491,3 +556,69 @@ def all_tensors(self):
491
556
return self .key_states + self .value_states + self .key_scales + self .value_scales
492
557
493
558
559
+ def calibrate (self ,
560
+ tokenizer : ExLlamaV2Tokenizer ,
561
+ num_batches = 8 ,
562
+ num_samples_per_batch = 256
563
+ ):
564
+ """
565
+ Unfinished
566
+ """
567
+
568
+ assert self .max_seq_len >= num_samples_per_batch , \
569
+ f"Cache max_seq_len must be at least { num_samples_per_batch } to calibrate."
570
+
571
+ self .calibrating = True
572
+ torch .manual_seed (123 )
573
+
574
+ for _ in range (num_batches ):
575
+
576
+ input_ids = torch .randint (
577
+ low = 0 ,
578
+ high = tokenizer .get_vocab_size () - 1 ,
579
+ size = (1 , num_samples_per_batch ),
580
+ dtype = torch .long
581
+ )
582
+
583
+ self .reset ()
584
+ self .model .forward (input_ids , preprocess_only = True , cache = self )
585
+
586
+ self .calibrating = False
587
+
588
+ for i in range (self .model .config .num_hidden_layers ):
589
+ cal_k = self .calibration_k [i ] / self .calibration_rows [i ] # self.calibration_k[i].mean()
590
+ cal_v = self .calibration_v [i ] / self .calibration_rows [i ] # self.calibration_v[i].mean()
591
+ cal_k = cal_k ** (1 / 8 )
592
+ cal_v = cal_v ** (1 / 8 )
593
+ cal_k = cal_k .half () * (- 1 )
594
+ cal_v = cal_v .half () * (- 1 )
595
+ self .calibration_k [i ] = cal_k
596
+ self .calibration_v [i ] = cal_v
597
+ self .calibrating = False
598
+ # self.calibrated = True
599
+
600
+
601
+ class ExLlamaV2Cache_Q4 (ExLlamaV2Cache_Q ):
602
+
603
+ def __init__ (self ,
604
+ model : ExLlamaV2 ,
605
+ batch_size : int = 1 ,
606
+ max_seq_len : int = - 1 ,
607
+ copy_from : ExLlamaV2Cache_Q4 | None = None ,
608
+ lazy : bool = False ):
609
+
610
+ super ().__init__ (model , batch_size , max_seq_len , copy_from , lazy , 2 )
611
+ self .wbits = 4
612
+
613
+
614
+ class ExLlamaV2Cache_Q8 (ExLlamaV2Cache_Q ):
615
+
616
+ def __init__ (self ,
617
+ model : ExLlamaV2 ,
618
+ batch_size : int = 1 ,
619
+ max_seq_len : int = - 1 ,
620
+ copy_from : ExLlamaV2Cache_Q4 | None = None ,
621
+ lazy : bool = False ):
622
+
623
+ super ().__init__ (model , batch_size , max_seq_len , copy_from , lazy , 1 )
624
+ self .wbits = 8
0 commit comments