Skip to content

Commit

Permalink
LLM Async mode (#1200)
Browse files Browse the repository at this point in the history
* add async mode for llm_map and implement exponential backoff in openai llm impl

Signed-off-by: Henry Lindeman <[email protected]>

* add async anthropic implementation

Signed-off-by: Henry Lindeman <[email protected]>

* add async gemini integration. I didn't seem to get rate-limiting errors in testing so I think the client handles that automatically

Signed-off-by: Henry Lindeman <[email protected]>

* add async mode to unittests

Signed-off-by: Henry Lindeman <[email protected]>

---------

Signed-off-by: Henry Lindeman <[email protected]>
  • Loading branch information
HenryL27 authored Feb 27, 2025
1 parent b7fe1b6 commit f4f3032
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 39 deletions.
58 changes: 46 additions & 12 deletions lib/sycamore/sycamore/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from enum import Enum
import logging
from typing import Any, Optional, Union
import asyncio
import random

from PIL import Image

Expand All @@ -12,6 +14,7 @@
from sycamore.utils.import_utils import requires_modules

DEFAULT_MAX_TOKENS = 1000
INITIAL_BACKOFF = 1


class AnthropicModels(Enum):
Expand Down Expand Up @@ -125,6 +128,7 @@ def __init__(
# We import this here so we can share utility code with the Bedrock
# LLM implementation without requiring an Anthropic dependency.
from anthropic import Anthropic as AnthropicClient
from anthropic import AsyncAnthropic as AsyncAnthropicClient

self.model_name = model_name

Expand All @@ -137,6 +141,7 @@ def __init__(
self.model = model

self._client = AnthropicClient()
self._async_client = AsyncAnthropicClient()
super().__init__(self.model.value, cache)

def __reduce__(self):
Expand All @@ -153,18 +158,8 @@ def is_chat_mode(self) -> bool:
def format_image(self, image: Image.Image) -> dict[str, Any]:
return format_image(image)

def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
return ret

kwargs = get_generate_kwargs(prompt, llm_kwargs)

start = datetime.now()

response = self._client.messages.create(model=self.model.value, **kwargs)

wall_latency = datetime.now() - start
def _metadata_from_response(self, kwargs, response, starttime) -> dict:
wall_latency = datetime.now() - starttime
in_tokens = response.usage.input_tokens
out_tokens = response.usage.output_tokens
output = response.content[0].text
Expand All @@ -176,6 +171,18 @@ def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict
"out_tokens": out_tokens,
}
self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens)
return ret

def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
return ret

kwargs = get_generate_kwargs(prompt, llm_kwargs)
start = datetime.now()

response = self._client.messages.create(model=self.model.value, **kwargs)
ret = self._metadata_from_response(kwargs, response, start)
logging.debug(f"Generated response from Anthropic model: {ret}")

self._llm_cache_set(prompt, llm_kwargs, ret)
Expand All @@ -184,3 +191,30 @@ def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict
def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
d = self.generate_metadata(prompt=prompt, llm_kwargs=llm_kwargs)
return d["output"]

async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
from anthropic import RateLimitError, APIConnectionError

ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
return ret["output"]

kwargs = get_generate_kwargs(prompt, llm_kwargs)
start = datetime.now()
done = False
retries = 0
while not done:
try:
response = await self._async_client.messages.create(model=self.model.value, **kwargs)
done = True
except (RateLimitError, APIConnectionError):
backoff = INITIAL_BACKOFF * (2**retries)
jitter = random.uniform(0, 0.1 * backoff)
await asyncio.sleep(backoff + jitter)
retries += 1

ret = self._metadata_from_response(kwargs, response, start)
logging.debug(f"Generated response from Anthropic model: {ret}")

self._llm_cache_set(prompt, llm_kwargs, ret)
return ret["output"]
46 changes: 33 additions & 13 deletions lib/sycamore/sycamore/llms/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]
content_list: list[types.Content] = []
for message in prompt.messages:
if message.role == "system":
config["system_message"] = message.content
config["system_instruction"] = message.content
continue
role = "model" if message.role == "assistant" else "user"
content = types.Content(parts=[types.Part.from_text(text=message.content)], role=role)
Expand All @@ -108,6 +108,21 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict]
kwargs["content"] = content_list
return kwargs

def _metadata_from_response(self, kwargs, response, starttime) -> dict:
wall_latency = datetime.datetime.now() - starttime
md = response.usage_metadata
in_tokens = int(md.prompt_token_count) if md and md.prompt_token_count else 0
out_tokens = int(md.candidates_token_count) if md and md.candidates_token_count else 0
output = " ".join(part.text if part else "" for part in response.candidates[0].content.parts)
ret = {
"output": output,
"wall_latency": wall_latency,
"in_tokens": in_tokens,
"out_tokens": out_tokens,
}
self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens)
return ret

