|
| 1 | +import os |
| 2 | +import uuid |
| 3 | +from datetime import datetime |
| 4 | +from typing import Dict, Any, Awaitable, cast |
| 5 | + |
| 6 | +from anyio.streams.memory import MemoryObjectSendStream |
| 7 | +from litellm.types.router import RetryPolicy |
| 8 | + |
| 9 | +import litellm |
| 10 | + |
| 11 | +import dspy |
| 12 | +from dspy.clients.lm import request_cache, LM |
| 13 | +from dspy.utils import with_callbacks |
| 14 | + |
| 15 | + |
| 16 | +class AsyncLM(LM): |
| 17 | + @with_callbacks |
| 18 | + def __call__(self, prompt=None, messages=None, **kwargs) -> Awaitable: |
| 19 | + async def _async_call(prompt, messages, **kwargs): |
| 20 | + # Build the request. |
| 21 | + cache = kwargs.pop("cache", self.cache) |
| 22 | + messages = messages or [{"role": "user", "content": prompt}] |
| 23 | + kwargs = {**self.kwargs, **kwargs} |
| 24 | + |
| 25 | + # Make the request and handle LRU & disk caching. |
| 26 | + if self.model_type == "chat": |
| 27 | + completion = cached_litellm_completion if cache else litellm_acompletion |
| 28 | + else: |
| 29 | + completion = cached_litellm_text_completion if cache else litellm_text_acompletion |
| 30 | + |
| 31 | + response = await completion( |
| 32 | + request=dict(model=self.model, messages=messages, **kwargs), |
| 33 | + num_retries=self.num_retries, |
| 34 | + ) |
| 35 | + outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] |
| 36 | + self._log_entry(prompt, messages, kwargs, response, outputs) |
| 37 | + return outputs |
| 38 | + |
| 39 | + return _async_call(prompt, messages, **kwargs) |
| 40 | + |
| 41 | +@request_cache(maxsize=None) |
| 42 | +async def cached_litellm_completion(request: Dict[str, Any], num_retries: int): |
| 43 | + return await litellm_acompletion( |
| 44 | + request, |
| 45 | + cache={"no-cache": False, "no-store": False}, |
| 46 | + num_retries=num_retries, |
| 47 | + ) |
| 48 | + |
| 49 | + |
| 50 | +async def litellm_acompletion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): |
| 51 | + retry_kwargs = dict( |
| 52 | + retry_policy=_get_litellm_retry_policy(num_retries), |
| 53 | + # In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument |
| 54 | + # to completion()), the default value of max_retries is non-zero for certain providers, and |
| 55 | + # max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0 |
| 56 | + max_retries=0, |
| 57 | + ) |
| 58 | + |
| 59 | + stream = dspy.settings.send_stream |
| 60 | + if stream is None: |
| 61 | + return await litellm.acompletion( |
| 62 | + cache=cache, |
| 63 | + **retry_kwargs, |
| 64 | + **request, |
| 65 | + ) |
| 66 | + |
| 67 | + # The stream is already opened, and will be closed by the caller. |
| 68 | + stream = cast(MemoryObjectSendStream, stream) |
| 69 | + |
| 70 | + async def stream_completion(): |
| 71 | + response = await litellm.acompletion( |
| 72 | + cache=cache, |
| 73 | + stream=True, |
| 74 | + **retry_kwargs, |
| 75 | + **request, |
| 76 | + ) |
| 77 | + chunks = [] |
| 78 | + async for chunk in response: |
| 79 | + chunks.append(chunk) |
| 80 | + await stream.send(chunk) |
| 81 | + return litellm.stream_chunk_builder(chunks) |
| 82 | + |
| 83 | + return await stream_completion() |
| 84 | + |
| 85 | + |
| 86 | +@request_cache(maxsize=None) |
| 87 | +async def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int): |
| 88 | + return await litellm_text_acompletion( |
| 89 | + request, |
| 90 | + num_retries=num_retries, |
| 91 | + cache={"no-cache": False, "no-store": False}, |
| 92 | + ) |
| 93 | + |
| 94 | + |
| 95 | +async def litellm_text_acompletion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}): |
| 96 | + # Extract the provider and model from the model string. |
| 97 | + # TODO: Not all the models are in the format of "provider/model" |
| 98 | + model = request.pop("model").split("/", 1) |
| 99 | + provider, model = model[0] if len(model) > 1 else "openai", model[-1] |
| 100 | + |
| 101 | + # Use the API key and base from the request, or from the environment. |
| 102 | + api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY") |
| 103 | + api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE") |
| 104 | + |
| 105 | + # Build the prompt from the messages. |
| 106 | + prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"]) |
| 107 | + |
| 108 | + return await litellm.atext_completion( |
| 109 | + cache=cache, |
| 110 | + model=f"text-completion-openai/{model}", |
| 111 | + api_key=api_key, |
| 112 | + api_base=api_base, |
| 113 | + prompt=prompt, |
| 114 | + num_retries=num_retries, |
| 115 | + **request, |
| 116 | + ) |
| 117 | + |
| 118 | +def _get_litellm_retry_policy(num_retries: int) -> RetryPolicy: |
| 119 | + """ |
| 120 | + Get a LiteLLM retry policy for retrying requests when transient API errors occur. |
| 121 | + Args: |
| 122 | + num_retries: The number of times to retry a request if it fails transiently due to |
| 123 | + network error, rate limiting, etc. Requests are retried with exponential |
| 124 | + backoff. |
| 125 | + Returns: |
| 126 | + A LiteLLM RetryPolicy instance. |
| 127 | + """ |
| 128 | + return RetryPolicy( |
| 129 | + TimeoutErrorRetries=num_retries, |
| 130 | + RateLimitErrorRetries=num_retries, |
| 131 | + InternalServerErrorRetries=num_retries, |
| 132 | + ContentPolicyViolationErrorRetries=num_retries, |
| 133 | + # We don't retry on errors that are unlikely to be transient |
| 134 | + # (e.g. bad request, invalid auth credentials) |
| 135 | + BadRequestErrorRetries=0, |
| 136 | + AuthenticationErrorRetries=0, |
| 137 | + ) |
0 commit comments