Skip to content

Commit

Permalink
Merge pull request #4977 from BerriAI/litellm_async_cohere_calls
Browse files Browse the repository at this point in the history
fix(cohere.py): support async cohere embedding calls
  • Loading branch information
krrishdholakia authored Jul 30, 2024
2 parents 3a70849 + 653aefd commit bc8cc4f
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 30 deletions.
138 changes: 115 additions & 23 deletions litellm/llms/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import traceback
import types
from enum import Enum
from typing import Callable, Optional
from typing import Any, Callable, Optional, Union

import httpx # type: ignore
import requests # type: ignore

import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import Choices, Message, ModelResponse, Usage


Expand Down Expand Up @@ -249,14 +251,98 @@ def completion(
return model_response


def _process_embedding_response(
embeddings: list,
model_response: litellm.EmbeddingResponse,
model: str,
encoding: Any,
input: list,
) -> litellm.EmbeddingResponse:
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))

setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
),
)

return model_response


async def async_embedding(
model: str,
data: dict,
input: list,
model_response: litellm.utils.EmbeddingResponse,
timeout: Union[float, httpx.Timeout],
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
api_base: str,
api_key: Optional[str],
headers: dict,
encoding: Callable,
client: Optional[AsyncHTTPHandler] = None,
):

## LOGGING
logging_obj.pre_call(
input=input,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
},
)
## COMPLETION CALL
if client is None:
client = AsyncHTTPHandler(concurrent_limit=1)

response = await client.post(api_base, headers=headers, data=json.dumps(data))

## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,
)

embeddings = response.json()["embeddings"]

## PROCESS RESPONSE ##
return _process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
encoding=encoding,
input=input,
)


def embedding(
model: str,
input: list,
model_response: litellm.EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
encoding: Any,
api_key: Optional[str] = None,
logging_obj=None,
encoding=None,
optional_params=None,
aembedding: Optional[bool] = None,
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
):
headers = validate_environment(api_key)
embed_url = "https://api.cohere.ai/v1/embed"
Expand All @@ -273,8 +359,26 @@ def embedding(
api_key=api_key,
additional_args={"complete_input_dict": data},
)

## ROUTING
if aembedding is True:
return async_embedding(
model=model,
data=data,
input=input,
model_response=model_response,
timeout=timeout,
logging_obj=logging_obj,
optional_params=optional_params,
api_base=embed_url,
api_key=api_key,
headers=headers,
encoding=encoding,
)
## COMPLETION CALL
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
if client is None or not isinstance(client, HTTPHandler):
client = HTTPHandler(concurrent_limit=1)
response = client.post(embed_url, headers=headers, data=json.dumps(data))
## LOGGING
logging_obj.post_call(
input=input,
Expand All @@ -296,23 +400,11 @@ def embedding(
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
embeddings = response.json()["embeddings"]
output_data = []
for idx, embedding in enumerate(embeddings):
output_data.append(
{"object": "embedding", "index": idx, "embedding": embedding}
)
model_response.object = "list"
model_response.data = output_data
model_response.model = model
input_tokens = 0
for text in input:
input_tokens += len(encoding.encode(text))

setattr(
model_response,
"usage",
Usage(
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
),
return _process_embedding_response(
embeddings=embeddings,
model_response=model_response,
model=model,
encoding=encoding,
input=input,
)
return model_response
6 changes: 5 additions & 1 deletion litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3114,6 +3114,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "databricks"
or custom_llm_provider == "watsonx"
or custom_llm_provider == "cohere"
or custom_llm_provider == "huggingface"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
Expand Down Expand Up @@ -3441,9 +3442,12 @@ def embedding(
input=input,
optional_params=optional_params,
encoding=encoding,
api_key=cohere_key,
api_key=cohere_key, # type: ignore
logging_obj=logging,
model_response=EmbeddingResponse(),
aembedding=aembedding,
timeout=float(timeout),
client=client,
)
elif custom_llm_provider == "huggingface":
api_key = (
Expand Down
18 changes: 12 additions & 6 deletions litellm/tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,20 @@ def test_openai_azure_embedding_optional_arg(mocker):
# test_openai_embedding()


def test_cohere_embedding():
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_cohere_embedding(sync_mode):
try:
# litellm.set_verbose=True
response = embedding(
model="embed-english-v2.0",
input=["good morning from litellm", "this is another item"],
input_type="search_query",
)
data = {
"model": "embed-english-v2.0",
"input": ["good morning from litellm", "this is another item"],
"input_type": "search_query",
}
if sync_mode:
response = embedding(**data)
else:
response = await litellm.aembedding(**data)
print(f"response:", response)

assert isinstance(response.usage, litellm.Usage)
Expand Down

0 comments on commit bc8cc4f

Please sign in to comment.