Skip to content

Commit ac17f86

Browse files
authored
SGLang Integration + Accuracy Tests, Restructure app_tests/integration_tests (#570)
# Description This PR implements integration tests for the Shortfin LLM Server w/ the SGLang integration. It uses llama3-8b-instruct on GPU, which is downloaded using sharktank's `hf_datasets` script. The tests server two purposes: 1. Test that the SGLang integration works properly at a functional level. 2. Test that the accuracy of the responses from the shortfin LLM server are consistent. - We have a batch of candidate questions, with expected answers - We have temperature set to `1.0`, so the responses should be deterministic. This test is intended to run every 4 hours, which allows for us to detect degradations in shortfin LLM output accuracy. If we do get a failure due to an accuracy degradation, there will only be a small set of shark-ai/iree commits that could be responsible.
1 parent a7feae8 commit ac17f86

File tree

13 files changed

+575
-3
lines changed

13 files changed

+575
-3
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
name: SGLang Llama Integration Tests
8+
9+
on:
10+
workflow_dispatch:
11+
schedule:
12+
# Run periodically, every 4 hours. This is ran periodically with the
13+
# intent of catching regressions early, and allowing for those
14+
# regressions to be easily triaged to a small subset of commits.
15+
- cron: '0 */4 * * *'
16+
17+
concurrency:
18+
# A PR number if a pull request and otherwise the commit hash. This cancels
19+
# queued and in-progress runs for the same PR (presubmit) or commit
20+
# (postsubmit). The workflow name is prepended to avoid conflicts between
21+
# different workflows.
22+
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
23+
cancel-in-progress: true
24+
25+
jobs:
26+
sglang_bench_serve:
27+
name: "SGLang Integration Tests"
28+
strategy:
29+
matrix:
30+
version: [3.11]
31+
fail-fast: false
32+
runs-on: llama-mi300x-3
33+
defaults:
34+
run:
35+
shell: bash
36+
env:
37+
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
38+
steps:
39+
- name: Get Current Date
40+
id: date
41+
run: echo "::set-output name=date::$(date +'%Y-%m-%d')"
42+
43+
- name: "Setting up Python"
44+
id: setup_python
45+
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
46+
with:
47+
python-version: ${{matrix.version}}
48+
49+
- name: "Checkout Code"
50+
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
51+
52+
- name: Cache Pip Packages
53+
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
54+
id: cache-pip
55+
with:
56+
path: ${{ env.PIP_CACHE_DIR }}
57+
key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }}
58+
59+
- name: Install pip deps
60+
run: |
61+
python -m pip install --no-compile --upgrade pip
62+
# Note: We install in three steps in order to satisfy requirements
63+
# from non default locations first. Installing the PyTorch CPU
64+
# wheels saves multiple minutes and a lot of bandwidth on runner setup.
65+
pip install --no-compile -r pytorch-cpu-requirements.txt
66+
pip install --no-compile -f https://iree.dev/pip-release-links.html --src deps \
67+
-e "git+https://github.com/iree-org/iree-turbine.git#egg=iree-turbine"
68+
pip install --no-compile -r requirements.txt -e sharktank/ shortfin/
69+
70+
# Use newest possible releases to be able to track commits that may
71+
# cause errors.
72+
pip install -f https://iree.dev/pip-release-links.html --upgrade \
73+
iree-base-compiler \
74+
iree-base-runtime \
75+
"numpy<2.0"
76+
77+
- name: Install SGLang
78+
run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python"
79+
80+
- name: Install sentence_transformers
81+
run: pip install sentence_transformers
82+
83+
- name: Run Integration Tests
84+
run: pytest -v app_tests/integration_tests/llm/sglang --log-cli-level=INFO

.github/workflows/ci-shark-ai.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,4 @@ jobs:
7272
iree-base-runtime
7373
7474
- name: Run LLM Integration Tests
75-
run: pytest -v app_tests/integration_tests/llm --log-cli-level=INFO
75+
run: pytest -v app_tests/integration_tests/llm/shortfin --log-cli-level=INFO