def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
Expand All @@ -120,21 +135,26 @@ def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict
response = self._client.models.generate_content(
model=self.model.name, contents=kwargs["content"], config=kwargs["config"]
)
wall_latency = datetime.datetime.now() - start
md = response.usage_metadata
in_tokens = int(md.prompt_token_count) if md and md.prompt_token_count else 0
out_tokens = int(md.candidates_token_count) if md and md.candidates_token_count else 0
output = " ".join(part.text if part else "" for part in response.candidates[0].content.parts)
ret = {
"output": output,
"wall_latency": wall_latency,
"in_tokens": in_tokens,
"out_tokens": out_tokens,
}
self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens)
ret = self._metadata_from_response(kwargs, response, start)
self._llm_cache_set(prompt, llm_kwargs, ret)
return ret

def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
d = self.generate_metadata(prompt=prompt, llm_kwargs=llm_kwargs)
return d["output"]

async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
ret = self._llm_cache_get(prompt, llm_kwargs)
if isinstance(ret, dict):
return ret["output"]
assert ret is None

kwargs = self.get_generate_kwargs(prompt, llm_kwargs)

start = datetime.datetime.now()
response = await self._client.aio.models.generate_content(
model=self.model.name, contents=kwargs["content"], config=kwargs["config"]
)
ret = self._metadata_from_response(kwargs, response, start)
self._llm_cache_set(prompt, llm_kwargs, ret)
return ret["output"]
23 changes: 19 additions & 4 deletions lib/sycamore/sycamore/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from PIL import Image
from typing import Any, Dict, Optional, Tuple, Union
from datetime import datetime
import asyncio
import random

from openai import AzureOpenAI as AzureOpenAIClient
from openai import AsyncAzureOpenAI as AsyncAzureOpenAIClient
Expand All @@ -15,6 +17,7 @@
from openai import max_retries as DEFAULT_MAX_RETRIES
from openai.lib.azure import AzureADTokenProvider
from openai.lib._parsing import type_to_response_format_param
from openai import APIConnectionError


import pydantic
Expand All @@ -29,6 +32,7 @@

# Base URL for Helicone API, if configured using the SYCAMORE_HELICONE_API_KEY environment variable.
HELICONE_BASE_URL = "https://oai.helicone.ai/v1"
INITIAL_BACKOFF = 0.2


