From f4f30323a819ae587befd28dc72452d1858542d9 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Thu, 27 Feb 2025 09:19:16 -0800 Subject: [PATCH] LLM Async mode (#1200) * add async mode for llm_map and implement exponential backoff in openai llm impl Signed-off-by: Henry Lindeman * add async anthropic implementation Signed-off-by: Henry Lindeman * add async gemini integration. I didn't seem to get rate-limiting errors in testing so I think the client handles that automatically Signed-off-by: Henry Lindeman * add async mode to unittests Signed-off-by: Henry Lindeman --------- Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/llms/anthropic.py | 58 +++++++++++++++---- lib/sycamore/sycamore/llms/gemini.py | 46 ++++++++++----- lib/sycamore/sycamore/llms/openai.py | 23 ++++++-- .../tests/unit/transforms/test_base_llm.py | 25 +++++--- lib/sycamore/sycamore/transforms/base_llm.py | 15 ++++- 5 files changed, 128 insertions(+), 39 deletions(-) diff --git a/lib/sycamore/sycamore/llms/anthropic.py b/lib/sycamore/sycamore/llms/anthropic.py index fffc23769..d274e8066 100644 --- a/lib/sycamore/sycamore/llms/anthropic.py +++ b/lib/sycamore/sycamore/llms/anthropic.py @@ -2,6 +2,8 @@ from enum import Enum import logging from typing import Any, Optional, Union +import asyncio +import random from PIL import Image @@ -12,6 +14,7 @@ from sycamore.utils.import_utils import requires_modules DEFAULT_MAX_TOKENS = 1000 +INITIAL_BACKOFF = 1 class AnthropicModels(Enum): @@ -125,6 +128,7 @@ def __init__( # We import this here so we can share utility code with the Bedrock # LLM implementation without requiring an Anthropic dependency. from anthropic import Anthropic as AnthropicClient + from anthropic import AsyncAnthropic as AsyncAnthropicClient self.model_name = model_name @@ -137,6 +141,7 @@ def __init__( self.model = model self._client = AnthropicClient() + self._async_client = AsyncAnthropicClient() super().__init__(self.model.value, cache) def __reduce__(self): @@ -153,18 +158,8 @@ def is_chat_mode(self) -> bool: def format_image(self, image: Image.Image) -> dict[str, Any]: return format_image(image) - def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: - ret = self._llm_cache_get(prompt, llm_kwargs) - if isinstance(ret, dict): - return ret - - kwargs = get_generate_kwargs(prompt, llm_kwargs) - - start = datetime.now() - - response = self._client.messages.create(model=self.model.value, **kwargs) - - wall_latency = datetime.now() - start + def _metadata_from_response(self, kwargs, response, starttime) -> dict: + wall_latency = datetime.now() - starttime in_tokens = response.usage.input_tokens out_tokens = response.usage.output_tokens output = response.content[0].text @@ -176,6 +171,18 @@ def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict "out_tokens": out_tokens, } self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens) + return ret + + def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: + ret = self._llm_cache_get(prompt, llm_kwargs) + if isinstance(ret, dict): + return ret + + kwargs = get_generate_kwargs(prompt, llm_kwargs) + start = datetime.now() + + response = self._client.messages.create(model=self.model.value, **kwargs) + ret = self._metadata_from_response(kwargs, response, start) logging.debug(f"Generated response from Anthropic model: {ret}") self._llm_cache_set(prompt, llm_kwargs, ret) @@ -184,3 +191,30 @@ def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: d = self.generate_metadata(prompt=prompt, llm_kwargs=llm_kwargs) return d["output"] + + async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + from anthropic import RateLimitError, APIConnectionError + + ret = self._llm_cache_get(prompt, llm_kwargs) + if isinstance(ret, dict): + return ret["output"] + + kwargs = get_generate_kwargs(prompt, llm_kwargs) + start = datetime.now() + done = False + retries = 0 + while not done: + try: + response = await self._async_client.messages.create(model=self.model.value, **kwargs) + done = True + except (RateLimitError, APIConnectionError): + backoff = INITIAL_BACKOFF * (2**retries) + jitter = random.uniform(0, 0.1 * backoff) + await asyncio.sleep(backoff + jitter) + retries += 1 + + ret = self._metadata_from_response(kwargs, response, start) + logging.debug(f"Generated response from Anthropic model: {ret}") + + self._llm_cache_set(prompt, llm_kwargs, ret) + return ret["output"] diff --git a/lib/sycamore/sycamore/llms/gemini.py b/lib/sycamore/sycamore/llms/gemini.py index 9583deda0..ac37d81e2 100644 --- a/lib/sycamore/sycamore/llms/gemini.py +++ b/lib/sycamore/sycamore/llms/gemini.py @@ -91,7 +91,7 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] content_list: list[types.Content] = [] for message in prompt.messages: if message.role == "system": - config["system_message"] = message.content + config["system_instruction"] = message.content continue role = "model" if message.role == "assistant" else "user" content = types.Content(parts=[types.Part.from_text(text=message.content)], role=role) @@ -108,6 +108,21 @@ def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] kwargs["content"] = content_list return kwargs + def _metadata_from_response(self, kwargs, response, starttime) -> dict: + wall_latency = datetime.datetime.now() - starttime + md = response.usage_metadata + in_tokens = int(md.prompt_token_count) if md and md.prompt_token_count else 0 + out_tokens = int(md.candidates_token_count) if md and md.candidates_token_count else 0 + output = " ".join(part.text if part else "" for part in response.candidates[0].content.parts) + ret = { + "output": output, + "wall_latency": wall_latency, + "in_tokens": in_tokens, + "out_tokens": out_tokens, + } + self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens) + return ret + def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: ret = self._llm_cache_get(prompt, llm_kwargs) if isinstance(ret, dict): @@ -120,21 +135,26 @@ def generate_metadata(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict response = self._client.models.generate_content( model=self.model.name, contents=kwargs["content"], config=kwargs["config"] ) - wall_latency = datetime.datetime.now() - start - md = response.usage_metadata - in_tokens = int(md.prompt_token_count) if md and md.prompt_token_count else 0 - out_tokens = int(md.candidates_token_count) if md and md.candidates_token_count else 0 - output = " ".join(part.text if part else "" for part in response.candidates[0].content.parts) - ret = { - "output": output, - "wall_latency": wall_latency, - "in_tokens": in_tokens, - "out_tokens": out_tokens, - } - self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens) + ret = self._metadata_from_response(kwargs, response, start) self._llm_cache_set(prompt, llm_kwargs, ret) return ret def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: d = self.generate_metadata(prompt=prompt, llm_kwargs=llm_kwargs) return d["output"] + + async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + ret = self._llm_cache_get(prompt, llm_kwargs) + if isinstance(ret, dict): + return ret["output"] + assert ret is None + + kwargs = self.get_generate_kwargs(prompt, llm_kwargs) + + start = datetime.datetime.now() + response = await self._client.aio.models.generate_content( + model=self.model.name, contents=kwargs["content"], config=kwargs["config"] + ) + ret = self._metadata_from_response(kwargs, response, start) + self._llm_cache_set(prompt, llm_kwargs, ret) + return ret["output"] diff --git a/lib/sycamore/sycamore/llms/openai.py b/lib/sycamore/sycamore/llms/openai.py index 7730da9ea..f9f8cbf0d 100644 --- a/lib/sycamore/sycamore/llms/openai.py +++ b/lib/sycamore/sycamore/llms/openai.py @@ -7,6 +7,8 @@ from PIL import Image from typing import Any, Dict, Optional, Tuple, Union from datetime import datetime +import asyncio +import random from openai import AzureOpenAI as AzureOpenAIClient from openai import AsyncAzureOpenAI as AsyncAzureOpenAIClient @@ -15,6 +17,7 @@ from openai import max_retries as DEFAULT_MAX_RETRIES from openai.lib.azure import AzureADTokenProvider from openai.lib._parsing import type_to_response_format_param +from openai import APIConnectionError import pydantic @@ -29,6 +32,7 @@ # Base URL for Helicone API, if configured using the SYCAMORE_HELICONE_API_KEY environment variable. HELICONE_BASE_URL = "https://oai.helicone.ai/v1" +INITIAL_BACKOFF = 0.2 class OpenAIClientType(Enum): @@ -416,10 +420,20 @@ async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[d if llm_kwargs is None: raise ValueError("Must include llm_kwargs to generate future call") - if prompt.response_format is not None: - ret = await self._generate_awaitable_using_openai_structured(prompt, llm_kwargs) - else: - ret = await self._generate_awaitable_using_openai(prompt, llm_kwargs) + done = False + retries = 0 + while not done: + try: + if prompt.response_format is not None: + ret = await self._generate_awaitable_using_openai_structured(prompt, llm_kwargs) + else: + ret = await self._generate_awaitable_using_openai(prompt, llm_kwargs) + done = True + except APIConnectionError: + backoff = INITIAL_BACKOFF * (2**retries) + jitter = random.uniform(0, 0.1 * backoff) + await asyncio.sleep(backoff + jitter) + retries += 1 self._llm_cache_set(prompt, llm_kwargs, ret) return ret @@ -432,6 +446,7 @@ async def _generate_awaitable_using_openai(self, prompt: RenderedPrompt, llm_kwa model=self._model_name, **kwargs ) response_text = completion.choices[0].message.content + wall_latency = datetime.now() - starttime else: completion = await self.client_wrapper.get_async_client().completions.create( model=self._model_name, **kwargs diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py index 0710eb8fe..5811be6b5 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_base_llm.py @@ -1,5 +1,5 @@ from sycamore.data import Document, Element -from sycamore.llms.llms import LLM +from sycamore.llms.llms import LLM, LLMMode from sycamore.llms.prompts import RenderedPrompt, SycamorePrompt from sycamore.llms.prompts.prompts import RenderedMessage from sycamore.transforms.base_llm import LLMMap, LLMMapElements @@ -17,6 +17,9 @@ def is_chat_mode(self) -> bool: def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: return "".join(m.content for m in prompt.messages) + async def generate_async(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + return self.generate(prompt=prompt, llm_kwargs=llm_kwargs) + class FakeDocPrompt(SycamorePrompt): def render_document(self, doc: Document) -> RenderedPrompt: @@ -41,12 +44,13 @@ def test_wrong_prompt_fails_fast(self): _ = LLMMap(None, prompt, "out", llm) assert "FakeEltPrompt" in str(einfo.value) - def test_happy_path(self): + @pytest.mark.parametrize("mode", [LLMMode.SYNC, LLMMode.ASYNC]) + def test_happy_path(self, mode): prompt = FakeDocPrompt() llm = FakeLLM() doc1 = Document({"text_representation": "ooga"}) doc2 = Document({"text_representation": "booga"}) - map = LLMMap(None, prompt, "out", llm) + map = LLMMap(None, prompt, "out", llm, llm_mode=mode) outdocs = map.llm_map([doc1, doc2]) assert outdocs[0].text_representation == "ooga" @@ -54,7 +58,8 @@ def test_happy_path(self): assert outdocs[1].text_representation == "booga" assert outdocs[1].properties["out"] == "booga" - def test_validate(self): + @pytest.mark.parametrize("mode", [LLMMode.SYNC, LLMMode.ASYNC]) + def test_validate(self, mode): prompt = FakeDocPrompt() llm = FakeLLM() doc1 = Document({"text_representation": "ooga"}) @@ -66,7 +71,7 @@ def valfn(d: Document) -> bool: count += 1 return count > 1 - map = LLMMap(None, prompt, "out", llm, validate=valfn) + map = LLMMap(None, prompt, "out", llm, validate=valfn, llm_mode=mode) _ = map.llm_map([doc1, doc2]) assert count == 2 @@ -80,7 +85,8 @@ def test_wrong_prompt_fails_fast(self): _ = LLMMapElements(None, prompt, "out", llm) assert "FakeDocPrompt" in str(einfo.value) - def test_happy_path(self): + @pytest.mark.parametrize("mode", [LLMMode.SYNC, LLMMode.ASYNC]) + def test_happy_path(self, mode): prompt = FakeEltPrompt() llm = FakeLLM() doc1 = Document( @@ -91,7 +97,7 @@ def test_happy_path(self): } ) doc2 = Document({"doc_id": "2", "elements": [{"text_representation": "booga"}, {}]}) - map = LLMMapElements(None, prompt, "out", llm) + map = LLMMapElements(None, prompt, "out", llm, llm_mode=mode) outdocs = map.llm_map_elements([doc1, doc2]) assert outdocs[0].elements[0].properties["out"] == "oogayo" @@ -99,7 +105,8 @@ def test_happy_path(self): assert outdocs[1].elements[0].properties["out"] == "Nonebooga" assert outdocs[1].elements[1].properties["out"] == "NoneNone" - def test_postprocess(self): + @pytest.mark.parametrize("mode", [LLMMode.SYNC, LLMMode.ASYNC]) + def test_postprocess(self, mode): prompt = FakeEltPrompt() llm = FakeLLM() doc1 = Document( @@ -117,7 +124,7 @@ def valfn(e: Element) -> bool: count += 1 return count > 1 - map = LLMMapElements(None, prompt, "out", llm, validate=valfn) + map = LLMMapElements(None, prompt, "out", llm, validate=valfn, llm_mode=mode) _ = map.llm_map_elements([doc1, doc2]) assert count == 4 diff --git a/lib/sycamore/sycamore/transforms/base_llm.py b/lib/sycamore/sycamore/transforms/base_llm.py index 224ac8f87..099a069ef 100644 --- a/lib/sycamore/sycamore/transforms/base_llm.py +++ b/lib/sycamore/sycamore/transforms/base_llm.py @@ -5,6 +5,14 @@ from sycamore.plan_nodes import Node from sycamore.transforms.map import MapBatch from sycamore.data import Document, Element +import asyncio + + +async def _infer_prompts_async(prompts: list[RenderedPrompt], llm: LLM) -> list[str]: + el = asyncio.get_running_loop() + awaitables = [llm.generate_async(prompt=p, llm_kwargs={}) for p in prompts] + tasks = [el.create_task(aw) for aw in awaitables] + return await asyncio.gather(*tasks) def _infer_prompts( @@ -22,7 +30,12 @@ def _infer_prompts( res.append(s) return res elif llm_mode == LLMMode.ASYNC: - raise NotImplementedError("Haven't done async yet") + nonempty = [(i, p) for i, p in enumerate(prompts) if len(p.messages) > 0] + res = [""] * len(prompts) + rsps = asyncio.run(_infer_prompts_async([p for _, p in nonempty], llm)) + for (i, _), rs in zip(nonempty, rsps): + res[i] = rs + return res elif llm_mode == LLMMode.BATCH: raise NotImplementedError("Haven't done batch yet") else: