Skip to content

Commit

Permalink
Shortfin LLM Direct-to-batcher tests (#987)
Browse files Browse the repository at this point in the history
Where we skip everything before the batcher & directly test batches of
requests to make sure that requests in a batch do not interfere with
each other.
  • Loading branch information
renxida authored Feb 26, 2025
1 parent aa470cc commit 06f5b2a
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 11 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/pkgci_shark_ai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,54 @@ jobs:
name: smoke-test-${{ matrix.name }}
path: smoke-test-${{ matrix.name }}.xml


direct_to_batcher_test:
name: "Direct to Batcher Test (${{ matrix.name }})"
runs-on: ${{ matrix.runs-on }}
strategy:
fail-fast: false
matrix:
include:
- name: cpu
runs-on: azure-cpubuilder-linux-scale
test_device: cpu
python-version: 3.11
- name: amdgpu_rocm_mi300_gfx942
runs-on: linux-mi300-1gpu-ossci
test_device: gfx942
python-version: 3.11
defaults:
run:
shell: bash
env:
PACKAGE_DOWNLOAD_DIR: ${{ github.workspace }}/.packages
VENV_DIR: ${{ github.workspace }}/.venv
steps:
- name: Run rocminfo
if: contains(matrix.test_device, 'gfx')
run: rocminfo
- name: "Checkout Code"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: "Set up environment and install PkgCI Artifacts"
uses: ./.github/actions/pkgci-setup
with:
python-version: ${{matrix.python-version}}
artifact-run-id: ${{ inputs.artifact_run_id }}
- name: Run Direct-to-batcher Test
run: |
source ${VENV_DIR}/bin/activate
pytest -v -s --test_device=${{ matrix.test_device }} \
--junitxml=direct-to-batcher-test-${{ matrix.name }}.xml \
app_tests/integration_tests/llm/shortfin/direct_to_batcher_test.py \
--log-cli-level=INFO
- name: Upload Test Results
if: always()
uses: actions/upload-artifact@v4
with:
name: direct-to-batcher-test-${{ matrix.name }}
path: direct-to-batcher-test-${{ matrix.name }}.xml


integration_test:
name: "Integration Test (${{ matrix.name }})"
runs-on: ${{ matrix.runs-on }}
Expand Down
47 changes: 38 additions & 9 deletions app_tests/integration_tests/llm/server_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from .device_settings import DeviceSettings
from .model_management import ModelArtifacts
from shortfin_apps.llm.components.service import GenerateService
from contextlib import contextmanager


@dataclass
Expand Down Expand Up @@ -58,6 +60,41 @@ def find_available_port() -> int:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]

def get_server_args(self) -> list[str]:
"""Returns the command line arguments to start the server."""
argv = [
f"--tokenizer_json={self.config.artifacts.tokenizer_path}",
f"--model_config={self.config.artifacts.config_path}",
f"--vmfb={self.config.artifacts.vmfb_path}",
f"--parameters={self.config.artifacts.weights_path}",
f"--port={self.port}",
f"--prefix_sharing_algorithm={self.config.prefix_sharing_algorithm}",
]
argv.extend(self.config.device_settings.server_flags)
return argv

@contextmanager
def start_service_only(self) -> GenerateService:
"""Starts a server with only the shortfin_apps.llm.components.serivce.GenerateService."""

argv = self.get_server_args()
from shortfin_apps.llm.server import parse_args

args = parse_args(argv)
if args.tokenizer_config_json is None:
# this is only used for the EOS token
inferred_tokenizer_config_path = args.tokenizer_json.with_name(
args.tokenizer_json.stem + "_config.json"
)
args.tokenizer_config_json = inferred_tokenizer_config_path

from shortfin_apps.llm.components.lifecycle import ShortfinLlmLifecycleManager

lifecycle_manager = ShortfinLlmLifecycleManager(args)

with lifecycle_manager:
yield lifecycle_manager.services["default"]

def start(self) -> None:
"""Starts the server process."""
if self.process is not None:
Expand All @@ -69,15 +106,7 @@ def start(self) -> None:
sys.executable,
"-m",
"shortfin_apps.llm.server",
f"--tokenizer_json={self.config.artifacts.tokenizer_path}",
f"--model_config={self.config.artifacts.config_path}",
f"--vmfb={self.config.artifacts.vmfb_path}",
f"--parameters={self.config.artifacts.weights_path}",
f"--port={self.port}",
f"--prefix_sharing_algorithm={self.config.prefix_sharing_algorithm}",
]
cmd.extend(self.config.device_settings.server_flags)

] + self.get_server_args()
self.process = subprocess.Popen(cmd)
self.wait_for_ready()

