Skip to content

Commit 9b2eb17

Browse files
fix(cohere.py): support async cohere embedding calls
1 parent 185a685 commit 9b2eb17

File tree

3 files changed

+132
-30
lines changed

3 files changed

+132
-30
lines changed

litellm/llms/cohere.py

Lines changed: 115 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import traceback
55
import types
66
from enum import Enum
7-
from typing import Callable, Optional
7+
from typing import Any, Callable, Optional, Union
88

99
import httpx # type: ignore
1010
import requests # type: ignore
1111

1212
import litellm
13+
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
14+
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
1315
from litellm.utils import Choices, Message, ModelResponse, Usage
1416

1517

@@ -246,14 +248,98 @@ def completion(
246248
return model_response
247249

248250

251+
def _process_embedding_response(
252+
embeddings: list,
253+
model_response: litellm.EmbeddingResponse,
254+
model: str,
255+
encoding: Any,
256+
input: list,
257+
) -> litellm.EmbeddingResponse:
258+
output_data = []
259+
for idx, embedding in enumerate(embeddings):
260+
output_data.append(
261+
{"object": "embedding", "index": idx, "embedding": embedding}
262+
)
263+
model_response.object = "list"
264+
model_response.data = output_data
265+
model_response.model = model
266+
input_tokens = 0
267+
for text in input:
268+
input_tokens += len(encoding.encode(text))
269+
270+
setattr(
271+
model_response,
272+
"usage",
273+
Usage(
274+
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
275+
),
276+
)
277+
278+
return model_response
279+
280+
281+
async def async_embedding(
282+
model: str,
283+
data: dict,
284+
input: list,
285+
model_response: litellm.utils.EmbeddingResponse,
286+
timeout: Union[float, httpx.Timeout],
287+
logging_obj: LiteLLMLoggingObj,
288+
optional_params: dict,
289+
api_base: str,
290+
api_key: Optional[str],
291+
headers: dict,
292+
encoding: Callable,
293+
client: Optional[AsyncHTTPHandler] = None,
294+
):
295+
296+
## LOGGING
297+
logging_obj.pre_call(
298+
input=input,
299+
api_key=api_key,
300+
additional_args={
301+
"complete_input_dict": data,
302+
"headers": headers,
303+
"api_base": api_base,
304+
},
305+
)
306+
## COMPLETION CALL
307+
if client is None:
308+
client = AsyncHTTPHandler(concurrent_limit=1)
309+
310+
response = await client.post(api_base, headers=headers, data=json.dumps(data))
311+
312+
## LOGGING
313+
logging_obj.post_call(
314+
input=input,
315+
api_key=api_key,
316+
additional_args={"complete_input_dict": data},
317+
original_response=response,
318+
)
319+
320+
embeddings = response.json()["embeddings"]
321+
322+
## PROCESS RESPONSE ##
323+
return _process_embedding_response(
324+
embeddings=embeddings,
325+
model_response=model_response,
326+
model=model,
327+
encoding=encoding,
328+
input=input,
329+
)
330+
331+
249332
def embedding(
250333
model: str,
251334
input: list,
252335
model_response: litellm.EmbeddingResponse,
336+
logging_obj: LiteLLMLoggingObj,
337+
optional_params: dict,
338+
encoding: Any,
253339
api_key: Optional[str] = None,
254-
logging_obj=None,
255-
encoding=None,
256-
optional_params=None,
340+
aembedding: Optional[bool] = None,
341+
timeout: Union[float, httpx.Timeout] = httpx.Timeout(None),
342+
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
257343
):
258344
headers = validate_environment(api_key)
259345
embed_url = "https://api.cohere.ai/v1/embed"
@@ -270,8 +356,26 @@ def embedding(
270356
api_key=api_key,
271357
additional_args={"complete_input_dict": data},
272358
)
359+
360+
## ROUTING
361+
if aembedding is True:
362+
return async_embedding(
363+
model=model,
364+
data=data,
365+
input=input,
366+
model_response=model_response,
367+
timeout=timeout,
368+
logging_obj=logging_obj,
369+
optional_params=optional_params,
370+
api_base=embed_url,
371+
api_key=api_key,
372+
headers=headers,
373+
encoding=encoding,
374+
)
273375
## COMPLETION CALL
274-
response = requests.post(embed_url, headers=headers, data=json.dumps(data))
376+
if client is None or not isinstance(client, HTTPHandler):
377+
client = HTTPHandler(concurrent_limit=1)
378+
response = client.post(embed_url, headers=headers, data=json.dumps(data))
275379
## LOGGING
276380
logging_obj.post_call(
277381
input=input,
@@ -293,23 +397,11 @@ def embedding(
293397
if response.status_code != 200:
294398
raise CohereError(message=response.text, status_code=response.status_code)
295399
embeddings = response.json()["embeddings"]
296-
output_data = []
297-
for idx, embedding in enumerate(embeddings):
298-
output_data.append(
299-
{"object": "embedding", "index": idx, "embedding": embedding}
300-
)
301-
model_response.object = "list"
302-
model_response.data = output_data
303-
model_response.model = model
304-
input_tokens = 0
305-
for text in input:
306-
input_tokens += len(encoding.encode(text))
307400

308-
setattr(
309-
model_response,
310-
"usage",
311-
Usage(
312-
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
313-
),
401+
return _process_embedding_response(
402+
embeddings=embeddings,
403+
model_response=model_response,
404+
model=model,
405+
encoding=encoding,
406+
input=input,
314407
)
315-
return model_response

litellm/main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3114,6 +3114,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
31143114
or custom_llm_provider == "vertex_ai"
31153115
or custom_llm_provider == "databricks"
31163116
or custom_llm_provider == "watsonx"
3117+
or custom_llm_provider == "cohere"
31173118
): # currently implemented aiohttp calls for just azure and openai, soon all.
31183119
# Await normally
31193120
init_response = await loop.run_in_executor(None, func_with_context)
@@ -3440,9 +3441,12 @@ def embedding(
34403441
input=input,
34413442
optional_params=optional_params,
34423443
encoding=encoding,
3443-
api_key=cohere_key,
3444+
api_key=cohere_key, # type: ignore
34443445
logging_obj=logging,
34453446
model_response=EmbeddingResponse(),
3447+
aembedding=aembedding,
3448+
timeout=float(timeout),
3449+
client=client,
34463450
)
34473451
elif custom_llm_provider == "huggingface":
34483452
api_key = (

litellm/tests/test_embedding.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,20 @@ def test_openai_azure_embedding_optional_arg(mocker):
257257
# test_openai_embedding()
258258

259259

260-
def test_cohere_embedding():
260+
@pytest.mark.parametrize("sync_mode", [True, False])
261+
@pytest.mark.asyncio
262+
async def test_cohere_embedding(sync_mode):
261263
try:
262264
# litellm.set_verbose=True
263-
response = embedding(
264-
model="embed-english-v2.0",
265-
input=["good morning from litellm", "this is another item"],
266-
input_type="search_query",
267-
)
265+
data = {
266+
"model": "embed-english-v2.0",
267+
"input": ["good morning from litellm", "this is another item"],
268+
"input_type": "search_query",
269+
}
270+
if sync_mode:
271+
response = embedding(**data)
272+
else:
273+
response = await litellm.aembedding(**data)
268274
print(f"response:", response)
269275

270276
assert isinstance(response.usage, litellm.Usage)

0 commit comments

Comments
 (0)