-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
29 changed files
with
2,047 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# LLM Server and CLI | ||
|
||
This directory contains an LLM inference server, CLI and support components. | ||
|
||
|
||
## Quick start | ||
|
||
``` | ||
python -m shortfin_apps.llm.server --help | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from . import _deps |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from shortfin.support.deps import ShortfinDepNotFoundError | ||
|
||
try: | ||
import tokenizers | ||
except ModuleNotFoundError as e: | ||
raise ShortfinDepNotFoundError(__name__, "tokenizers") from e | ||
|
||
try: | ||
import dataclasses_json | ||
except ModuleNotFoundError as e: | ||
raise ShortfinDepNotFoundError(__name__, "dataclasses-json") from e |
111 changes: 111 additions & 0 deletions
111
libshortfin/python/shortfin_apps/llm/components/cache.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from typing import Sequence | ||
|
||
import logging | ||
import math | ||
import threading | ||
|
||
import shortfin as sf | ||
|
||
from .config_struct import ModelParams, human_size | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AttnPageEntry: | ||
__slots__ = [ | ||
"cache", | ||
"index", | ||
"in_use", | ||
] | ||
|
||
def __init__(self, cache: "AttnPageCache", index: int): | ||
self.cache = cache | ||
self.index = index | ||
self.in_use = False | ||
|
||
def __repr__(self): | ||
return f"Block({self.index}, {'FREE' if not self.in_use else 'BUSY'})" | ||
|
||
|
||
class AttnPageCache: | ||
"""Page table based attention cache. | ||
While internal to a model, the cache is organized with additional structure | ||
per page, outside of the model, it is just a list of pages of a certain | ||
element type and number of elements (all inner dims are flattened). | ||
One page table is allocated per device in a fiber. Currently, this is a | ||
dense allocation with committed memory but in the future, we may just | ||
allocate the address space and lazily populate it with committed memory. | ||
The cache is unique because usage of it can span fibers and concurrency | ||
is implicitly managed at the block level (i.e. freshly acquired blocks | ||
are assumed to be uninitialized and available immediately for use). | ||
It is initialized with a discrete list of fiberd devices from a fiber but | ||
cache usage can be done from any fiber which includes those devices. | ||
""" | ||
|
||
def __init__( | ||
self, *, devices: Sequence[sf.ScopedDevice], model_params: ModelParams | ||
): | ||
self._lock = threading.Lock() | ||
self.devices = list(devices) | ||
self.model_params = model_params | ||
self.page_tables: list[sf.array.device_array] = [] | ||
cache_params = model_params.paged_kv_cache | ||
alloc_page_count = cache_params.device_block_count | ||
|
||
# Setup accounting structs. | ||
self.attn_page_entries = [ | ||
AttnPageEntry(self, i) for i in range(alloc_page_count) | ||
] | ||
self.attn_page_free = list(self.attn_page_entries) | ||
|
||
# Initialize a page table on each device. | ||
assert cache_params is not None, "Model does not have a paged kv cache" | ||
page_table_shape = [ | ||
alloc_page_count, | ||
model_params.paged_kv_block_size_elements, | ||
] | ||
for device in devices: | ||
logging.info( | ||
"Allocating page table (shape=%r, dtype=%r, size=%s) on %r", | ||
page_table_shape, | ||
model_params.attn_dtype, | ||
human_size( | ||
math.prod(page_table_shape) | ||
* model_params.attn_dtype.dense_byte_count | ||
), | ||
device, | ||
) | ||
page_table = sf.array.device_array.for_device( | ||
device, page_table_shape, model_params.attn_dtype | ||
) | ||
self.page_tables.append(page_table) | ||
|
||
def acquire_free_pages(self, count: int) -> list[AttnPageEntry] | None: | ||
with self._lock: | ||
available = len(self.attn_page_free) | ||
if count > available: | ||
return None | ||
return [self.attn_page_free.pop() for _ in range(count)] | ||
|
||
def release_pages(self, pages: list[AttnPageEntry]): | ||
with self._lock: | ||
self.attn_page_free.extend(pages) | ||
|
||
def __repr__(self): | ||
# No need to lock for repr (list is internally synchronized). | ||
free_pages = len(self.attn_page_free) | ||
total_pages = len(self.attn_page_entries) | ||
return ( | ||
f"AttnPageCache({total_pages - free_pages}/{total_pages} pages in use: " | ||
f"{100.0 * free_pages / total_pages}% free)" | ||
) |
185 changes: 185 additions & 0 deletions
185
libshortfin/python/shortfin_apps/llm/components/config_struct.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
# Copyright 2024 Advanced Micro Devices, Inc. | ||
# | ||
# Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
# See https://llvm.org/LICENSE.txt for license information. | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
"""Configuration objects. | ||
Parameters that are intrinsic to a specific model. | ||
In a typical transformer model, the KV cache is organized similar to (mapped to | ||
our parameter names below): | ||
k = tensor.empty(transformer_block_count, batch_size, seq, | ||
attn_head_count, attn_head_dim) | ||
v = ... | ||
For context, a popular model has parameters of: | ||
attn_dtype_size = 2 # (fp16) | ||
max_seq_len = 2048 | ||
transformer_block_count = 32 | ||
attn_head_count = 32 | ||
attn_head_dim = 128 # (dim / head_count) | ||
If paging, then we primarily care about the organization of a single block, where | ||
a block represents a single position in the sequence for a single item in the batch. | ||
Therefore, it will be organized like: | ||
block = torch.empty(transformer_block_count, 2, attn_head_count, attn_head_dim) | ||
In this scenario, we declare that one block holds the KV cache for all transformer | ||
block layers because it reduces the accounting. As such, for the above example, | ||
a single position in the sequence will be 524,288 bytes, assuming a 2-byte element | ||
type. If we choose to block by block_stride=16 positions, each block will be 8MiB. | ||
Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536 | ||
blocks for a total number of sequence positions of 24,576. | ||
These are well-known numbers but are derived above to give a sense of scale. | ||
In order to indirect through to the block cache, we have to provide the index map | ||
to specific invocations: | ||
* Prefill: Prefill is only writing to the blocks from [0:prompt_len], so it will | ||
need write indices of [batch_size, prompt_len // block_stride + 1]. | ||
* Decode step: Decode is auto-regressive, and needs to first compute the new kv | ||
row and then attend over all rows in the cache up to this point in the sequence. | ||
If wanting to avoid dynamic allocation of transients, we can also pool the index | ||
tables based on the maximum batch size and maximum sequence length. Since all | ||
block cache sizes are well within the range of an i16, we will use that for storage. | ||
Therefore, each batch invocation would need a block lookup table of: | ||
byte_size = max_batch_size * (max_seq_len // block_stride) * sizeof(int16_t) | ||
For a max_batch_size of 16, this is 4KiB of block index table lookups per | ||
invocation. We don't have to statically allocate this, but the system is more | ||
predictable if we just reserve what we need. Again, numbers are given to give a | ||
sense of scale only: real workloads will vary. | ||
""" | ||
|
||
from dataclasses import dataclass | ||
from pathlib import Path | ||
|
||
import dataclasses_json | ||
from dataclasses_json import dataclass_json, Undefined | ||
|
||
import shortfin.array as sfnp | ||
|
||
|
||
def _decode_dtype(name: str) -> sfnp.DType: | ||
obj = getattr(sfnp, name, None) | ||
if not isinstance(obj, sfnp.DType): | ||
raise ValueError(f"{name} is not a recognized dtype") | ||
|
||
|
||
dataclasses_json.cfg.global_config.encoders[sfnp.DType] = lambda dt: dt.name | ||
dataclasses_json.cfg.global_config.decoders[sfnp.DType] = _decode_dtype | ||
|
||
|
||
@dataclass_json(undefined=Undefined.RAISE) | ||
@dataclass | ||
class PagedKVCacheParams: | ||
"""Parameters for the paged KV cache.""" | ||
|
||
# Position stride per attention block | ||
block_seq_stride: int | ||
|
||
# Size of the cache on each device. | ||
device_block_count: int | ||
|
||
|
||
@dataclass_json(undefined=Undefined.RAISE) | ||
@dataclass | ||
class ModelParams: | ||
"""Parameters for a specific compiled model, sufficient to do cache planning and | ||
invocations.""" | ||
|
||
# Maximum length of a sequence including prompt and output. | ||
max_seq_len: int | ||
|
||
# Number of transformer blocks. | ||
transformer_block_count: int | ||
|
||
# Number of attention heads per block. | ||
attn_head_count: int | ||
|
||
# Dimensionality of each attention head | ||
attn_head_dim: int | ||
|
||
# Batch sizes that the prefill stage is compiled for. These are expected to be | ||
# functions exported from the model with suffixes of "_bs{batch_size}". Must | ||
# be in ascending order. | ||
prefill_batch_sizes: list[int] | ||
|
||
# Similarly, batch sizes that the decode stage is compiled for. | ||
decode_batch_sizes: list[int] | ||
|
||
# Name of the IREE module implementing the model. | ||
module_name: str = "module" | ||
|
||
# ABI of the module. | ||
module_abi_version: int = 1 | ||
|
||
# The element type of the attention caches. | ||
attn_dtype: sfnp.DType = sfnp.float16 | ||
|
||
# Cache parameters. | ||
paged_kv_cache: PagedKVCacheParams | None = None | ||
|
||
# Size in bytes of the KV cache dtype. | ||
@property | ||
def attn_dtype_size(self) -> int: | ||
assert sfnp.DType.is_byte_aligned() | ||
return sfnp.DType.dense_byte_count() | ||
|
||
@property | ||
def max_prefill_batch_size(self) -> int: | ||
return self.prefill_batch_sizes[-1] | ||
|
||
@property | ||
def max_decode_batch_size(self) -> int: | ||
return self.decode_batch_sizes[-1] | ||
|
||
@property | ||
def max_batch_size(self): | ||
return max(self.max_prefill_batch_size, self.max_decode_batch_size) | ||
|
||
@property | ||
def has_paged_kv_cache(self): | ||
return self.paged_kv_cache is not None | ||
|
||
@property | ||
def paged_kv_unit_size_elements(self) -> int: | ||
"""Size in elements of each cache line in the attention cache. | ||
Each cache line can store a unit position stride. | ||
""" | ||
assert self.has_paged_kv_cache | ||
size = 1 | ||
size *= self.transformer_block_count | ||
size *= 2 # K and V cache line | ||
size *= self.attn_head_count | ||
size *= self.attn_head_dim | ||
return size | ||
|
||
@property | ||
def paged_kv_block_size_elements(self) -> int: | ||
"""Size in elements of each attention block of {block_position_stride} | ||
positions. | ||
""" | ||
assert self.paged_kv_cache is not None | ||
return self.paged_kv_unit_size_elements * self.paged_kv_cache.block_seq_stride | ||
|
||
@staticmethod | ||
def load_json(path: Path | str): | ||
with open(path, "rt") as f: | ||
json_text = f.read() | ||
return ModelParams.from_json(json_text) | ||
|
||
|
||
# From: https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size | ||
def human_size(num, suffix="B"): | ||
for unit in ("", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"): | ||
if abs(num) < 1024.0: | ||
return f"{num:3.1f}{unit}{suffix}" | ||
num /= 1024.0 | ||
return f"{num:.1f}Yi{suffix}" |
Oops, something went wrong.