Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shortfin llm beam search #1011

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

stbaione
Copy link
Contributor

Initial implementation of beam_search for the LLM server. Putting it up as a draft for now to get some feedback on the code, and because it needs unit/integration tests. Too big/important of a change to merge without it.

The beam_search specific logic is contained with beam_manager.py, while generate.py just contains logic for managing the InferenceExecRequests and orchestrating the overall flow.

This also keeps all of the logic above the Batcher level, which minimized changes a lot from what I previously tried.

At a high level, the idea is that we:

  1. Run prefill normally. Get the kvcache initialized and obtain first token.
  2. Create n_beams InferenceExecRequests. Here we replicate the initial req that we used for prefill, including replicating the KVCaches pages.
  3. Group each of these reqs under a BeamGroup. This is a helper class that handles actually performing beam_search token selection, and tracking the reqs.
  4. Submit all reqs to batch, and wait for them all to finish.
  5. For each req select the top_k tokens, based on cumulative log probability.
  6. Select the overall top_k tokens, from all reqs.
  7. Update our beams. Do any replication/beam collapses if needed.
  8. If a req generates an eos_token, add it to the set of completed_reqs.
  9. Repeat until either, max_completion_tokens is reached, or all beams generated an eos_token
  10. When we return, we either select the highest cumulative_probability beam, or return all beams, depending on the request params.

Selecting top-k

Extra attention in this section would be appreciated.

For each beam:

<begin_loop>

  1. Obtain logits from decode invocation.
  2. Apply log_softmax to logits. By taking the log of softmax, we can use addition to track the cumulative probabilities, instead of multiplication with the raw probabilities. If we did multiplication with raw probabilities, our cumulative probabilities will become smaller and smaller, until we lose precision. source (search for 'From products to addition`)
  3. Select the top_k values and tokens from the -1 axis. Track the cumulative log probs for each possible token.
    <end_loop>

We then return the top possible selections, in sorted order, based on which potential tokens would yield the beams with the highest probability.

However, this achieves a stable implementation of unsharded beam search for varying `n_beams` and varying batch sizes.

Both an initial implementation, and a checkpoint
Cleanup `BeamGroup.process_beams`
Ensure beam_group gets deleted
@stbaione stbaione requested a review from rsuderman February 27, 2025 21:09
src_view = page_table.view(src_page.index)
dst_view = page_table.view(dst_page.index)
# Copy the data
dst_view.copy_from(src_view)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't be copying every page but only incomplete pages. I.e. the final page if it is incomplete. The rest should be tracked with the retain count system.

@renxida should be able to point to it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Messaged @renxida and asked if he can provide context for the page copying. I think I see what you're saying though. If a page is full, we just need to read from it, not write to it. So, we should be able to share the full pages, which is all except the last one. That last non-full one is where we may actually see differences, so needs to be copied?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should do last-page copying for now, but the eventual goal in my mind was to have a kv cache so good that we can just request the cache for pages corresponding to the tokens and the cache would automatically figure out we only need to copy the last page.

self.rid,
)
new_exec_req.start_position = self.start_position
result_logits: sfnp.device_array = self.result_logits.for_transfer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It concerns me that we see result_logits being retained / surviving. If we do want it to survive this buffer should be held via retainment and not copied.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate on the issue and diff between retainment and coped? I think there's some background/context that I'm missing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two queries may have some overlapping sections. E.g. when we initially branch there may be 4 blocks of contents but the first 3 are identical. The idea with the page table is that we do not duplicate the first blocks that match, we just use the same pages. The retainment idea is we track how many queries use each page and only release it once no active queries use it.

This comment was marked as duplicate.

stbaione added 6 commits March 3, 2025 15:42
…arge negative numbers,

Track accumulated normalization through decode process,
Apply length correction when selecting top beam
using `DecodeStrategy` base class
Make sorting `ExecRequestSelection` a private function for easier testing
Allow `n_beams` as an argument in server_config, server CLI, and `sampling_params`,
Use `lock` when copying cache pages to prevent race condition,
Still trying to figure out only copying the last page
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants