diff --git a/litellm/llms/cohere.py b/litellm/llms/cohere.py index d946a8ddeef5..f3c2770b3bfb 100644 --- a/litellm/llms/cohere.py +++ b/litellm/llms/cohere.py @@ -4,12 +4,14 @@ import traceback import types from enum import Enum -from typing import Callable, Optional +from typing import Any, Callable, Optional, Union import httpx # type: ignore import requests # type: ignore import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.utils import Choices, Message, ModelResponse, Usage @@ -246,14 +248,98 @@ def completion( return model_response +def _process_embedding_response( + embeddings: list, + model_response: litellm.EmbeddingResponse, + model: str, + encoding: Any, + input: list, +) -> litellm.EmbeddingResponse: + output_data = [] + for idx, embedding in enumerate(embeddings): + output_data.append( + {"object": "embedding", "index": idx, "embedding": embedding} + ) + model_response.object = "list" + model_response.data = output_data + model_response.model = model + input_tokens = 0 + for text in input: + input_tokens += len(encoding.encode(text)) + + setattr( + model_response, + "usage", + Usage( + prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens + ), + ) + + return model_response + + +async def async_embedding( + model: str, + data: dict, + input: list, + model_response: litellm.utils.EmbeddingResponse, + timeout: Union[float, httpx.Timeout], + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + api_base: str, + api_key: Optional[str], + headers: dict, + encoding: Callable, + client: Optional[AsyncHTTPHandler] = None, +): + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": api_base, + }, + ) + ## COMPLETION CALL + if client is None: + client = AsyncHTTPHandler(concurrent_limit=1) + + response = await client.post(api_base, headers=headers, data=json.dumps(data)) + + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + + embeddings = response.json()["embeddings"] + + ## PROCESS RESPONSE ## + return _process_embedding_response( + embeddings=embeddings, + model_response=model_response, + model=model, + encoding=encoding, + input=input, + ) + + def embedding( model: str, input: list, model_response: litellm.EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + encoding: Any, api_key: Optional[str] = None, - logging_obj=None, - encoding=None, - optional_params=None, + aembedding: Optional[bool] = None, + timeout: Union[float, httpx.Timeout] = httpx.Timeout(None), + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): headers = validate_environment(api_key) embed_url = "https://api.cohere.ai/v1/embed" @@ -270,8 +356,26 @@ def embedding( api_key=api_key, additional_args={"complete_input_dict": data}, ) + + ## ROUTING + if aembedding is True: + return async_embedding( + model=model, + data=data, + input=input, + model_response=model_response, + timeout=timeout, + logging_obj=logging_obj, + optional_params=optional_params, + api_base=embed_url, + api_key=api_key, + headers=headers, + encoding=encoding, + ) ## COMPLETION CALL - response = requests.post(embed_url, headers=headers, data=json.dumps(data)) + if client is None or not isinstance(client, HTTPHandler): + client = HTTPHandler(concurrent_limit=1) + response = client.post(embed_url, headers=headers, data=json.dumps(data)) ## LOGGING logging_obj.post_call( input=input, @@ -293,23 +397,11 @@ def embedding( if response.status_code != 200: raise CohereError(message=response.text, status_code=response.status_code) embeddings = response.json()["embeddings"] - output_data = [] - for idx, embedding in enumerate(embeddings): - output_data.append( - {"object": "embedding", "index": idx, "embedding": embedding} - ) - model_response.object = "list" - model_response.data = output_data - model_response.model = model - input_tokens = 0 - for text in input: - input_tokens += len(encoding.encode(text)) - setattr( - model_response, - "usage", - Usage( - prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens - ), + return _process_embedding_response( + embeddings=embeddings, + model_response=model_response, + model=model, + encoding=encoding, + input=input, ) - return model_response diff --git a/litellm/main.py b/litellm/main.py index 3a52ae29b0b5..528cbf071048 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3114,6 +3114,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse: or custom_llm_provider == "vertex_ai" or custom_llm_provider == "databricks" or custom_llm_provider == "watsonx" + or custom_llm_provider == "cohere" ): # currently implemented aiohttp calls for just azure and openai, soon all. # Await normally init_response = await loop.run_in_executor(None, func_with_context) @@ -3440,9 +3441,12 @@ def embedding( input=input, optional_params=optional_params, encoding=encoding, - api_key=cohere_key, + api_key=cohere_key, # type: ignore logging_obj=logging, model_response=EmbeddingResponse(), + aembedding=aembedding, + timeout=float(timeout), + client=client, ) elif custom_llm_provider == "huggingface": api_key = ( diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 79ba8bc3ee6b..c44967a9abb2 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -257,14 +257,20 @@ def test_openai_azure_embedding_optional_arg(mocker): # test_openai_embedding() -def test_cohere_embedding(): +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio +async def test_cohere_embedding(sync_mode): try: # litellm.set_verbose=True - response = embedding( - model="embed-english-v2.0", - input=["good morning from litellm", "this is another item"], - input_type="search_query", - ) + data = { + "model": "embed-english-v2.0", + "input": ["good morning from litellm", "this is another item"], + "input_type": "search_query", + } + if sync_mode: + response = embedding(**data) + else: + response = await litellm.aembedding(**data) print(f"response:", response) assert isinstance(response.usage, litellm.Usage)