-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shortfin llm beam search #1011
base: main
Are you sure you want to change the base?
Shortfin llm beam search #1011
Conversation
However, this achieves a stable implementation of unsharded beam search for varying `n_beams` and varying batch sizes. Both an initial implementation, and a checkpoint
Cleanup `BeamGroup.process_beams`
…in-llm-beam-search
Ensure beam_group gets deleted
src_view = page_table.view(src_page.index) | ||
dst_view = page_table.view(dst_page.index) | ||
# Copy the data | ||
dst_view.copy_from(src_view) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't be copying every page but only incomplete pages. I.e. the final page if it is incomplete. The rest should be tracked with the retain count system.
@renxida should be able to point to it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Messaged @renxida and asked if he can provide context for the page copying. I think I see what you're saying though. If a page is full, we just need to read from it, not write to it. So, we should be able to share the full pages, which is all except the last one. That last non-full one is where we may actually see differences, so needs to be copied?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should do last-page copying for now, but the eventual goal in my mind was to have a kv cache so good that we can just request the cache for pages corresponding to the tokens and the cache would automatically figure out we only need to copy the last page.
self.rid, | ||
) | ||
new_exec_req.start_position = self.start_position | ||
result_logits: sfnp.device_array = self.result_logits.for_transfer() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It concerns me that we see result_logits being retained / surviving. If we do want it to survive this buffer should be held via retainment and not copied.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate on the issue and diff between retainment and coped? I think there's some background/context that I'm missing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two queries may have some overlapping sections. E.g. when we initially branch there may be 4 blocks of contents but the first 3 are identical. The idea with the page table is that we do not duplicate the first blocks that match, we just use the same pages. The retainment idea is we track how many queries use each page and only release it once no active queries use it.
This comment was marked as duplicate.
This comment was marked as duplicate.
Sorry, something went wrong.
…arge negative numbers, Track accumulated normalization through decode process, Apply length correction when selecting top beam
using `DecodeStrategy` base class
Make sorting `ExecRequestSelection` a private function for easier testing
Allow `n_beams` as an argument in server_config, server CLI, and `sampling_params`, Use `lock` when copying cache pages to prevent race condition, Still trying to figure out only copying the last page
Initial implementation of
beam_search
for the LLM server. Putting it up as a draft for now to get some feedback on the code, and because it needs unit/integration tests. Too big/important of a change to merge without it.The beam_search specific logic is contained with
beam_manager.py
, whilegenerate.py
just contains logic for managing theInferenceExecRequest
s and orchestrating the overall flow.This also keeps all of the logic above the Batcher level, which minimized changes a lot from what I previously tried.
At a high level, the idea is that we:
n_beams
InferenceExecRequests. Here wereplicate
the initial req that we used for prefill, including replicating the KVCaches pages.BeamGroup
. This is a helper class that handles actually performing beam_search token selection, and tracking the reqs.top_k
tokens, based on cumulative log probability.top_k
tokens, from all reqs.eos_token
, add it to the set ofcompleted_reqs
.max_completion_tokens
is reached, or all beams generated aneos_token
Selecting top-k
Extra attention in this section would be appreciated.
For each beam:
<begin_loop>
log_softmax
to logits. By taking the log of softmax, we can use addition to track the cumulative probabilities, instead of multiplication with the raw probabilities. If we did multiplication with raw probabilities, our cumulative probabilities will become smaller and smaller, until we lose precision. source (search for 'From products to addition`)top_k
values and tokens from the-1
axis. Track the cumulative log probs for each possible token.<end_loop>
We then return the top possible selections, in sorted order, based on which potential tokens would yield the beams with the highest probability.