class OpenAIClientType(Enum):
Expand Down Expand Up @@ -416,10 +420,20 @@ async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[d
if llm_kwargs is None:
raise ValueError("Must include llm_kwargs to generate future call")

if prompt.response_format is not None:
ret = await self._generate_awaitable_using_openai_structured(prompt, llm_kwargs)
else:
ret = await self._generate_awaitable_using_openai(prompt, llm_kwargs)
done = False
retries = 0
while not done:
try:
if prompt.response_format is not None:
ret = await self._generate_awaitable_using_openai_structured(prompt, llm_kwargs)
else:
ret = await self._generate_awaitable_using_openai(prompt, llm_kwargs)
done = True
except APIConnectionError:
backoff = INITIAL_BACKOFF * (2**retries)
jitter = random.uniform(0, 0.1 * backoff)
await asyncio.sleep(backoff + jitter)
retries += 1

self._llm_cache_set(prompt, llm_kwargs, ret)
return ret
Expand All @@ -432,6 +446,7 @@ async def _generate_awaitable_using_openai(self, prompt: RenderedPrompt, llm_kwa
model=self._model_name, **kwargs
)
response_text = completion.choices[0].message.content
wall_latency = datetime.now() - starttime
else:
completion = await self.client_wrapper.get_async_client().completions.create(
model=self._model_name, **kwargs
Expand Down
25 changes: 16 additions & 9 deletions lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sycamore.data import Document, Element
from sycamore.llms.llms import LLM
from sycamore.llms.llms import LLM, LLMMode
from sycamore.llms.prompts import RenderedPrompt, SycamorePrompt
from sycamore.llms.prompts.prompts import RenderedMessage
from sycamore.transforms.base_llm import LLMMap, LLMMapElements
Expand All @@ -17,6 +17,9 @@ def is_chat_mode(self) -> bool:
def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
return "".join(m.content for m in prompt.messages)

async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str:
return self.generate(prompt=prompt, llm_kwargs=llm_kwargs)


class FakeDocPrompt(SycamorePrompt):
def render_document(self, doc: Document) -> RenderedPrompt:
Expand All @@ -41,20 +44,22 @@ def test_wrong_prompt_fails_fast(self):
_ = LLMMap(None, prompt, "out", llm)
assert "FakeEltPrompt" in str(einfo.value)

def test_happy_path(self):
@pytest.mark.parametrize("mode", [LLMMode.SYNC, LLMMode.ASYNC])
def test_happy_path(self, mode):
prompt = FakeDocPrompt()
llm = FakeLLM()
doc1 = Document({"text_representation": "ooga"})
doc2 = Document({"text_representation": "booga"})
map = LLMMap(None, prompt, "out", llm)
map = LLMMap(None, prompt, "out", llm, llm_mode=mode)
outdocs = map.llm_map([doc1, doc2])

assert outdocs[0].text_representation == "ooga"
assert outdocs[0].properties["out"] == "ooga"
assert outdocs[1].text_representation == "booga"
assert outdocs[1].properties["out"] == "booga"

def test_validate(self):
@pytest.mark.parametrize("mode", [LLMMode.SYNC, LLMMode.ASYNC])
def test_validate(self, mode):
prompt = FakeDocPrompt()
llm = FakeLLM()
doc1 = Document({"text_representation": "ooga"})
Expand All @@ -66,7 +71,7 @@ def valfn(d: Document) -> bool:
count += 1
return count > 1

map = LLMMap(None, prompt, "out", llm, validate=valfn)
map = LLMMap(None, prompt, "out", llm, validate=valfn, llm_mode=mode)
_ = map.llm_map([doc1, doc2])

assert count == 2
Expand All @@ -80,7 +85,8 @@ def test_wrong_prompt_fails_fast(self):
_ = LLMMapElements(None, prompt, "out", llm)
assert "FakeDocPrompt" in str(einfo.value)

def test_happy_path(self):
@pytest.mark.parametrize("mode", [LLMMode.SYNC, LLMMode.ASYNC])
def test_happy_path(self, mode):
prompt = FakeEltPrompt()
llm = FakeLLM()
doc1 = Document(
Expand All @@ -91,15 +97,16 @@ def test_happy_path(self):
}
)
doc2 = Document({"doc_id": "2", "elements": [{"text_representation": "booga"}, {}]})
map = LLMMapElements(None, prompt, "out", llm)
map = LLMMapElements(None, prompt, "out", llm, llm_mode=mode)
outdocs = map.llm_map_elements([doc1, doc2])

assert outdocs[0].elements[0].properties["out"] == "oogayo"
assert outdocs[0].elements[1].properties["out"] == "oogaho"
assert outdocs[1].elements[0].properties["out"] == "Nonebooga"
assert outdocs[1].elements[1].properties["out"] == "NoneNone"

def test_postprocess(self):
@pytest.mark.parametrize("mode", [LLMMode.SYNC, LLMMode.ASYNC])
def test_postprocess(self, mode):
prompt = FakeEltPrompt()
llm = FakeLLM()
doc1 = Document(
Expand All @@ -117,7 +124,7 @@ def valfn(e: Element) -> bool:
count += 1
return count > 1

map = LLMMapElements(None, prompt, "out", llm, validate=valfn)
map = LLMMapElements(None, prompt, "out", llm, validate=valfn, llm_mode=mode)
_ = map.llm_map_elements([doc1, doc2])

assert count == 4
15 changes: 14 additions & 1 deletion lib/sycamore/sycamore/transforms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
from sycamore.plan_nodes import Node
from sycamore.transforms.map import MapBatch
from sycamore.data import Document, Element
import asyncio


async def _infer_prompts_async(prompts: list[RenderedPrompt], llm: LLM) -> list[str]:
el = asyncio.get_running_loop()
awaitables = [llm.generate_async(prompt=p, llm_kwargs={}) for p in prompts]
tasks = [el.create_task(aw) for aw in awaitables]
return await asyncio.gather(*tasks)


def _infer_prompts(
Expand All @@ -22,7 +30,12 @@ def _infer_prompts(
res.append(s)
return res
elif llm_mode == LLMMode.ASYNC:
raise NotImplementedError("Haven't done async yet")
nonempty = [(i, p) for i, p in enumerate(prompts) if len(p.messages) > 0]
res = [""] * len(prompts)
rsps = asyncio.run(_infer_prompts_async([p for _, p in nonempty], llm))
for (i, _), rs in zip(nonempty, rsps):
res[i] = rs
return res
elif llm_mode == LLMMode.BATCH:
raise NotImplementedError("Haven't done batch yet")
else:
Expand Down

0 comments on commit f4f3032

Please sign in to comment.