Skip to content

Commit

Permalink
LLM Batch inference (#1202)
Browse files Browse the repository at this point in the history
* add batch mode and openai batch implementation

Signed-off-by: Henry Lindeman <[email protected]>

* add anthropic batch mode

Signed-off-by: Henry Lindeman <[email protected]>

* mypy

Signed-off-by: Henry Lindeman <[email protected]>

* review comments. drop tqdms

Signed-off-by: Henry Lindeman <[email protected]>

---------

Signed-off-by: Henry Lindeman <[email protected]>
  • Loading branch information
HenryL27 authored Feb 28, 2025
1 parent f4f3032 commit 889844a
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 2 deletions.
39 changes: 39 additions & 0 deletions lib/sycamore/sycamore/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Optional, Union
import asyncio
import random
import time

from PIL import Image

Expand All @@ -15,6 +16,7 @@

DEFAULT_MAX_TOKENS = 1000
INITIAL_BACKOFF = 1
BATCH_POLL_INTERVAL = 10


class AnthropicModels(Enum):
Expand Down Expand Up @@ -203,6 +205,7 @@ async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[d
start = datetime.now()
done = False
retries = 0
response = None
while not done:
try:
response = await self._async_client.messages.create(model=self.model.value, **kwargs)
Expand All @@ -218,3 +221,39 @@ async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[d

self._llm_cache_set(prompt, llm_kwargs, ret)
return ret["output"]

def generate_batch(self, *, prompts: list[RenderedPrompt], llm_kwargs: Optional[dict] = None) -> list[str]:
from anthropic.types.message_create_params import MessageCreateParamsNonStreaming
from anthropic.types.messages.batch_create_params import Request

cache_hits = [self._llm_cache_get(p, llm_kwargs) for p in prompts]

calls = []
for p, ch, i in zip(prompts, cache_hits, range(len(prompts))):
if ch is not None:
continue
kwargs = get_generate_kwargs(p, llm_kwargs)
kwargs["model"] = self.model.value
kwargs["max_tokens"] = kwargs.get("max_tokens", 1024)
mparams = MessageCreateParamsNonStreaming(**kwargs) # type: ignore
rq = Request(custom_id=str(i), params=mparams)
calls.append(rq)

starttime = datetime.now()
batch = self._client.messages.batches.create(requests=calls)

while batch.processing_status == "in_progress":
time.sleep(BATCH_POLL_INTERVAL)
batch = self._client.messages.batches.retrieve(batch.id)

results = self._client.messages.batches.results(batch.id)
for rs, call in zip(results, calls):
if rs.result.type != "succeeded":
raise ValueError(f"Call failed: {rs}")
id = int(rs.custom_id)
in_kwargs = get_generate_kwargs(prompts[id], llm_kwargs)
ret = self._metadata_from_response(in_kwargs, rs.result.message, starttime)
cache_hits[id] = ret
self._llm_cache_set(prompts[id], llm_kwargs, ret)

return [ch["output"] for ch in cache_hits]
4 changes: 4 additions & 0 deletions lib/sycamore/sycamore/llms/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ async def generate_async_old(self, *, prompt_kwargs: dict[str, Any], llm_kwargs:
raise ValueError("Either 'prompt' or 'messages' must be specified in prompt_kwargs")
return await self.generate_async(prompt=rendered, llm_kwargs=llm_kwargs)

def generate_batch(self, *, prompts: list[RenderedPrompt], llm_kwargs: Optional[dict] = None) -> list[str]:
"""Generates a series of responses from the LLM for the given series of prompts. Order is preserved."""
raise NotImplementedError("This LLM does not support batched generation")

def __str__(self):
return f"{self.__class__.__name__}({self._model_name})"

Expand Down
52 changes: 51 additions & 1 deletion lib/sycamore/sycamore/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from datetime import datetime
import asyncio
import random
import json
import io
import time

from openai import AzureOpenAI as AzureOpenAIClient
from openai import AsyncAzureOpenAI as AsyncAzureOpenAIClient
Expand All @@ -18,7 +21,7 @@
from openai.lib.azure import AzureADTokenProvider
from openai.lib._parsing import type_to_response_format_param
from openai import APIConnectionError

from openai.types.chat.chat_completion import ChatCompletion

import pydantic

Expand All @@ -33,6 +36,7 @@
# Base URL for Helicone API, if configured using the SYCAMORE_HELICONE_API_KEY environment variable.
HELICONE_BASE_URL = "https://oai.helicone.ai/v1"
INITIAL_BACKOFF = 0.2
BATCH_POLL_INTERVAL = 10


class OpenAIClientType(Enum):
Expand Down Expand Up @@ -488,3 +492,49 @@ async def _generate_awaitable_using_openai_structured(
# 1.) The LLM ran out of output context length(usually do to hallucination of repeating the same phrase)
# 2.) The LLM refused to respond to the request because it did not meet guidelines
raise e

def generate_batch(self, *, prompts: list[RenderedPrompt], llm_kwargs: Optional[dict] = None) -> list[str]:
cache_hits = [self._llm_cache_get(p, llm_kwargs) for p in prompts]

calls = []
for p, ch, i in zip(prompts, cache_hits, range(len(prompts))):
if ch is not None:
continue
kwargs = self._get_generate_kwargs(p, llm_kwargs)
kwargs["model"] = self.model.name
call = {"custom_id": str(i), "method": "POST", "url": "/v1/chat/completions", "body": kwargs}
calls.append(call)
f = io.BytesIO()
for i, c in enumerate(calls):
f.write(json.dumps(c).encode("utf-8"))
if i != len(calls) - 1:
f.write(b"\n")
client = self.client_wrapper.get_client()
starttime = datetime.now()
batch_in_file = client.files.create(file=f, purpose="batch")
batch = client.batches.create(
input_file_id=batch_in_file.id, endpoint="/v1/chat/completions", completion_window="24h"
)
while batch.status in ("validating", "in_progress", "finalizing"):
time.sleep(BATCH_POLL_INTERVAL)
batch = client.batches.retrieve(batch.id)

wall_latency = datetime.now() - starttime
if batch.error_file_id:
errors = client.files.content(batch.error_file_id)
logging.error(errors.text)
raise ValueError(f"LLM batch call failed: {batch}")
if batch.output_file_id:
responses = client.files.content(batch.output_file_id)
for rs, call in zip(responses.iter_lines(), calls):
rdata = json.loads(rs)
id = int(rdata["custom_id"])
cc = ChatCompletion.model_construct(**rdata["response"]["body"])
response_text = cc.choices[0].message.content
ct, pt = self.validate_tokens(cc)
kws = call["body"]
self.add_llm_metadata(kws, response_text, wall_latency, ct, pt)
cache_hits[id] = response_text
self._llm_cache_set(prompts[id], llm_kwargs, response_text)
return cache_hits
raise ValueError(f"LLM batch call terminated with no output file or error file: {batch}")
2 changes: 1 addition & 1 deletion lib/sycamore/sycamore/transforms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _infer_prompts(
res[i] = rs
return res
elif llm_mode == LLMMode.BATCH:
raise NotImplementedError("Haven't done batch yet")
return llm.generate_batch(prompts=prompts)
else:
raise NotImplementedError("Unknown LLM Mode")

Expand Down

0 comments on commit 889844a

Please sign in to comment.