Skip to content

Commit

Permalink
Use Trie and Base Caching in SGL Benchmark (#658)
Browse files Browse the repository at this point in the history
# Description

Previously we were only benchmarking against the `base` caching
algorithm when we run our SGLang benchmark tests. This PR allows us to
benchmark against both the `base` and `trie` caching algorithms.

Along with that, it adds the same trick from #655 to be able to reuse
model artifacts instead of generating new ones each time, which
drastically decreases the overall time required to run the benchmark
test.

It also includes a patch to run the benchmark script after syncing
SGLang repo today, and ensures that an error is raised when one occurs
  • Loading branch information
stbaione authored Dec 9, 2024
1 parent 7f6de06 commit 2f5bfab
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 15 deletions.
54 changes: 49 additions & 5 deletions app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import hashlib
import json
import logging
import os
import pytest
import sys
Expand All @@ -14,15 +16,32 @@
)
from integration_tests.llm.utils import (
compile_model,
end_log_group,
export_paged_llm_v1,
download_with_hf_datasets,
start_log_group,
)

logger = logging.getLogger(__name__)

MODEL_DIR_CACHE = {}


@pytest.fixture(scope="module")
def pre_process_model(request, tmp_path_factory):
tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test")

logger.info(
"Preparing model artifacts..." + start_log_group("Preparing model artifacts")
)

param_key = hashlib.md5(str(request.param).encode()).hexdigest()
if (directory := MODEL_DIR_CACHE.get(param_key)) is not None:
logger.info(
f"Reusing existing model artifacts directory: {directory}" + end_log_group()
)
return MODEL_DIR_CACHE[param_key]

model_name = request.param["model_name"]
model_param_file_name = request.param["model_param_file_name"]
settings = request.param["settings"]
Expand All @@ -37,6 +56,25 @@ def pre_process_model(request, tmp_path_factory):

export_paged_llm_v1(mlir_path, config_path, model_path, batch_sizes)

compile_model(mlir_path, vmfb_path, settings)

logger.info("Model artifacts setup successfully" + end_log_group())
MODEL_DIR_CACHE[param_key] = tmp_dir
return tmp_dir


@pytest.fixture(scope="module")
def write_config(request, pre_process_model):
batch_sizes = request.param["batch_sizes"]
prefix_sharing_algorithm = request.param["prefix_sharing_algorithm"]

logger.info("Writing config file..." + start_log_group("Writing config file"))

config_path = (
pre_process_model
/ f"{'_'.join(str(bs) for bs in batch_sizes)}_{prefix_sharing_algorithm}.json"
)

config = {
"module_name": "module",
"module_abi_version": 1,
Expand All @@ -46,14 +84,20 @@ def pre_process_model(request, tmp_path_factory):
"prefill_batch_sizes": batch_sizes,
"decode_batch_sizes": batch_sizes,
"transformer_block_count": 32,
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
"paged_kv_cache": {
"block_seq_stride": 16,
"device_block_count": 256,
"prefix_sharing_algorithm": prefix_sharing_algorithm,
},
}
with open(config_path, "w") as file:
json.dump(config, file)

compile_model(mlir_path, vmfb_path, settings)
logger.info(f"Saving edited config to: {config_path}\n")
logger.info(f"Config: {json.dumps(config, indent=2)}")
with open(config_path, "w") as f:
json.dump(config, f)

return tmp_dir
logger.info("Config file successfully written" + end_log_group())
yield config_path


def pytest_addoption(parser):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ def test_sglang_benchmark(request_rate, tokenizer_id, sglang_args, tmp_path_fact
log_jsonl_result(benchmark_args.output_file)
except Exception as e:
logger.error(e)
raise e
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
)

from integration_tests.llm.utils import (
end_log_group,
find_available_port,
start_llm_server,
start_log_group,
)

logger = logging.getLogger(__name__)
Expand All @@ -44,26 +46,39 @@
],
)
@pytest.mark.parametrize(
"pre_process_model",
"pre_process_model,write_config",
[
(
pytest.param(
{
"model_name": "llama3_8B_fp16",
"model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf",
"settings": device_settings,
"batch_sizes": [1, 4],
}
)
},
{"batch_sizes": [1, 4], "prefix_sharing_algorithm": "none"},
),
pytest.param(
{
"model_name": "llama3_8B_fp16",
"model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf",
"settings": device_settings,
"batch_sizes": [1, 4],
},
{"batch_sizes": [1, 4], "prefix_sharing_algorithm": "trie"},
),
],
indirect=True,
)
def test_shortfin_benchmark(request_rate, model_param_file_name, pre_process_model):
def test_shortfin_benchmark(
request_rate, model_param_file_name, pre_process_model, write_config
):
# TODO: Remove when multi-device is fixed
os.environ["ROCR_VISIBLE_DEVICES"] = "1"

tmp_dir = pre_process_model

config_path = tmp_dir / "config.json"
config_path = write_config
prefix_sharing_algorithm = config_path.stem.split("_")[-1]
vmfb_path = tmp_dir / "model.vmfb"
tokenizer_path = tmp_dir / "tokenizer.json"
model_path = tmp_dir / model_param_file_name
Expand All @@ -90,12 +105,17 @@ def test_shortfin_benchmark(request_rate, model_param_file_name, pre_process_mod
)
output_file = (
tmp_dir
/ f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}.jsonl"
/ f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}_{prefix_sharing_algorithm}.jsonl"
)
benchmark_args.output_file = output_file

logger.info("Running SGLang Benchmark with the following args:")
logger.info(benchmark_args)
logger.info(
f"Starting benchmark run with prefix sharing algorith {prefix_sharing_algorithm}..."
+ start_log_group(f"Benchmark run with {prefix_sharing_algorithm} algorithm")
)
logger.info("Running SGLang Benchmark with the following settings:")
logger.info(f"Prefix sharing algorith: {prefix_sharing_algorithm}")
logger.info(f"Benchmark Args: {benchmark_args}")
try:
start = time.time()
with patch.object(bench_serving, "print", side_effect=logger.info):
Expand All @@ -107,8 +127,9 @@ def test_shortfin_benchmark(request_rate, model_param_file_name, pre_process_mod
benchmark_process.join()

logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds")
logger.info("======== RESULTS ========")
logger.info("\n\n======== RESULTS ========")
log_jsonl_result(benchmark_args.output_file)
logger.info("Benchmark run successful" + end_log_group())
except Exception as e:
logger.error(e)

Expand Down
2 changes: 2 additions & 0 deletions app_tests/benchmark_tests/llm/sglang_benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def as_namespace(self) -> Namespace:
disable_tqdm=False,
disable_stream=False,
disable_ignore_eos=False,
lora_name=None,
profile=False,
)

def __repr__(self):
Expand Down

0 comments on commit 2f5bfab

Please sign in to comment.