diff --git a/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py index 1e1c64b24..7f822b0e0 100644 --- a/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py @@ -37,20 +37,6 @@ def pre_process_model(request, tmp_path_factory): export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes) - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 131072, - "attn_head_count": 8, - "attn_head_dim": 128, - "prefill_batch_sizes": batch_sizes, - "decode_batch_sizes": batch_sizes, - "transformer_block_count": 32, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - with open(config_path, "w") as file: - json.dump(config, file) - compile_model(mlir_path, vmfb_path, settings) return tmp_dir diff --git a/app_tests/integration_tests/llm/sglang/conftest.py b/app_tests/integration_tests/llm/sglang/conftest.py index 8543708da..cc79fc365 100644 --- a/app_tests/integration_tests/llm/sglang/conftest.py +++ b/app_tests/integration_tests/llm/sglang/conftest.py @@ -64,21 +64,6 @@ def pre_process_model(request, tmp_path_factory): device_settings, ) - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 131072, - "attn_head_count": 8, - "attn_head_dim": 128, - "prefill_batch_sizes": [1, 4], - "decode_batch_sizes": [1, 4], - "transformer_block_count": 32, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - config_path = tmp_dir / "config.json" - with open(config_path, "w") as f: - json.dump(config, f) - return tmp_dir diff --git a/app_tests/integration_tests/llm/shortfin/conftest.py b/app_tests/integration_tests/llm/shortfin/conftest.py index 0d40119c7..7619d3805 100644 --- a/app_tests/integration_tests/llm/shortfin/conftest.py +++ b/app_tests/integration_tests/llm/shortfin/conftest.py @@ -72,23 +72,6 @@ def model_test_dir(request, tmp_path_factory): vmfb_path = tmp_dir / "model.vmfb" compile_model(mlir_path, vmfb_path, settings) - # Write config - edited_config_path = tmp_dir / "edited_config.json" - config = { - "module_name": "module", - "module_abi_version": 1, - "max_seq_len": 2048, - "attn_head_count": 32, - "attn_head_dim": 100, - "prefill_batch_sizes": batch_sizes, - "decode_batch_sizes": batch_sizes, - "transformer_block_count": 26, - "paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256}, - } - logger.info(f"Saving edited config to: {edited_config_path}\n") - logger.info(f"Config: {json.dumps(config, indent=2)}") - with open(edited_config_path, "w") as f: - json.dump(config, f) logger.info("Model artifacts setup successfully" + end_log_group()) yield hf_home, tmp_dir finally: @@ -120,7 +103,7 @@ def llm_server(request, model_test_dir, available_port): settings = request.param["settings"] tokenizer_path = hf_home / "tokenizer.json" - config_path = tmp_dir / "edited_config.json" + config_path = tmp_dir / "config.json" vmfb_path = tmp_dir / "model.vmfb" parameters_path = hf_home / model_file diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 6dd9785c3..7f35387ca 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -96,7 +96,11 @@ def generate_params_json(hp, prefill_bs: list[int], decode_bs: list[int]): "prefill_batch_sizes": prefill_bs, "decode_batch_sizes": decode_bs, "transformer_block_count": hp.block_count, - "block_seq_stride": llama_config.block_seq_stride, + "paged_kv_cache": { + "attention_head_count_kv": hp.attention_head_count_kv, + "block_seq_stride": llama_config.block_seq_stride, + "device_block_count": 256, # so that this makes its way into the config file & can be edited. + }, } # Unrolling cache updates by batch row makes dynamo sad without an diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 141c7a7eb..34a3386e9 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -11,20 +11,20 @@ In a typical transformer model, the KV cache is organized similar to (mapped to our parameter names below): k = tensor.empty(transformer_block_count, batch_size, seq, - attn_head_count, attn_head_dim) + attn_head_count_kv, attn_head_dim) v = ... For context, a popular model has parameters of: attn_dtype_size = 2 # (fp16) max_seq_len = 2048 transformer_block_count = 32 - attn_head_count = 32 + attn_head_count_kv = 32 attn_head_dim = 128 # (dim / head_count) If paging, then we primarily care about the organization of a single block, where a block represents a single position in the sequence for a single item in the batch. Therefore, it will be organized like: - block = torch.empty(transformer_block_count, 2, attn_head_count, attn_head_dim) + block = torch.empty(transformer_block_count, 2, attn_head_count_kv, attn_head_dim) In this scenario, we declare that one block holds the KV cache for all transformer block layers because it reduces the accounting. As such, for the above example, @@ -80,6 +80,7 @@ def _decode_dtype(name: str) -> sfnp.DType: class PagedKVCacheParams: """Parameters for the paged KV cache.""" + attention_head_count_kv: int # Position stride per attention block block_seq_stride: int @@ -90,8 +91,13 @@ class PagedKVCacheParams: @dataclass_json(undefined=Undefined.RAISE) @dataclass class ModelParams: - """Parameters for a specific compiled model, sufficient to do cache planning and - invocations.""" + """ + Parameters for a specific compiled model, sufficient to do cache planning and + invocations. + + Compatibility should be maintained with function generate_params_json in + sharktank/sharktank/examples/export_paged_llm_v1.py + """ # Maximum length of a sequence including prompt and output. max_seq_len: int @@ -157,7 +163,7 @@ def paged_kv_unit_size_elements(self) -> int: size = 1 size *= self.transformer_block_count size *= 2 # K and V cache line - size *= self.attn_head_count + size *= self.paged_kv_cache.attention_head_count_kv size *= self.attn_head_dim return size