Skip to content

Commit d631c84

Browse files
added AsyncLM
1 parent b652c8e commit d631c84

File tree

6 files changed

+174
-3
lines changed

6 files changed

+174
-3
lines changed

Diff for: dspy/clients/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dspy.clients.lm import LM
2+
from dspy.clients.async_lm import AsyncLM
23
from dspy.clients.provider import Provider, TrainingJob
34
from dspy.clients.base_lm import BaseLM, inspect_history
45
from dspy.clients.embedding import Embedder
@@ -37,6 +38,7 @@ def disable_litellm_logging():
3738

3839
__all__ = [
3940
"LM",
41+
"AsyncLM",
4042
"Provider",
4143
"TrainingJob",
4244
"BaseLM",

Diff for: dspy/clients/async_lm.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import os
2+
import uuid
3+
from datetime import datetime
4+
from typing import Dict, Any, Awaitable, cast
5+
6+
from anyio.streams.memory import MemoryObjectSendStream
7+
from litellm.types.router import RetryPolicy
8+
9+
import litellm
10+
11+
import dspy
12+
from dspy.clients.lm import request_cache, LM
13+
from dspy.utils import with_callbacks
14+
15+
16+
class AsyncLM(LM):
17+
@with_callbacks
18+
def __call__(self, prompt=None, messages=None, **kwargs) -> Awaitable:
19+
async def _async_call(prompt, messages, **kwargs):
20+
# Build the request.
21+
cache = kwargs.pop("cache", self.cache)
22+
messages = messages or [{"role": "user", "content": prompt}]
23+
kwargs = {**self.kwargs, **kwargs}
24+
25+
# Make the request and handle LRU & disk caching.
26+
if self.model_type == "chat":
27+
completion = cached_litellm_completion if cache else litellm_acompletion
28+
else:
29+
completion = cached_litellm_text_completion if cache else litellm_text_acompletion
30+
31+
response = await completion(
32+
request=dict(model=self.model, messages=messages, **kwargs),
33+
num_retries=self.num_retries,
34+
)
35+
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
36+
self._log_entry(prompt, messages, kwargs, response, outputs)
37+
return outputs
38+
39+
return _async_call(prompt, messages, **kwargs)
40+
41+
@request_cache(maxsize=None)
42+
async def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
43+
return await litellm_acompletion(
44+
request,
45+
cache={"no-cache": False, "no-store": False},
46+
num_retries=num_retries,
47+
)
48+
49+
50+
async def litellm_acompletion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
51+
retry_kwargs = dict(
52+
retry_policy=_get_litellm_retry_policy(num_retries),
53+
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
54+
# to completion()), the default value of max_retries is non-zero for certain providers, and
55+
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
56+
max_retries=0,
57+
)
58+
59+
stream = dspy.settings.send_stream
60+
if stream is None:
61+
return await litellm.acompletion(
62+
cache=cache,
63+
**retry_kwargs,
64+
**request,
65+
)
66+
67+
# The stream is already opened, and will be closed by the caller.
68+
stream = cast(MemoryObjectSendStream, stream)
69+
70+
async def stream_completion():
71+
response = await litellm.acompletion(
72+
cache=cache,
73+
stream=True,
74+
**retry_kwargs,
75+
**request,
76+
)
77+
chunks = []
78+
async for chunk in response:
79+
chunks.append(chunk)
80+
await stream.send(chunk)
81+
return litellm.stream_chunk_builder(chunks)
82+
83+
return await stream_completion()
84+
85+
86+
@request_cache(maxsize=None)
87+
async def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
88+
return await litellm_text_acompletion(
89+
request,
90+
num_retries=num_retries,
91+
cache={"no-cache": False, "no-store": False},
92+
)
93+
94+
95+
async def litellm_text_acompletion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
96+
# Extract the provider and model from the model string.
97+
# TODO: Not all the models are in the format of "provider/model"
98+
model = request.pop("model").split("/", 1)
99+
provider, model = model[0] if len(model) > 1 else "openai", model[-1]
100+
101+
# Use the API key and base from the request, or from the environment.
102+
api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
103+
api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
104+
105+
# Build the prompt from the messages.
106+
prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])
107+
108+
return await litellm.atext_completion(
109+
cache=cache,
110+
model=f"text-completion-openai/{model}",
111+
api_key=api_key,
112+
api_base=api_base,
113+
prompt=prompt,
114+
num_retries=num_retries,
115+
**request,
116+
)
117+
118+
def _get_litellm_retry_policy(num_retries: int) -> RetryPolicy:
119+
"""
120+
Get a LiteLLM retry policy for retrying requests when transient API errors occur.
121+
Args:
122+
num_retries: The number of times to retry a request if it fails transiently due to
123+
network error, rate limiting, etc. Requests are retried with exponential
124+
backoff.
125+
Returns:
126+
A LiteLLM RetryPolicy instance.
127+
"""
128+
return RetryPolicy(
129+
TimeoutErrorRetries=num_retries,
130+
RateLimitErrorRetries=num_retries,
131+
InternalServerErrorRetries=num_retries,
132+
ContentPolicyViolationErrorRetries=num_retries,
133+
# We don't retry on errors that are unlikely to be transient
134+
# (e.g. bad request, invalid auth credentials)
135+
BadRequestErrorRetries=0,
136+
AuthenticationErrorRetries=0,
137+
)