Expand Down
17 changes: 17 additions & 0 deletions app_tests/integration_tests/llm/shortfin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,23 @@ def server(model_artifacts, request):
process.wait()


@pytest.fixture(scope="module")
def generate_service(model_artifacts, request):
"""Starts and manages the test server."""
model_config = model_artifacts.model_config

server_config = ServerConfig(
artifacts=model_artifacts,
device_settings=model_config.device_settings,
prefix_sharing_algorithm=request.param.get("prefix_sharing", "none"),
)

server_instance = ServerInstance(server_config)
server_instance.port = 0
with server_instance.start_service_only() as gs:
yield gs


@pytest.fixture(scope="module")
def encoded_prompt(model_artifacts: ModelArtifacts, request) -> list[int]:
tokenizer = Tokenizer.from_file(str(model_artifacts.tokenizer_path))
Expand Down
101 changes: 101 additions & 0 deletions app_tests/integration_tests/llm/shortfin/direct_to_batcher_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import pytest
import numpy as np
import asyncio
import shortfin as sf

from app_tests.integration_tests.llm.server_management import (
ServerInstance,
ServerConfig,
)
from app_tests.integration_tests.llm.model_management import TEST_MODELS, ModelProcessor
from app_tests.integration_tests.llm.device_settings import CPU
from shortfin_apps.llm.components.messages import InferencePhase, InferenceExecRequest


pytestmark = pytest.mark.parametrize(
"model_artifacts,generate_service",
[
["tinystories_llama2_25m", {"prefix_sharing": "none"}],
],
indirect=True,
)


class BatchConsistencyTestProcess(sf.Process):
"""Process to test consistency of results across different batch sizes.
This is necessary because InferenceExecRequest uses shortfin.VoidFuture
which can only be created on a process (which belongs to a fiber that a worker works through).
"""

def __init__(self, service, input_tokens, batch_sizes, max_response_length):
super().__init__(fiber=service.main_fiber)
self.service = service
self.input_tokens = input_tokens
self.batch_sizes = batch_sizes
self.max_response_length = max_response_length
self.results = {} # Store results for each batch size
# TODO: modify the batcher to guarantee the batch we send isn't split by strobe messages

async def run(self):
for batch_size in self.batch_sizes:
batch_results = []
for _ in range(batch_size):
prefill_req = InferenceExecRequest(
phase=InferencePhase.PREFILL,
input_token_ids=self.input_tokens,
rid=f"test-{batch_size}",
)
prefill_req.return_host_array = True
self.service.batcher.submit(prefill_req)
await prefill_req.done
first_token = np.argmax(prefill_req.result_logits.items)
result_sequence = [first_token]

decode_req = prefill_req
for _ in range(self.max_response_length - 1):
decode_req.reset(InferencePhase.DECODE)
decode_req.input_token_ids.append(first_token)
decode_req.start_position += 1
self.service.batcher.submit(decode_req)
await decode_req.done
next_token = np.argmax(decode_req.result_logits.items)
result_sequence.append(next_token)
first_token = next_token

batch_results.append(result_sequence)
decode_req.free_cache_pages()

self.results[batch_size] = batch_results

first_result = batch_results[0]
for result in batch_results[1:]:
assert np.array_equal(
first_result, result
), f"Inconsistent results within batch size {batch_size}"

first_batch_result = self.results[self.batch_sizes[0]][0]
for batch_size in self.batch_sizes[1:]:
assert np.array_equal(
first_batch_result, self.results[batch_size][0]
), f"Inconsistent results between batch sizes {self.batch_sizes[0]} and {batch_size}"


def test_batch_and_nobatch_consistency(model_artifacts, generate_service):
"""
Test that requests produce identical results regardless of batch size.
If this test fails, it means that changing the batch size changes the generation results.
Look for kvcache corruption due to
- improper seq_len / current_position handling in service.py
- improper masking in sharktank
"""
# Create and run the test process
test_process = BatchConsistencyTestProcess(
generate_service,
input_tokens=[1, 2, 3, 4],
batch_sizes=[1, 2, 3, 4],
max_response_length=3,
)
test_process.launch()
7 changes: 5 additions & 2 deletions shortfin/python/shortfin_apps/llm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
}


def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
def parse_args(argv):
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
Expand Down Expand Up @@ -136,8 +136,11 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
choices=["none", "trie"],
help="Algorithm to use for prefix sharing in KV cache",
)
args = parser.parse_args(argv)
return parser.parse_args(argv)


def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
args = parse_args(argv)
if args.tokenizer_config_json is None:
# this is only used for the EOS token
logging.info("Argument `--tokenizer_config_json` is not provided")
Expand Down

0 comments on commit 06f5b2a

Please sign in to comment.