Skip to content

Commit 2f5bfab

Browse files
authored
Use Trie and Base Caching in SGL Benchmark (#658)
# Description Previously we were only benchmarking against the `base` caching algorithm when we run our SGLang benchmark tests. This PR allows us to benchmark against both the `base` and `trie` caching algorithms. Along with that, it adds the same trick from #655 to be able to reuse model artifacts instead of generating new ones each time, which drastically decreases the overall time required to run the benchmark test. It also includes a patch to run the benchmark script after syncing SGLang repo today, and ensures that an error is raised when one occurs
1 parent 7f6de06 commit 2f5bfab

File tree

4 files changed

+83
-15
lines changed

4 files changed

+83
-15
lines changed

app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
import hashlib
78
import json
9+
import logging
810
import os
911
import pytest
1012
import sys
@@ -14,15 +16,32 @@
1416
)
1517
from integration_tests.llm.utils import (
1618
compile_model,
19+
end_log_group,
1720
export_paged_llm_v1,
1821
download_with_hf_datasets,
22+
start_log_group,
1923
)
2024

25+
logger = logging.getLogger(__name__)
26+
27+
MODEL_DIR_CACHE = {}
28+
2129

2230
@pytest.fixture(scope="module")
2331
def pre_process_model(request, tmp_path_factory):
2432
tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test")
2533

34+
logger.info(
35+
"Preparing model artifacts..." + start_log_group("Preparing model artifacts")
36+
)
37+
38+
param_key = hashlib.md5(str(request.param).encode()).hexdigest()
39+
if (directory := MODEL_DIR_CACHE.get(param_key)) is not None:
40+
logger.info(
41+
f"Reusing existing model artifacts directory: {directory}" + end_log_group()
42+
)
43+
return MODEL_DIR_CACHE[param_key]
44+
2645
model_name = request.param["model_name"]
2746
model_param_file_name = request.param["model_param_file_name"]
2847
settings = request.param["settings"]
@@ -37,6 +56,25 @@ def pre_process_model(request, tmp_path_factory):
3756

3857
export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)
3958

59+
compile_model(mlir_path, vmfb_path, settings)
60+
61+
logger.info("Model artifacts setup successfully" + end_log_group())
62+
MODEL_DIR_CACHE[param_key] = tmp_dir
63+
return tmp_dir
64+
65+
66+
@pytest.fixture(scope="module")
67+
def write_config(request, pre_process_model):
68+
batch_sizes = request.param["batch_sizes"]
69+
prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"]
70+
71+
logger.info("Writing config file..." + start_log_group("Writing config file"))
72+
73+
config_path = (
74+
pre_process_model
75+
/ f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json"
76+
)
77+
4078
config = {
4179
"module_name": "module",
4280
"module_abi_version": 1,
@@ -46,14 +84,20 @@ def pre_process_model(request, tmp_path_factory):
4684
"prefill_batch_sizes": batch_sizes,
4785
"decode_batch_sizes": batch_sizes,
4886
"transformer_block_count": 32,
49-
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
87+
"paged_kv_cache": {
88+
"block_seq_stride": 16,
89+
"device_block_count": 256,
90+
"prefix_sharing_algorithm": prefix_sharing_algorithm,
91+
},
5092
}
51-
with open(config_path, "w") as file:
52-
json.dump(config, file)
5393

54-
compile_model(mlir_path, vmfb_path, settings)
94+
logger.info(f"Saving edited config to: {config_path}\n")
95+
logger.info(f"Config: {json.dumps(config, indent=2)}")
96+
with open(config_path, "w") as f:
97+
json.dump(config, f)
5598

56-
return tmp_dir
99+
logger.info("Config file successfully written" + end_log_group())
100+
yield config_path
57101

58102

59103
def pytest_addoption(parser):

