Skip to content

Commit 7f5af13

Browse files
renxidaIanNod
authored andcommitted
Make config.json consistent between shortfin and sharktank (nod-ai#487)
And remove the adaption layer in buidl_tools/integration_tests/llm/conftest.py
1 parent 8109f39 commit 7f5af13

File tree

5 files changed

+75
-64
lines changed

5 files changed

+75
-64
lines changed

app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -68,35 +68,35 @@ def write_config(request, pre_process_model):
6868
batch_sizes = request.param["batch_sizes"]
6969
prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"]
7070

71-
logger.info("Writing config file..." + start_log_group("Writing config file"))
72-
71+
# Construct the new config filename
7372
config_path = (
7473
pre_process_model
7574
/ f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json"
7675
)
7776

78-
config = {
79-
"module_name": "module",
80-
"module_abi_version": 1,
81-
"max_seq_len": 131072,
82-
"attn_head_count": 8,
83-
"attn_head_dim": 128,
84-
"prefill_batch_sizes": batch_sizes,
85-
"decode_batch_sizes": batch_sizes,
86-
"transformer_block_count": 32,
87-
"paged_kv_cache": {
88-
"block_seq_stride": 16,
89-
"device_block_count": 256,
90-
"prefix_sharing_algorithm": prefix_sharing_algorithm,
91-
},
92-
}
77+
# Read the base config file
78+
base_config_path = pre_process_model / "config.json"
79+
with open(base_config_path, "r") as f:
80+
config = json.load(f)
81+
82+
# Override specific fields
83+
config.update(
84+
{
85+
"prefill_batch_sizes": batch_sizes,
86+
"decode_batch_sizes": batch_sizes,
87+
"paged_kv_cache": {
88+
**config.get(
89+
"paged_kv_cache", {}
90+
), # Preserve other paged_kv_cache settings
91+
"prefix_sharing_algorithm": prefix_sharing_algorithm,
92+
},
93+
}
94+
)
9395

9496
logger.info(f"Saving edited config to: {config_path}\n")
9597
logger.info(f"Config: {json.dumps(config, indent=2)}")
9698
with open(config_path, "w") as f:
9799
json.dump(config, f)
98-
99-
logger.info("Config file successfully written" + end_log_group())
100100
yield config_path
101101

102102

app_tests/integration_tests/llm/sglang/conftest.py

-15
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,6 @@ def pre_process_model(request, tmp_path_factory):
6464
device_settings,
6565
)
6666

67-
config = {
68-
"module_name": "module",
69-
"module_abi_version": 1,
70-
"max_seq_len": 131072,
71-
"attn_head_count": 8,
72-
"attn_head_dim": 128,
73-
"prefill_batch_sizes": [1, 4],
74-
"decode_batch_sizes": [1, 4],
75-
"transformer_block_count": 32,
76-
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
77-
}
78-
config_path = tmp_dir / "config.json"
79-
with open(config_path, "w") as f:
80-
json.dump(config, f)
81-
8267
return tmp_dir
8368

8469

app_tests/integration_tests/llm/shortfin/conftest.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -87,26 +87,30 @@ def write_config(request, model_test_dir):
8787
batch_sizes = request.param["batch_sizes"]
8888
prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"]
8989

90+
# Construct the new config filename
9091
config_path = (
9192
model_test_dir
9293
/ f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json"
9394
)
9495

95-
config = {
96-
"module_name": "module",
97-
"module_abi_version": 1,
98-
"max_seq_len": 2048,
99-
"attn_head_count": 32,
100-
"attn_head_dim": 100,
101-
"prefill_batch_sizes": batch_sizes,
102-
"decode_batch_sizes": batch_sizes,
103-
"transformer_block_count": 26,
104-
"paged_kv_cache": {
105-
"block_seq_stride": 16,
106-
"device_block_count": 256,
107-
"prefix_sharing_algorithm": prefix_sharing_algorithm,
108-
},
109-
}
96+
# Read the base config file
97+
base_config_path = model_test_dir / "config.json"
98+
with open(base_config_path, "r") as f:
99+
config = json.load(f)
100+
101+
# Override specific fields
102+
config.update(
103+
{
104+
"prefill_batch_sizes": batch_sizes,
105+
"decode_batch_sizes": batch_sizes,
106+
"paged_kv_cache": {
107+
**config.get(
108+
"paged_kv_cache", {}
109+
), # Preserve other paged_kv_cache settings
110+
"prefix_sharing_algorithm": prefix_sharing_algorithm,
111+
},
112+
}
113+
)
110114
logger.info(f"Saving edited config to: {config_path}\n")
111115
logger.info(f"Config: {json.dumps(config, indent=2)}")
112116
with open(config_path, "w") as f:

