Skip to content

Commit fb61a81

Browse files
committed
Add Q8 cache mode
1 parent cd75438 commit fb61a81

File tree

11 files changed

+568
-243
lines changed

11 files changed

+568
-243
lines changed

exllamav2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from exllamav2.cache import ExLlamaV2CacheBase
55
from exllamav2.cache import ExLlamaV2Cache
66
from exllamav2.cache import ExLlamaV2Cache_Q4
7+
from exllamav2.cache import ExLlamaV2Cache_Q8
78
from exllamav2.cache import ExLlamaV2Cache_8bit
89
from exllamav2.config import ExLlamaV2Config
910
from exllamav2.tokenizer.tokenizer import ExLlamaV2Tokenizer

exllamav2/cache.py

Lines changed: 142 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import TYPE_CHECKING
77
if TYPE_CHECKING:
88
from exllamav2.model import ExLlamaV2
9-
9+
from exllamav2 import ExLlamaV2Tokenizer
1010

1111
class ExLlamaV2CacheBase:
1212

@@ -204,6 +204,10 @@ def all_tensors(self):
204204
raise NotImplementedError()
205205

206206

207+
def reset(self):
208+
self.current_seq_len = 0
209+
210+
207211
class ExLlamaV2Cache(ExLlamaV2CacheBase):
208212
"""
209213
FP16 cache
@@ -348,27 +352,31 @@ def all_tensors(self):
348352
return self.key_states + self.value_states
349353

350354

351-
class ExLlamaV2Cache_Q4(ExLlamaV2CacheBase):
355+
class ExLlamaV2Cache_Q(ExLlamaV2CacheBase):
352356
"""
353-
Q4 cache. Uses grouped RTN quantization for keys/values
357+
Q cache. Uses grouped RTN quantization for keys/values
354358
"""
355359

360+
wbits: int
361+
356362
def __init__(self,
357363
model: ExLlamaV2,
358364
batch_size: int = 1,
359365
max_seq_len: int = -1,
360366
copy_from: ExLlamaV2Cache_Q4 | None = None,
361-
lazy: bool = False):
367+
lazy: bool = False,
368+
weights_per_byte: int = -1):
362369

363-
super().__init__(model, batch_size, max_seq_len, torch.uint8, 2, True)
370+
super().__init__(model, batch_size, max_seq_len, torch.uint8, weights_per_byte, True)
371+
cfg = self.model.config
364372

365373
self.create_state_tensors(copy_from, lazy)
366374

367-
# Models with odd key/value dims need to to quantize/dequantize in multi-token blocks. Make sure the quant
375+
# Models with odd key/value dims need to quantize/dequantize in multi-token blocks. Make sure the quant
368376
# blocksize aligns with a whole number of tokens
369377

370378
Q_CACHE_BLOCKSIZE_Q = 512
371-
kv_dim = model.config.num_key_value_heads * model.config.head_dim
379+
kv_dim = cfg.num_key_value_heads * cfg.head_dim
372380
self.q_block = 1
373381
while (kv_dim * self.q_block) % Q_CACHE_BLOCKSIZE_Q:
374382
self.q_block += 1
@@ -380,6 +388,14 @@ def __init__(self,
380388
if not lazy:
381389
for device in self.model.get_cache_devices(): self.touch_device(device)
382390

391+
# Calibration mode
392+
393+
self.calibrated = False
394+
self.calibrating = False
395+
self.calibration_rows = [0] * cfg.num_hidden_layers
396+
self.calibration_k = {}
397+
self.calibration_v = {}
398+
383399

384400
def touch_device(self, device):
385401

@@ -410,7 +426,7 @@ def get_kv_state(self,
410426
offset = a
411427
width = b - a
412428

413-
ext_c.q4_to_fp16_kv(
429+
ext_c.q_to_fp16_kv(
414430
self.key_states[layer_idx],
415431
temp_key_state,
416432
self.key_scales[layer_idx],
@@ -422,8 +438,18 @@ def get_kv_state(self,
422438
width,
423439
page_size,
424440
cache_seqlens if cache_seqlens is not None else none_tensor,
425-
block_table if block_table is not None else none_tensor
441+
block_table if block_table is not None else none_tensor,
442+
# none_tensor,
443+
# none_tensor
444+
self.calibration_k[layer_idx] if self.calibrated else none_tensor,
445+
self.calibration_v[layer_idx] if self.calibrated else none_tensor,
446+
self.wbits
426447
)
448+
449+
# if self.calibrated:
450+
# temp_key_state *= self.calibration_k[layer_idx]
451+
# temp_value_state *= self.calibration_v[layer_idx]
452+
427453
return temp_key_state, temp_value_state
428454

429455

@@ -448,7 +474,12 @@ def store_kv_state(self,
448474

449475
device = self.model.cache_map[layer_idx]
450476
temp_key_state, temp_value_state = self.temp_tensors[device]
451-
ext_c.fp16_to_q4_kv(
477+
478+
# if self.calibrated:
479+
# temp_key_state /= self.calibration_k[layer_idx]
480+
# temp_value_state /= self.calibration_v[layer_idx]
481+
482+
ext_c.fp16_to_q_kv(
452483
temp_key_state,
453484
self.key_states[layer_idx],
454485
self.key_scales[layer_idx],
@@ -460,9 +491,43 @@ def store_kv_state(self,
460491
width,
461492
page_size,
462493
cache_seqlens if cache_seqlens is not None else none_tensor,
463-
block_table if block_table is not None else none_tensor
494+
block_table if block_table is not None else none_tensor,
495+
# none_tensor,
496+
# none_tensor
497+
self.calibration_k[layer_idx] if self.calibrated else none_tensor,
498+
self.calibration_v[layer_idx] if self.calibrated else none_tensor,
499+
self.wbits
464500
)
465501

502+
# Collect calibration data
503+
504+
if self.calibrating:
505+
506+
cfg = self.model.config
507+
508+
if layer_idx not in self.calibration_k:
509+
self.calibration_k[layer_idx] = torch.zeros(
510+
(cfg.num_key_value_heads, cfg.head_dim,),
511+
dtype = torch.float,
512+
device = temp_key_state.device
513+
)
514+
self.calibration_v[layer_idx] = torch.zeros(
515+
(cfg.num_key_value_heads, cfg.head_dim,),
516+
dtype = torch.float,
517+
device = temp_key_state.device
518+
)
519+
520+
b, l, h, d = temp_key_state.shape
521+
cal_k = self.calibration_k[layer_idx]
522+
cal_v = self.calibration_v[layer_idx]
523+
cal_k_input = temp_key_state[:, offset:offset+width, :, :].view(b * width, h * d)
524+
cal_v_input = temp_value_state[:, offset:offset+width, :, :].view(b * width, h * d)
525+
cal_k_sum = torch.norm(cal_k_input, p = 1, dim = 0, dtype = torch.float)
526+
cal_v_sum = torch.norm(cal_v_input, p = 1, dim = 0, dtype = torch.float)
527+
cal_k.add_(cal_k_sum.view(h, d))
528+
cal_v.add_(cal_v_sum.view(h, d))
529+
self.calibration_rows[layer_idx] += width
530+
466531

467532
def footprint(self) -> list[int]:
468533

@@ -491,3 +556,69 @@ def all_tensors(self):
491556
return self.key_states + self.value_states + self.key_scales + self.value_scales
492557

493558

559+
def calibrate(self,
560+
tokenizer: ExLlamaV2Tokenizer,
561+
num_batches = 8,
562+
num_samples_per_batch = 256
563+
):
564+
"""
565+
Unfinished
566+
"""
567+
568+
assert self.max_seq_len >= num_samples_per_batch, \
569+
f"Cache max_seq_len must be at least {num_samples_per_batch} to calibrate."
570+
571+
self.calibrating = True
572+
torch.manual_seed(123)
573+
574+
for _ in range(num_batches):
575+
576+
input_ids = torch.randint(
577+
low = 0,
578+
high = tokenizer.get_vocab_size() - 1,
579+
size = (1, num_samples_per_batch),
580+
dtype = torch.long
581+
)
582+
583+
self.reset()
584+
self.model.forward(input_ids, preprocess_only = True, cache = self)
585+
586+
self.calibrating = False
587+
588+
for i in range(self.model.config.num_hidden_layers):
589+
cal_k = self.calibration_k[i] / self.calibration_rows[i] # self.calibration_k[i].mean()
590+
cal_v = self.calibration_v[i] / self.calibration_rows[i] # self.calibration_v[i].mean()
591+
cal_k = cal_k ** (1/8)
592+
cal_v = cal_v ** (1/8)
593+
cal_k = cal_k.half() * (-1)
594+
cal_v = cal_v.half() * (-1)
595+
self.calibration_k[i] = cal_k
596+
self.calibration_v[i] = cal_v
597+
self.calibrating = False
598+
# self.calibrated = True
599+
600+
601+
class ExLlamaV2Cache_Q4(ExLlamaV2Cache_Q):
602+
603+
def __init__(self,
604+
model: ExLlamaV2,
605+
batch_size: int = 1,
606+
max_seq_len: int = -1,
607+
copy_from: ExLlamaV2Cache_Q4 | None = None,
608+
lazy: bool = False):
609+
610+
super().__init__(model, batch_size, max_seq_len, copy_from, lazy, 2)
611+
self.wbits = 4
612+
613+
614+
class ExLlamaV2Cache_Q8(ExLlamaV2Cache_Q):
615+
616+
def __init__(self,
617+
model: ExLlamaV2,
618+
batch_size: int = 1,
619+
max_seq_len: int = -1,
620+
copy_from: ExLlamaV2Cache_Q4 | None = None,
621+
lazy: bool = False):
622+
623+
super().__init__(model, batch_size, max_seq_len, copy_from, lazy, 1)
624+
self.wbits = 8

0 commit comments

Comments
 (0)