app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,4 @@ def test_sglang_benchmark(request_rate, tokenizer_id, sglang_args, tmp_path_fact
6565
log_jsonl_result(benchmark_args.output_file)
6666
except Exception as e:
6767
logger.error(e)
68+
raise e

app_tests/benchmark_tests/llm/sglang_benchmarks/shortfin_benchmark_test.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
)
2222

2323
from integration_tests.llm.utils import (
24+
end_log_group,
2425
find_available_port,
2526
start_llm_server,
27+
start_log_group,
2628
)
2729

2830
logger = logging.getLogger(__name__)
@@ -44,26 +46,39 @@
4446
],
4547
)
4648
@pytest.mark.parametrize(
47-
"pre_process_model",
49+
"pre_process_model,write_config",
4850
[
49-
(
51+
pytest.param(
5052
{
5153
"model_name": "llama3_8B_fp16",
5254
"model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf",
5355
"settings": device_settings,
5456
"batch_sizes": [1, 4],
55-
}
56-
)
57+
},
58+
{"batch_sizes": [1, 4], "prefix_sharing_algorithm": "none"},
59+
),
60+
pytest.param(
61+
{
62+
"model_name": "llama3_8B_fp16",
63+
"model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf",
64+
"settings": device_settings,
65+
"batch_sizes": [1, 4],
66+
},
67+
{"batch_sizes": [1, 4], "prefix_sharing_algorithm": "trie"},
68+
),
5769
],
5870
indirect=True,
5971
)
60-
def test_shortfin_benchmark(request_rate, model_param_file_name, pre_process_model):
72+
def test_shortfin_benchmark(
73+
request_rate, model_param_file_name, pre_process_model, write_config
74+
):
6175
# TODO: Remove when multi-device is fixed
6276
os.environ["ROCR_VISIBLE_DEVICES"] = "1"
6377

6478
tmp_dir = pre_process_model
6579

66-
config_path = tmp_dir / "config.json"
80+
config_path = write_config
81+
prefix_sharing_algorithm = config_path.stem.split("_")[-1]
6782
vmfb_path = tmp_dir / "model.vmfb"
6883
tokenizer_path = tmp_dir / "tokenizer.json"
6984
model_path = tmp_dir / model_param_file_name
@@ -90,12 +105,17 @@ def test_shortfin_benchmark(request_rate, model_param_file_name, pre_process_mod
90105
)
91106
output_file = (
92107
tmp_dir
93-
/ f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}.jsonl"
108+
/ f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}_{prefix_sharing_algorithm}.jsonl"
94109
)
95110
benchmark_args.output_file = output_file
96111

97-
logger.info("Running SGLang Benchmark with the following args:")
98-
logger.info(benchmark_args)
112+
logger.info(
113+
f"Starting benchmark run with prefix sharing algorith {prefix_sharing_algorithm}..."
114+
+ start_log_group(f"Benchmark run with {prefix_sharing_algorithm} algorithm")
115+
)
116+
logger.info("Running SGLang Benchmark with the following settings:")
117+
logger.info(f"Prefix sharing algorith: {prefix_sharing_algorithm}")
118+
logger.info(f"Benchmark Args: {benchmark_args}")
99119
try:
100120
start = time.time()
101121
with patch.object(bench_serving, "print", side_effect=logger.info):
@@ -107,8 +127,9 @@ def test_shortfin_benchmark(request_rate, model_param_file_name, pre_process_mod
107127
benchmark_process.join()
108128

109129
logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds")
110-
logger.info("======== RESULTS ========")
130+
logger.info("\n\n======== RESULTS ========")
111131
log_jsonl_result(benchmark_args.output_file)
132+
logger.info("Benchmark run successful" + end_log_group())
112133
except Exception as e:
113134
logger.error(e)
114135

app_tests/benchmark_tests/llm/sglang_benchmarks/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def as_namespace(self) -> Namespace:
4848
disable_tqdm=False,
4949
disable_stream=False,
5050
disable_ignore_eos=False,
51+
lora_name=None,
52+
profile=False,
5153
)
5254

5355
def __repr__(self):

0 commit comments

Comments
 (0)