11
11
In a typical transformer model, the KV cache is organized similar to (mapped to
12
12
our parameter names below):
13
13
k = tensor.empty(transformer_block_count, batch_size, seq,
14
- attn_head_count , attn_head_dim)
14
+ attn_head_count_kv , attn_head_dim)
15
15
v = ...
16
16
17
17
For context, a popular model has parameters of:
18
18
attn_dtype_size = 2 # (fp16)
19
19
max_seq_len = 2048
20
20
transformer_block_count = 32
21
- attn_head_count = 32
21
+ attn_head_count_kv = 32
22
22
attn_head_dim = 128 # (dim / head_count)
23
23
24
24
If paging, then we primarily care about the organization of a single block, where
25
25
a block represents a single position in the sequence for a single item in the batch.
26
26
Therefore, it will be organized like:
27
- block = torch.empty(transformer_block_count, 2, attn_head_count , attn_head_dim)
27
+ block = torch.empty(transformer_block_count, 2, attn_head_count_kv , attn_head_dim)
28
28
29
29
In this scenario, we declare that one block holds the KV cache for all transformer
30
30
block layers because it reduces the accounting. As such, for the above example,
@@ -80,10 +80,15 @@ def _decode_dtype(name: str) -> sfnp.DType:
80
80
class PagedKVCacheParams :
81
81
"""Parameters for the paged KV cache."""
82
82
83
- # Position stride per attention block
83
+ # Tokens per page.
84
84
block_seq_stride : int
85
85
86
+ # Number of attention heads per block. This can be different from the model's
87
+ # attention head count due to sharing.
88
+ attention_head_count_kv : int
89
+
86
90
# Size of the cache on each device.
91
+ # Default: 256
87
92
device_block_count : int
88
93
89
94
prefix_sharing_algorithm : str = "none" # currently supporting none and trie
@@ -92,19 +97,23 @@ class PagedKVCacheParams:
92
97
@dataclass_json (undefined = Undefined .RAISE )
93
98
@dataclass
94
99
class ModelParams :
95
- """Parameters for a specific compiled model, sufficient to do cache planning and
96
- invocations."""
100
+ """
101
+ Parameters for a specific compiled model, sufficient to do cache planning and
102
+ invocations.
103
+
104
+ Compatibility should be maintained with function generate_params_json in
105
+
106
+ sharktank/sharktank/examples/export_paged_llm_v1.py
107
+ """
97
108
98
109
# Maximum length of a sequence including prompt and output.
99
110
max_seq_len : int
100
111
101
- # Number of transformer blocks.
112
+ # Number of transformer layers (aka attention blocks / transformer blocks) .
102
113
transformer_block_count : int
103
114
104
- # Number of attention heads per block.
105
- attn_head_count : int
106
-
107
- # Dimensionality of each attention head
115
+ # Dimensionality of each attention head. This is the dimensionality of the
116
+ # key and value vectors. AKA rope_dimension_count from the GGUF props.
108
117
attn_head_dim : int
109
118
110
119
# Batch sizes that the prefill stage is compiled for. These are expected to be
@@ -159,7 +168,7 @@ def paged_kv_unit_size_elements(self) -> int:
159
168
size = 1
160
169
size *= self .transformer_block_count
161
170
size *= 2 # K and V cache line
162
- size *= self .attn_head_count
171
+ size *= self .paged_kv_cache . attention_head_count_kv
163
172
size *= self .attn_head_dim
164
173
return size
165
174
0 commit comments