Skip to content

Commit 671a6a1

Browse files
added AsyncLM
1 parent b652c8e commit 671a6a1

File tree

7 files changed

+533
-18
lines changed

7 files changed

+533
-18
lines changed

Diff for: dspy/clients/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dspy.clients.lm import LM
2+
from dspy.clients.async_lm import AsyncLM
23
from dspy.clients.provider import Provider, TrainingJob
34
from dspy.clients.base_lm import BaseLM, inspect_history
45
from dspy.clients.embedding import Embedder
@@ -37,6 +38,7 @@ def disable_litellm_logging():
3738

3839
__all__ = [
3940
"LM",
41+
"AsyncLM",
4042
"Provider",
4143
"TrainingJob",
4244
"BaseLM",

Diff for: dspy/clients/async_lm.py

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import os
2+
from typing import Any, Awaitable, Dict, cast
3+
4+
import litellm
5+
from anyio.streams.memory import MemoryObjectSendStream
6+
from litellm.types.router import RetryPolicy
7+
8+
import dspy
9+
from dspy.clients.lm import LM, request_cache
10+
from dspy.utils import with_callbacks
11+
12+
13+
class AsyncLM(LM):
14+
@with_callbacks
15+
def __call__(self, prompt=None, messages=None, **kwargs) -> Awaitable:
16+
async def _async_call(prompt, messages, **kwargs):
17+
# Build the request.
18+
cache = kwargs.pop("cache", self.cache)
19+
messages = messages or [{"role": "user", "content": prompt}]
20+
kwargs = {**self.kwargs, **kwargs}
21+
22+
# Make the request and handle LRU & disk caching.
23+
if self.model_type == "chat":
24+
completion = cached_litellm_completion if cache else litellm_acompletion
25+
else:
26+
completion = cached_litellm_text_completion if cache else litellm_text_acompletion
27+
28+
response = await completion(
29+
request=dict(model=self.model, messages=messages, **kwargs),
30+
num_retries=self.num_retries,
31+
)
32+
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
33+
self._log_entry(prompt, messages, kwargs, response, outputs)
34+
return outputs
35+
36+
return _async_call(prompt, messages, **kwargs)
37+
38+
39+
@request_cache(maxsize=None)
40+
async def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
41+
return await litellm_acompletion(
42+
request,
43+
cache={"no-cache": False, "no-store": False},
44+
num_retries=num_retries,
45+
)
46+
47+
48+
async def litellm_acompletion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
49+
retry_kwargs = dict(
50+
retry_policy=_get_litellm_retry_policy(num_retries),
51+
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
52+
# to completion()), the default value of max_retries is non-zero for certain providers, and
53+
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
54+
max_retries=0,
55+
)
56+
57+
stream = dspy.settings.send_stream
58+
if stream is None:
59+
return await litellm.acompletion(
60+
cache=cache,
61+
**retry_kwargs,
62+
**request,
63+
)
64+
65+
# The stream is already opened, and will be closed by the caller.
66+
stream = cast(MemoryObjectSendStream, stream)
67+
68+
async def stream_completion():
69+
response = await litellm.acompletion(
70+
cache=cache,
71+
stream=True,
72+
**retry_kwargs,
73+
**request,
74+
)
75+
chunks = []
76+
async for chunk in response:
77+
chunks.append(chunk)
78+
await stream.send(chunk)
79+
return litellm.stream_chunk_builder(chunks)
80+
81+
return await stream_completion()
82+
83+
84+
@request_cache(maxsize=None)
85+
async def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
86+
return await litellm_text_acompletion(
87+
request,
88+
num_retries=num_retries,
89+
cache={"no-cache": False, "no-store": False},
90+
)
91+
92+
93+
async def litellm_text_acompletion(
94+
request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}
95+
):
96+
# Extract the provider and model from the model string.
97+
# TODO: Not all the models are in the format of "provider/model"
98+
model = request.pop("model").split("/", 1)
99+
provider, model = model[0] if len(model) > 1 else "openai", model[-1]
100+
101+
# Use the API key and base from the request, or from the environment.
102+
api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
103+
api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
104+
105+
# Build the prompt from the messages.
106+
prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])
107+
108+
return await litellm.atext_completion(
109+
cache=cache,
110+
model=f"text-completion-openai/{model}",
111+
api_key=api_key,
112+
api_base=api_base,
113+
prompt=prompt,
114+
num_retries=num_retries,
115+
**request,
116+
)
117+
118+
119+
def _get_litellm_retry_policy(num_retries: int) -> RetryPolicy:
120+
"""
121+
Get a LiteLLM retry policy for retrying requests when transient API errors occur.
122+
Args:
123+
num_retries: The number of times to retry a request if it fails transiently due to
124+
network error, rate limiting, etc. Requests are retried with exponential
125+
backoff.
126+
Returns:
127+
A LiteLLM RetryPolicy instance.
128+
"""
129+
return RetryPolicy(
130+
TimeoutErrorRetries=num_retries,
131+
RateLimitErrorRetries=num_retries,
132+
InternalServerErrorRetries=num_retries,
133+
ContentPolicyViolationErrorRetries=num_retries,
134+
# We don't retry on errors that are unlikely to be transient
135+
# (e.g. bad request, invalid auth credentials)
136+
BadRequestErrorRetries=0,
137+
AuthenticationErrorRetries=0,
138+
)

Diff for: dspy/clients/lm.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ def __call__(self, prompt=None, messages=None, **kwargs):
113113
else:
114114
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
115115

116+
self._log_entry(prompt, messages, kwargs, response, outputs)
117+
return outputs
118+
119+
def _log_entry(self, prompt, messages, kwargs, response, outputs):
116120
# Logging, with removed api key & where `cost` is None on cache hit.
117121
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("api_")}
118122
entry = dict(prompt=prompt, messages=messages, kwargs=kwargs, response=response)
@@ -129,8 +133,6 @@ def __call__(self, prompt=None, messages=None, **kwargs):
129133
self.history.append(entry)
130134
self.update_global_history(entry)
131135

132-
return outputs
133-
134136
def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None):
135137
launch_kwargs = launch_kwargs or self.launch_kwargs
136138
self.provider.launch(self.model, launch_kwargs)

0 commit comments

Comments
 (0)