app_tests/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

app_tests/benchmark_tests/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import json
8+
import logging
9+
import os
10+
import pytest
11+
12+
from ..utils import (
13+
find_available_port,
14+
start_llm_server,
15+
download_with_hf_datasets,
16+
export_paged_llm_v1,
17+
compile_model,
18+
)
19+
20+
pytest.importorskip("sglang")
21+
import sglang as sgl
22+
from sglang.lang.chat_template import get_chat_template
23+
24+
pytest.importorskip("sentence_transformers")
25+
from sentence_transformers import SentenceTransformer
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
@pytest.fixture(scope="module")
31+
def register_shortfin_backend(available_port):
32+
backend = sgl.Shortfin(
33+
chat_template=get_chat_template("llama-3-instruct"),
34+
base_url=f"http://localhost:{available_port}",
35+
)
36+
sgl.set_default_backend(backend)
37+
38+
39+
@pytest.fixture(scope="module")
40+
def pre_process_model(request, tmp_path_factory):
41+
device_settings = request.param["device_settings"]
42+
tmp_dir = tmp_path_factory.mktemp("sglang_integration_tests")
43+
44+
# Download model
45+
model_params_path = tmp_dir / "meta-llama-3.1-8b-instruct.f16.gguf"
46+
download_with_hf_datasets(tmp_dir, "llama3_8B_fp16")
47+
48+
# Export to mlir
49+
mlir_path = tmp_dir / "model.mlir"
50+
config_path = tmp_dir / "config.json"
51+
batch_sizes = [1, 4]
52+
export_paged_llm_v1(
53+
mlir_path,
54+
config_path,
55+
model_params_path,
56+
batch_sizes,
57+
)
58+
59+
# Compile Model
60+
vmfb_path = tmp_dir / "model.vmfb"
61+
compile_model(
62+
mlir_path,
63+
vmfb_path,
64+
device_settings,
65+
)
66+
67+
config = {
68+
"module_name": "module",
69+
"module_abi_version": 1,
70+
"max_seq_len": 131072,
71+
"attn_head_count": 8,
72+
"attn_head_dim": 128,
73+
"prefill_batch_sizes": [1, 4],
74+
"decode_batch_sizes": [1, 4],
75+
"transformer_block_count": 32,
76+
"paged_kv_cache": {"block_seq_stride": 16, "device_block_count": 256},
77+
}
78+
config_path = tmp_dir / "config.json"
79+
with open(config_path, "w") as f:
80+
json.dump(config, f)
81+
82+
return tmp_dir
83+
84+
85+
@pytest.fixture(scope="module")
86+
def available_port():
87+
return find_available_port()
88+
89+
90+
@pytest.fixture(scope="module")
91+
def start_server(request, pre_process_model, available_port):
92+
os.environ["ROCR_VISIBLE_DEVICES"] = "1"
93+
device_settings = request.param["device_settings"]
94+
95+
export_dir = pre_process_model
96+
97+
tokenizer_path = export_dir / "tokenizer.json"
98+
model_params_path = export_dir / "meta-llama-3.1-8b-instruct.f16.gguf"
99+
vmfb_path = export_dir / "model.vmfb"
100+
config_path = export_dir / "config.json"
101+
102+
logger.info("Starting server...")
103+
server_process = start_llm_server(
104+
available_port,
105+
tokenizer_path,
106+
config_path,
107+
vmfb_path,
108+
model_params_path,
109+
device_settings,
110+
timeout=30,
111+
)
112+
logger.info("Server started")
113+
114+
yield server_process
115+
116+
server_process.terminate()
117+
server_process.wait()
118+
119+
120+
@pytest.fixture(scope="module")
121+
def load_comparison_model():
122+
model = SentenceTransformer("all-MiniLM-L6-v2")
123+
return model

0 commit comments

Comments
 (0)