Diff for: dspy/clients/lm.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ def __call__(self, prompt=None, messages=None, **kwargs):
113113
else:
114114
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
115115

116+
self._log_entry(prompt, messages, kwargs, response, outputs)
117+
return outputs
118+
119+
def _log_entry(self, prompt, messages, kwargs, response, outputs):
116120
# Logging, with removed api key & where `cost` is None on cache hit.
117121
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
118122
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)
@@ -129,8 +133,6 @@ def __call__(self, prompt=None, messages=None, **kwargs):
129133
self.history.append(entry)
130134
self.update_global_history(entry)
131135

132-
return outputs
133-
134136
def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None):
135137
launch_kwargs = launch_kwargs or self.launch_kwargs
136138
self.provider.launch(self.model, launch_kwargs)

Diff for: pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ docs = [
7272
"sphinx-reredirects>=0.1.2",
7373
"sphinx-automodapi==0.16.0",
7474
]
75-
dev = ["pytest>=6.2.5"]
75+
dev = ["pytest>=6.2.5", "pytest-asyncio>=0.25.0"]
7676
fastembed = ["fastembed>=0.2.0"]
7777

7878
[project.urls]
@@ -152,6 +152,7 @@ ipykernel = "^6.29.4"
152152
semver = "^3.0.2"
153153
pillow = "^10.1.0"
154154
litellm = { version = "^1.51.0", extras = ["proxy"] }
155+
pytest-asyncio = "^0.25.0"
155156

156157
[tool.poetry.extras]
157158
chromadb = ["chromadb"]

Diff for: requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ litellm[proxy]==1.53.7
44
pillow==10.4.0
55
pre-commit==3.7.0
66
pytest==8.3.3
7+
pytest-asyncio==0.25.0
78
pytest-env==1.1.3
89
pytest-mock==3.12.0
910
ruff==0.3.0

Diff for: tests/clients/test_lm.py

+28
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,20 @@ def test_chat_lms_can_be_queried(litellm_test_server):
2929
assert azure_openai_lm("azure openai query") == expected_response
3030

3131

32+
@pytest.mark.asyncio
33+
async def test_async_chat_lms_can_be_queried(litellm_test_server):
34+
api_base, _ = litellm_test_server
35+
expected_response = ["Hi!"]
36+
37+
openai_lm = dspy.AsyncLM(
38+
model="openai/dspy-test-model",
39+
api_base=api_base,
40+
api_key="fakekey",
41+
model_type="chat",
42+
)
43+
assert await openai_lm("openai query") == expected_response
44+
45+
3246
def test_text_lms_can_be_queried(litellm_test_server):
3347
api_base, _ = litellm_test_server
3448
expected_response = ["Hi!"]
@@ -50,6 +64,20 @@ def test_text_lms_can_be_queried(litellm_test_server):
5064
assert azure_openai_lm("azure openai query") == expected_response
5165

5266

67+
@pytest.mark.asyncio
68+
async def test_async_text_lms_can_be_queried(litellm_test_server):
69+
api_base, _ = litellm_test_server
70+
expected_response = ["Hi!"]
71+
72+
openai_lm = dspy.AsyncLM(
73+
model="openai/dspy-test-model",
74+
api_base=api_base,
75+
api_key="fakekey",
76+
model_type="text",
77+
)
78+
assert await openai_lm("openai query") == expected_response
79+
80+
5381
def test_lm_calls_support_callables(litellm_test_server):
5482
api_base, _ = litellm_test_server
5583

0 commit comments

Comments
 (0)