Skip to content

Commit bc8cc4f

Browse files
Merge pull request #4977 from BerriAI/litellm_async_cohere_calls
fix(cohere.py): support async cohere embedding calls
2 parents 3a70849 + 653aefd commit bc8cc4f

File tree

3 files changed

+132
-30
lines changed

3 files changed

+132
-30
lines changed

litellm/llms/cohere.py

Lines changed: 115 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import traceback
88
import types
99
from enum import Enum
10-
from typing import Callable, Optional
10+
from typing import Any, Callable, Optional, Union
1111

1212
import httpx # type: ignore
1313
import requests # type: ignore
1414

1515
import litellm
16+
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
17+
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
1618
from litellm.utils import Choices, Message, ModelResponse, Usage
1719

1820

@@ -249,14 +251,98 @@ def completion(
249251
return model_response
250252

251253

254+
def _process_embedding_response(
255+
embeddings: list,
256+
model_response: litellm.EmbeddingResponse,
257+
model: str,
258+
encoding: Any,
259+
input: list,
260+
) -> litellm.EmbeddingResponse:
261+
output_data = []
262+
for idx, embedding in enumerate(embeddings):
263+
output_data.append(
264+
{"object": "embedding", "index": idx, "embedding": embedding}
265+
)
266+
model_response.object = "list"
267+
model_response.data = output_data
268+
model_response.model = model
269+
input_tokens = 0
270+
for text in input:
271+
input_tokens += len(encoding.encode(text))
272+
273+
setattr(
274+
model_response,
275+
"usage",
276+
Usage(
277+
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
278+
),
279+
)
280+
281+
return model_response
282+
283+
284+
async def async_embedding(
285+
model: str,
286+
data: dict,
287+
input: list,
288+
model_response: litellm.utils.EmbeddingResponse,
289+
timeout: Union[float, httpx.Timeout],
290+
logging_obj: LiteLLMLoggingObj,
291+
optional_params: dict,
292+
api_base: str,
293+
api_key: Optional[str],
294+
headers: dict,
295+
encoding: Callable,
296+
client: Optional[AsyncHTTPHandler] = None,
297+
):
298+
299+
## LOGGING
300+
logging_obj.pre_call(
301+
input=input,
302+
api_key=api_key,
303+
additional_args={
304+
"complete_input_dict": data,
305+
"headers": headers,
306+
"api_base": api_base,
307+
},
308+
)
309+
## COMPLETION CALL
310+
if client is None:
311+
client = AsyncHTTPHandler(concurrent_limit=1)
312+
313+
response = await client.post(api_base, headers=headers, data=json.dumps(data))
314+
315+
## LOGGING
316+
logging_obj.post_call(
317+
input=input,
318+
api_key=api_key,
319+
additional_args={"complete_input_dict": data},
320+
original_response=response,
321+
)
322+
323+
embeddings = response.json()["embeddings"]
324+
325+
## PROCESS RESPONSE ##
326+
return _process_embedding_response(
327+
embeddings=embeddings,
328+
model_response=model_response,
329+
model=model,
330+
encoding=encoding,
331+
input=input,
332+
)
333+
334+
252335
def embedding(
253336
model: str,
254337
input: list,
255338
model_response: litellm.EmbeddingResponse,
339+
logging_obj: LiteLLMLoggingObj,
340+
optional_params: dict,
341+
encoding: Any,
256342
api_key: Optional[str] = None,
257-
logging_obj=None,
258-
encoding=None,
259-
optional_params=None,
343+
aembedding: Optional[bool] = None,
344+
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
345+
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
260346
):
261347
headers = validate_environment(api_key)
262348
embed_url = "https://api.cohere.ai/v1/embed"
@@ -273,8 +359,26 @@ def embedding(
273359
api_key=api_key,
274360
additional_args={"complete_input_dict": data},
275361
)
362+
363+
## ROUTING
364+
if aembedding is True:
365+
return async_embedding(
366+
model=model,
367+
data=data,
368+
input=input,
369+
model_response=model_response,
370+
timeout=timeout,
371+
logging_obj=logging_obj,
372+
optional_params=optional_params,
373+
api_base=embed_url,
374+
api_key=api_key,
375+
headers=headers,
376+
encoding=encoding,
377+
)
276378
## COMPLETION CALL
277-
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
379+
if client is None or not isinstance(client, HTTPHandler):
380+
client = HTTPHandler(concurrent_limit=1)
381+
response = client.post(embed_url, headers=headers, data=json.dumps(data))
278382
## LOGGING
279383
logging_obj.post_call(
280384
input=input,
@@ -296,23 +400,11 @@ def embedding(
296400
if response.status_code != 200:
297401
raise CohereError(message=response.text, status_code=response.status_code)
298402
embeddings = response.json()["embeddings"]
299-
output_data = []
300-
for idx, embedding in enumerate(embeddings):
301-
output_data.append(
302-
{"object": "embedding", "index": idx, "embedding": embedding}
303-
)
304-
model_response.object = "list"
305-
model_response.data = output_data
306-
model_response.model = model
307-
input_tokens = 0
308-
for text in input:
309-
input_tokens += len(encoding.encode(text))
310403

311-
setattr(
312-
model_response,
313-
"usage",
314-
Usage(
315-
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
316-
),
404+
return _process_embedding_response(
405+
embeddings=embeddings,
406+
model_response=model_response,
407+
model=model,
408+
encoding=encoding,
409+
input=input,
317410
)
318-
return model_response

litellm/main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3114,6 +3114,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
31143114
or custom_llm_provider == "vertex_ai"
31153115
or custom_llm_provider == "databricks"
31163116
or custom_llm_provider == "watsonx"
3117+
or custom_llm_provider == "cohere"
31173118
or custom_llm_provider == "huggingface"
31183119
): # currently implemented aiohttp calls for just azure and openai, soon all.
31193120
# Await normally
@@ -3441,9 +3442,12 @@ def embedding(
34413442
input=input,
34423443
optional_params=optional_params,
34433444
encoding=encoding,
3444-
api_key=cohere_key,
3445+
api_key=cohere_key, # type: ignore
34453446
logging_obj=logging,
34463447
model_response=EmbeddingResponse(),
3448+
aembedding=aembedding,
3449+
timeout=float(timeout),
3450+
client=client,
34473451
)
34483452
elif custom_llm_provider == "huggingface":
34493453
api_key = (

litellm/tests/test_embedding.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,20 @@ def test_openai_azure_embedding_optional_arg(mocker):
257257
# test_openai_embedding()
258258

259259

260-
def test_cohere_embedding():
260+
@pytest.mark.parametrize("sync_mode", [True, False])
261+
@pytest.mark.asyncio
262+
async def test_cohere_embedding(sync_mode):
261263
try:
262264
# litellm.set_verbose=True
263-
response = embedding(
264-
model="embed-english-v2.0",
265-
input=["good morning from litellm", "this is another item"],
266-
input_type="search_query",
267-
)
265+
data = {
266+
"model": "embed-english-v2.0",
267+
"input": ["good morning from litellm", "this is another item"],
268+
"input_type": "search_query",
269+
}
270+
if sync_mode:
271+
response = embedding(**data)
272+
else:
273+
response = await litellm.aembedding(**data)
268274
print(f"response:", response)
269275

270276
assert isinstance(response.usage, litellm.Usage)

0 commit comments

Comments
 (0)