sharktank/sharktank/examples/export_paged_llm_v1.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""Export support for the PagedLLMV1 protocol of models."""
88

99
import json
10+
from typing import Any, Dict
1011
import torch
1112

1213
from iree.turbine.aot import *
@@ -86,17 +87,29 @@ def main():
8687
else:
8788
model = PagedLlamaModelV1(dataset.root_theta, llama_config)
8889

89-
def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]):
90+
def generate_params_json(
91+
hp: LlamaHParams, prefill_bs: list[int], decode_bs: list[int]
92+
) -> Dict[str, Any]:
93+
"""
94+
Generate config.json for shortfin.
95+
96+
97+
For shortfin, we only write attention_head_count_kv because that's all shortfin needs.
98+
Note that this is different from hp.attn_head_count when grouped attention shares kvcache between heads.
99+
"""
90100
return {
91101
"module_name": "module",
92102
"module_abi_version": 1,
93103
"max_seq_len": hp.context_length,
94-
"attn_head_count": hp.attention_head_count,
95104
"attn_head_dim": hp.attn_head_dim,
96105
"prefill_batch_sizes": prefill_bs,
97106
"decode_batch_sizes": decode_bs,
98107
"transformer_block_count": hp.block_count,
99-
"block_seq_stride": llama_config.block_seq_stride,
108+
"paged_kv_cache": {
109+
"attention_head_count_kv": hp.attention_head_count_kv,
110+
"block_seq_stride": llama_config.block_seq_stride,
111+
"device_block_count": 256, # so that this makes its way into the config file & can be edited.
112+
},
100113
}
101114

102115
# Unrolling cache updates by batch row makes dynamo sad without an

shortfin/python/shortfin_apps/llm/components/config_struct.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,20 @@
1111
In a typical transformer model, the KV cache is organized similar to (mapped to
1212
our parameter names below):
1313
k = tensor.empty(transformer_block_count, batch_size, seq,
14-
attn_head_count, attn_head_dim)
14+
attn_head_count_kv, attn_head_dim)
1515
v = ...
1616
1717
For context, a popular model has parameters of:
1818
attn_dtype_size = 2 # (fp16)
1919
max_seq_len = 2048
2020
transformer_block_count = 32
21-
attn_head_count = 32
21+
attn_head_count_kv = 32
2222
attn_head_dim = 128 # (dim / head_count)
2323
2424
If paging, then we primarily care about the organization of a single block, where
2525
a block represents a single position in the sequence for a single item in the batch.
2626
Therefore, it will be organized like:
27-
block = torch.empty(transformer_block_count, 2, attn_head_count, attn_head_dim)
27+
block = torch.empty(transformer_block_count, 2, attn_head_count_kv, attn_head_dim)
2828
2929
In this scenario, we declare that one block holds the KV cache for all transformer
3030
block layers because it reduces the accounting. As such, for the above example,
@@ -80,10 +80,15 @@ def _decode_dtype(name: str) -> sfnp.DType:
8080
class PagedKVCacheParams:
8181
"""Parameters for the paged KV cache."""
8282

83-
# Position stride per attention block
83+
# Tokens per page.
8484
block_seq_stride: int
8585

86+
# Number of attention heads per block. This can be different from the model's
87+
# attention head count due to sharing.
88+
attention_head_count_kv: int
89+
8690
# Size of the cache on each device.
91+
# Default: 256
8792
device_block_count: int
8893

8994
prefix_sharing_algorithm: str = "none" # currently supporting none and trie
@@ -92,19 +97,23 @@ class PagedKVCacheParams:
9297
@dataclass_json(undefined=Undefined.RAISE)
9398
@dataclass
9499
class ModelParams:
95-
"""Parameters for a specific compiled model, sufficient to do cache planning and
96-
invocations."""
100+
"""
101+
Parameters for a specific compiled model, sufficient to do cache planning and
102+
invocations.
103+
104+
Compatibility should be maintained with function generate_params_json in
105+
106+
sharktank/sharktank/examples/export_paged_llm_v1.py
107+
"""
97108

98109
# Maximum length of a sequence including prompt and output.
99110
max_seq_len: int
100111

101-
# Number of transformer blocks.
112+
# Number of transformer layers (aka attention blocks / transformer blocks).
102113
transformer_block_count: int
103114

104-
# Number of attention heads per block.
105-
attn_head_count: int
106-
107-
# Dimensionality of each attention head
115+
# Dimensionality of each attention head. This is the dimensionality of the
116+
# key and value vectors. AKA rope_dimension_count from the GGUF props.
108117
attn_head_dim: int
109118

110119
# Batch sizes that the prefill stage is compiled for. These are expected to be
@@ -159,7 +168,7 @@ def paged_kv_unit_size_elements(self) -> int:
159168
size = 1
160169
size *= self.transformer_block_count
161170
size *= 2 # K and V cache line
162-
size *= self.attn_head_count
171+
size *= self.paged_kv_cache.attention_head_count_kv
163172
size *= self.attn_head_dim
164173
return size
165174

0 commit comments

Comments
 (0)