Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shortfin llm beam search #1011

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
7 changes: 7 additions & 0 deletions shortfin/python/shortfin_apps/llm/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ class ModelParams:
# Cache parameters.
paged_kv_cache: PagedKVCacheParams | None = None

# Beam size for beam_search
# This is currently just a placeholder, so that dataclass doesn't complain.
n_beams: int | None = None

# Size in bytes of the KV cache dtype.
@property
def attn_dtype_size(self) -> int:
Expand Down Expand Up @@ -218,6 +222,9 @@ class ServerParams:
# Program isolation configuration
program_isolation: str = "per_call"

# Decode Strategy configuration
n_beams: int = 1

# Device configuration
device_ids: list[str] = field(default_factory=list)
amdgpu_async_allocations: bool = False
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .base_decode_strategy import DecodeStrategy, DecodeStrategyConfig
from .beam_search_decode_strategy import (
BeamSearchDecodeStrategy,
BeamSearchDecodeStrategyConfig,
)
from .greedy_decode_strategy import GreedyDecodeStrategy


__all__ = [
"DecodeStrategy",
"DecodeStrategyConfig",
"BeamSearchDecodeStrategy",
"BeamSearchDecodeStrategyConfig",
"GreedyDecodeStrategy",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Callable, Union

from ..messages import InferenceExecRequest


@dataclass
class DecodeStrategyConfig:
batcher_callback: Callable[[InferenceExecRequest], None]
streaming_callback: Callable[[Union[int, List[int]]], None]
eos_token_id: int
max_completion_tokens: int


class DecodeStrategy(ABC):
"""Abstract class for implementing decode strategies."""

@property
@abstractmethod
def decode_strategy_config(self) -> DecodeStrategyConfig:
pass

@abstractmethod
async def decode(self) -> List[int]:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
from .base_decode_strategy import DecodeStrategy, DecodeStrategyConfig


from asyncio import gather
from dataclasses import dataclass
from typing import Dict, List, Tuple
from uuid import uuid4

import numpy as np

from ..messages import InferenceExecRequest, InferencePhase


@dataclass
class ExecRequestSelection:
"""Helper class top make `BeamGroup.evaluate_top_k` return cleaner."""

log_prob: float
exec_req: InferenceExecRequest
token: int
min_log_prob: float


class BeamGroup:
def __init__(
self,
beam_group_id: str,
n_beams: int,
temperature: int,
exec_reqs: list[InferenceExecRequest],
):
self.beam_group_id = beam_group_id
self.n_beams = n_beams
self.temperature = temperature
self.exec_reqs = exec_reqs
self.completed_reqs: set[InferenceExecRequest] = set()

async def wait(self):
done_signals = [
req.done for req in self.exec_reqs if req not in self.completed_reqs
]
return await gather(*done_signals)

def _apply_temperature(self, logits: np.array) -> np.array:
if self.temperature != 1.0:
return logits / self.temperature

return logits

def log_softmax(self, logits: np.array) -> np.array:
# TODO: Move this to sfnp.array
c = logits.max()
logsumexp = np.log(np.exp(logits - c).sum())
return logits - c - logsumexp

def topk(
self, logits: np.array, k: int, axis: int
) -> Tuple[List[float], List[int]]:
# TODO: Move this to sfnp.array
indices = np.argpartition(logits, -k, axis=axis)
topk_indices = indices[axis][-k:]
topk_values = logits[axis][topk_indices]

return topk_values, topk_indices

def _get_exec_req_selections(
self,
log_prob_map: Dict[float, tuple[InferenceExecRequest, int]],
min_log_prob: int,
):
# Find the topk tokens across all exec_reqs
sorted_keys = sorted(log_prob_map.keys(), reverse=True)
exec_req_selections: List[ExecRequestSelection] = []
for key in sorted_keys[: self.n_beams - len(self.completed_reqs)]:
exec_req, token = log_prob_map[key]
exec_req_selections.append(
ExecRequestSelection(
# Shift log_probs to the right to avoid large
# negative numbers
log_prob=key - min_log_prob,
exec_req=exec_req,
token=token,
min_log_prob=min_log_prob,
)
)

return exec_req_selections

def evaluate_topk(self) -> List[ExecRequestSelection]:
# TODO: Use temperature when processing logits for better diversity of
# outputs.
exec_reqs = self.exec_reqs

log_prob_map: Dict[float, tuple[InferenceExecRequest, int]] = {}
global_min_log_prob = 0.0
# Find the topk tokens for each req in our beam group
for exec_req in exec_reqs:
if exec_req in self.completed_reqs:
continue
# NOTE: This copy is slow, and part of why this needs to be moved to
# `shortfin.array`
logits = np.array(exec_req.result_logits)
logits = self._apply_temperature(logits)
# Take log_softmax. This is to avoid a req's cumulative probability
# becoming too small, which can lead precision issues.
# This allows us to obtain cumulative probability by summing
# the log_probabilities, instead of multiplying the probabilities.
log_logits = self.log_softmax(logits)
log_logits = np.squeeze(log_logits, 1)
values, tokens = self.topk(log_logits, self.n_beams, -1)
min_log_prob = 0.0
for value, token in zip(values, tokens):
if value < min_log_prob:
min_log_prob = value
cumulative_log_prob = exec_req.cumulative_log_prob + value
log_prob_map[cumulative_log_prob] = (exec_req, token)

if min_log_prob < global_min_log_prob:
global_min_log_prob = min_log_prob

return self._get_exec_req_selections(log_prob_map, global_min_log_prob)

def process_beams(self, eos_token_id):
exec_reqs_selections = self.evaluate_topk()
visited_reqs: Dict[str, InferenceExecRequest] = {}
new_reqs = set()

for selection in exec_reqs_selections:
new_req = selection.exec_req
token = selection.token
if new_req.instance_id not in visited_reqs:
new_req.input_token_ids.append(token)
new_req.output_token_ids.append(token)
new_req.start_position += 1
new_req.accumulated_normalization += abs(selection.min_log_prob)

else:
visited_req = visited_reqs[new_req.instance_id]
new_req = visited_req.replicate_self()
new_req.input_token_ids.append(token)
new_req.output_token_ids.append(token)

new_req.cumulative_log_prob = selection.log_prob
visited_reqs[new_req.instance_id] = new_req
new_reqs.add(new_req)
if token == eos_token_id:
self.completed_reqs.add(new_req)

for req in self.exec_reqs:
if req not in new_reqs:
req.free_cache_pages()

for req in self.completed_reqs:
req.free_cache_pages()

self.exec_reqs = list(new_reqs)

def _final_score(self, exec_req: InferenceExecRequest):
return (
exec_req.cumulative_log_prob - exec_req.accumulated_normalization
) / len(exec_req.output_token_ids)

def find_top_beam(self) -> InferenceExecRequest:
completed_reqs = list(self.completed_reqs)
if not completed_reqs:
completed_reqs = self.exec_reqs
max_score = self._final_score(completed_reqs[0])
selected_req = completed_reqs[0]
for req in completed_reqs[1:]:
score = self._final_score(req)
if score > max_score:
selected_req = req
max_score = score

return selected_req

def __del__(self):
for req in self.exec_reqs:
req.free_cache_pages()

for req in self.completed_reqs:
req.free_cache_pages()


@dataclass
class BeamSearchDecodeStrategyConfig(DecodeStrategyConfig):
n_beams: int
temperature: int
return_top_k: bool = False


class BeamSearchDecodeStrategy(DecodeStrategy):
beam_map: dict[str, BeamGroup] = {}

def __init__(
self,
decode_strategy_config: BeamSearchDecodeStrategyConfig,
):
self._decode_strategy_config = decode_strategy_config

@property
def decode_strategy_config(self):
return self._decode_strategy_config

def create_beam(self, requests: list[InferenceExecRequest]) -> BeamGroup:
beam_group_id = str(uuid4())
for req in requests:
req.beam_group_id = beam_group_id

beam_group = BeamGroup(
beam_group_id,
self.decode_strategy_config.n_beams,
self.decode_strategy_config.temperature,
requests,
)
BeamSearchDecodeStrategy.beam_map[beam_group_id] = beam_group
return beam_group

def delete_beam(self, beam_group_id: str):
beam_group = BeamSearchDecodeStrategy.beam_map[beam_group_id]
del beam_group

async def decode(
self,
exec_req: InferenceExecRequest,
) -> List[int] | List[List[int]]:
config = self.decode_strategy_config
decode_reqs = [exec_req]
for _ in range(config.n_beams - 1):
decode_req = exec_req.replicate_self()
decode_reqs.append(decode_req)

beam_group = self.create_beam(decode_reqs)
for _ in range(config.max_completion_tokens):
if len(beam_group.completed_reqs) == config.n_beams:
break

for exec in beam_group.exec_reqs:
if exec in beam_group.completed_reqs:
continue
exec.reset(InferencePhase.DECODE)
config.batcher_callback(exec)

await beam_group.wait()
beam_group.process_beams(config.eos_token_id)

result = None
if config.return_top_k:
reqs = beam_group.completed_reqs
for req in beam_group.exec_reqs:
reqs.add(req)
result = [req.output_token_ids for req in reqs]

else:
result = beam_group.find_top_beam().output_token_ids

self.delete_beam(beam_group.beam_group_id)
config.streaming_callback(result)
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import shortfin.array as sfnp

from .base_decode_strategy import DecodeStrategy, DecodeStrategyConfig
from ..messages import InferenceExecRequest, InferencePhase


class GreedyDecodeStrategy(DecodeStrategy):
def __init__(
self,
decode_strategy_config: DecodeStrategyConfig,
):
self._decode_strategy_config = decode_strategy_config

@property
def decode_strategy_config(self):
return self._decode_strategy_config

async def decode(
self,
exec_req: InferenceExecRequest,
):
config = self.decode_strategy_config
for _ in range(config.max_completion_tokens):
exec_req.reset(InferencePhase.DECODE)
config.batcher_callback(exec_req)
await exec_req.done
token = sfnp.argmax(exec_req.result_logits)
token_int = token.items[0]
config.streaming_callback(token_int)
if token_int == config.eos_token_id:
break
exec_req.input_token_ids.append(token_int)
exec_req.output_token_ids.append(token_int)
exec_req.start_position += 1
Loading