Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cohere /v2/chat pass-through cost tracking support #8235

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@ tests/local_testing/log.txt
.codegpt
litellm/proxy/_new_new_secret_config.yaml
litellm/proxy/custom_guardrail.py
.mypy_cache/*
357 changes: 357 additions & 0 deletions litellm/llms/cohere/chat/v2_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,357 @@
import time
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union

import httpx

import litellm
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.cohere import CohereV2ChatResponse
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk
from litellm.types.utils import ModelResponse, Usage

from ..common_utils import CohereError
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
from ..common_utils import validate_environment as cohere_validate_environment

if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj

LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any


class CohereV2ChatConfig(BaseConfig):
"""
Configuration class for Cohere's API interface.

Args:
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
generation_id (str, optional): Unique identifier for the generated reply.
response_id (str, optional): Unique identifier for the response.
conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
seed (int, optional): A seed to assist reproducibility of the model's response.
"""

preamble: Optional[str] = None
chat_history: Optional[list] = None
generation_id: Optional[str] = None
response_id: Optional[str] = None
conversation_id: Optional[str] = None
prompt_truncation: Optional[str] = None
connectors: Optional[list] = None
search_queries_only: Optional[bool] = None
documents: Optional[list] = None
temperature: Optional[int] = None
max_tokens: Optional[int] = None
k: Optional[int] = None
p: Optional[int] = None
frequency_penalty: Optional[int] = None
presence_penalty: Optional[int] = None
tools: Optional[list] = None
tool_results: Optional[list] = None
seed: Optional[int] = None

def __init__(
self,
preamble: Optional[str] = None,
chat_history: Optional[list] = None,
generation_id: Optional[str] = None,
response_id: Optional[str] = None,
conversation_id: Optional[str] = None,
prompt_truncation: Optional[str] = None,
connectors: Optional[list] = None,
search_queries_only: Optional[bool] = None,
documents: Optional[list] = None,
temperature: Optional[int] = None,
max_tokens: Optional[int] = None,
k: Optional[int] = None,
p: Optional[int] = None,
frequency_penalty: Optional[int] = None,
presence_penalty: Optional[int] = None,
tools: Optional[list] = None,
tool_results: Optional[list] = None,
seed: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)

def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return cohere_validate_environment(
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
api_key=api_key,
)

def get_supported_openai_params(self, model: str) -> List[str]:
return [
"stream",
"temperature",
"max_tokens",
"top_p",
"frequency_penalty",
"presence_penalty",
"stop",
"n",
"tools",
"tool_choice",
"seed",
"extra_headers",
]

def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param == "stream":
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "n":
optional_params["num_generations"] = value
if param == "top_p":
optional_params["p"] = value
if param == "frequency_penalty":
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "tools":
optional_params["tools"] = value
if param == "seed":
optional_params["seed"] = value
return optional_params

def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:

## Load Config
for k, v in litellm.CohereChatConfig.get_config().items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v

most_recent_message, chat_history = cohere_messages_pt_v2(
messages=messages, model=model, llm_provider="cohere_chat"
)

## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools
if isinstance(most_recent_message, dict):
optional_params["tool_results"] = [most_recent_message]
elif isinstance(most_recent_message, str):
optional_params["message"] = most_recent_message

## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
optional_params["force_single_step"] = True

return optional_params

def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:

try:
raw_response_json = raw_response.json()
except Exception:
raise CohereError(
message=raw_response.text, status_code=raw_response.status_code
)

try:
cohere_v2_chat_response = CohereV2ChatResponse(**raw_response_json) # type: ignore
except Exception:
raise CohereError(message=raw_response.text, status_code=422)

cohere_content = cohere_v2_chat_response["message"].get("content", None)
if cohere_content is not None:
model_response.choices[0].message.content = "".join( # type: ignore
[
content.get("text", "")
for content in cohere_content
if content is not None
]
)

## ADD CITATIONS
if "citations" in cohere_v2_chat_response:
setattr(model_response, "citations", cohere_v2_chat_response["citations"])

## Tool calling response
cohere_tools_response = cohere_v2_chat_response["message"].get("tool_calls", [])
if cohere_tools_response is not None and cohere_tools_response != []:
# convert cohere_tools_response to OpenAI response format
tool_calls: List[ChatCompletionToolCallChunk] = []
for index, tool in enumerate(cohere_tools_response):
tool_call: ChatCompletionToolCallChunk = {
**tool, # type: ignore
"index": index,
}
tool_calls.append(tool_call)
_message = litellm.Message(
tool_calls=tool_calls,
content=None,
)
model_response.choices[0].message = _message # type: ignore

## CALCULATING USAGE - use cohere `billed_units` for returning usage
token_usage = cohere_v2_chat_response["usage"].get("tokens", {})
prompt_tokens = token_usage.get("input_tokens", 0)
completion_tokens = token_usage.get("output_tokens", 0)

model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response

def _construct_cohere_tool(
self,
tools: Optional[list] = None,
):
if tools is None:
tools = []
cohere_tools = []
for tool in tools:
cohere_tool = self._translate_openai_tool_to_cohere(tool)
cohere_tools.append(cohere_tool)
return cohere_tools

def _translate_openai_tool_to_cohere(
self,
openai_tool: dict,
):
# cohere tools look like this
"""
{
"name": "query_daily_sales_report",
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
"parameter_definitions": {
"day": {
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
"type": "str",
"required": True
}
}
}
"""

# OpenAI tools look like this
"""
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
"""
cohere_tool = {
"name": openai_tool["function"]["name"],
"description": openai_tool["function"]["description"],
"parameter_definitions": {},
}

for param_name, param_def in openai_tool["function"]["parameters"][
"properties"
].items():
required_params = (
openai_tool.get("function", {})
.get("parameters", {})
.get("required", [])
)
cohere_param_def = {
"description": param_def.get("description", ""),
"type": param_def.get("type", ""),
"required": param_name in required_params,
}
cohere_tool["parameter_definitions"][param_name] = cohere_param_def

return cohere_tool

def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return CohereModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)

def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return CohereError(status_code=status_code, message=error_message)
Loading