-
Notifications
You must be signed in to change notification settings - Fork 992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tool calling support. Part I #102
Merged
+5,985
−3,432
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
4ef37c9
Tool calling support. Part I
rohitprasad15 6163974
Adding tool calling support for Azure Inference API.
rohitprasad15 db67176
Few Refactorings.
rohitprasad15 d42d753
Gemini support. Azure code-changes.
rohitprasad15 33fda16
Unit Tests for message conversion.
rohitprasad15 4df5bf7
Add support for remaining providers including HF.
rohitprasad15 d9b7485
Update poetry.lock
rohitprasad15 5b547b8
Fix tests broken due to pydantic checks.
rohitprasad15 70ae662
Fix __init__.py
rohitprasad15 1a7be38
Removed debug prints, and other cleanup.
rohitprasad15 ca67c72
Tool calling support in xAI, Mistral, Together, Cohere, Groq, Sambanova.
rohitprasad15 2d5b919
Skipping a test due to incompatible version deps.
rohitprasad15 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .client import Client | ||
from .framework.message import Message | ||
from .utils.tool_manager import Tools |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .provider_interface import ProviderInterface | ||
from .chat_completion_response import ChatCompletionResponse | ||
from .message import Message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
from aisuite.framework.message import Message | ||
from typing import Literal, Optional | ||
|
||
|
||
class Choice: | ||
def __init__(self): | ||
self.message = Message() | ||
self.finish_reason: Optional[Literal["stop", "tool_calls"]] = None | ||
self.message = Message( | ||
content=None, tool_calls=None, role="assistant", refusal=None | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,22 @@ | ||
"""Interface to hold contents of api responses when they do not confirm to the OpenAI style response""" | ||
|
||
from pydantic import BaseModel | ||
from typing import Literal, Optional | ||
|
||
class Message: | ||
def __init__(self): | ||
self.content = None | ||
|
||
class Function(BaseModel): | ||
arguments: str | ||
name: str | ||
|
||
|
||
class ChatCompletionMessageToolCall(BaseModel): | ||
id: str | ||
function: Function | ||
type: Literal["function"] | ||
|
||
|
||
class Message(BaseModel): | ||
content: Optional[str] | ||
tool_calls: Optional[list[ChatCompletionMessageToolCall]] | ||
role: Optional[Literal["user", "assistant", "system"]] | ||
refusal: Optional[str] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,224 @@ | ||
# Anthropic provider | ||
# Links: | ||
# Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use | ||
|
||
import anthropic | ||
import json | ||
from aisuite.provider import Provider | ||
from aisuite.framework import ChatCompletionResponse | ||
from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function | ||
|
||
# Define a constant for the default max_tokens value | ||
DEFAULT_MAX_TOKENS = 4096 | ||
|
||
|
||
class AnthropicProvider(Provider): | ||
def __init__(self, **config): | ||
""" | ||
Initialize the Anthropic provider with the given configuration. | ||
Pass the entire configuration dictionary to the Anthropic client constructor. | ||
""" | ||
class AnthropicMessageConverter: | ||
# Role constants | ||
ROLE_USER = "user" | ||
ROLE_ASSISTANT = "assistant" | ||
ROLE_TOOL = "tool" | ||
ROLE_SYSTEM = "system" | ||
|
||
self.client = anthropic.Anthropic(**config) | ||
# Finish reason mapping | ||
FINISH_REASON_MAPPING = { | ||
"end_turn": "stop", | ||
"max_tokens": "length", | ||
"tool_use": "tool_calls", | ||
} | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
# Check if the fist message is a system message | ||
if messages[0]["role"] == "system": | ||
def convert_request(self, messages): | ||
"""Convert framework messages to Anthropic format.""" | ||
system_message = self._extract_system_message(messages) | ||
converted_messages = [self._convert_single_message(msg) for msg in messages] | ||
return system_message, converted_messages | ||
|
||
def convert_response(self, response): | ||
"""Normalize the response from the Anthropic API to match OpenAI's response format.""" | ||
normalized_response = ChatCompletionResponse() | ||
normalized_response.choices[0].finish_reason = self._get_finish_reason(response) | ||
normalized_response.usage = self._get_usage_stats(response) | ||
normalized_response.choices[0].message = self._get_message(response) | ||
return normalized_response | ||
|
||
def _convert_single_message(self, msg): | ||
"""Convert a single message to Anthropic format.""" | ||
if isinstance(msg, dict): | ||
return self._convert_dict_message(msg) | ||
return self._convert_message_object(msg) | ||
|
||
def _convert_dict_message(self, msg): | ||
"""Convert a dictionary message to Anthropic format.""" | ||
if msg["role"] == self.ROLE_TOOL: | ||
return self._create_tool_result_message(msg["tool_call_id"], msg["content"]) | ||
elif msg["role"] == self.ROLE_ASSISTANT and "tool_calls" in msg: | ||
return self._create_assistant_tool_message( | ||
msg["content"], msg["tool_calls"] | ||
) | ||
return {"role": msg["role"], "content": msg["content"]} | ||
|
||
def _convert_message_object(self, msg): | ||
"""Convert a Message object to Anthropic format.""" | ||
if msg.role == self.ROLE_TOOL: | ||
return self._create_tool_result_message(msg.tool_call_id, msg.content) | ||
elif msg.role == self.ROLE_ASSISTANT and msg.tool_calls: | ||
return self._create_assistant_tool_message(msg.content, msg.tool_calls) | ||
return {"role": msg.role, "content": msg.content} | ||
|
||
def _create_tool_result_message(self, tool_call_id, content): | ||
"""Create a tool result message in Anthropic format.""" | ||
return { | ||
"role": self.ROLE_USER, | ||
"content": [ | ||
{ | ||
"type": "tool_result", | ||
"tool_use_id": tool_call_id, | ||
"content": content, | ||
} | ||
], | ||
} | ||
|
||
def _create_assistant_tool_message(self, content, tool_calls): | ||
"""Create an assistant message with tool calls in Anthropic format.""" | ||
message_content = [] | ||
if content: | ||
message_content.append({"type": "text", "text": content}) | ||
|
||
for tool_call in tool_calls: | ||
tool_input = ( | ||
tool_call["function"]["arguments"] | ||
if isinstance(tool_call, dict) | ||
else tool_call.function.arguments | ||
) | ||
message_content.append( | ||
{ | ||
"type": "tool_use", | ||
"id": ( | ||
tool_call["id"] if isinstance(tool_call, dict) else tool_call.id | ||
), | ||
"name": ( | ||
tool_call["function"]["name"] | ||
if isinstance(tool_call, dict) | ||
else tool_call.function.name | ||
), | ||
"input": json.loads(tool_input), | ||
} | ||
) | ||
Comment on lines
+87
to
+105
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar thought, converting to either an object or dict at the beginning of this would create a consistency that would not require the |
||
|
||
return {"role": self.ROLE_ASSISTANT, "content": message_content} | ||
|
||
def _extract_system_message(self, messages): | ||
"""Extract system message if present, otherwise return empty list.""" | ||
# TODO: This is a temporary solution to extract the system message. | ||
# User can pass multiple system messages, which can mingled with other messages. | ||
# This needs to be fixed to handle this case. | ||
if messages and messages[0]["role"] == "system": | ||
system_message = messages[0]["content"] | ||
messages = messages[1:] | ||
else: | ||
system_message = [] | ||
messages.pop(0) | ||
return system_message | ||
return [] | ||
|
||
def _get_finish_reason(self, response): | ||
"""Get the normalized finish reason.""" | ||
return self.FINISH_REASON_MAPPING.get(response.stop_reason, "stop") | ||
|
||
# kwargs.setdefault('max_tokens', DEFAULT_MAX_TOKENS) | ||
if "max_tokens" not in kwargs: | ||
kwargs["max_tokens"] = DEFAULT_MAX_TOKENS | ||
def _get_usage_stats(self, response): | ||
"""Get the usage statistics.""" | ||
return { | ||
"prompt_tokens": response.usage.input_tokens, | ||
"completion_tokens": response.usage.output_tokens, | ||
"total_tokens": response.usage.input_tokens + response.usage.output_tokens, | ||
} | ||
|
||
def _get_message(self, response): | ||
"""Get the appropriate message based on response type.""" | ||
if response.stop_reason == "tool_use": | ||
tool_message = self.convert_response_with_tool_use(response) | ||
if tool_message: | ||
return tool_message | ||
|
||
return Message( | ||
content=response.content[0].text, | ||
role="assistant", | ||
tool_calls=None, | ||
refusal=None, | ||
) | ||
|
||
return self.normalize_response( | ||
self.client.messages.create( | ||
model=model, system=system_message, messages=messages, **kwargs | ||
def convert_response_with_tool_use(self, response): | ||
"""Convert Anthropic tool use response to the framework's format.""" | ||
tool_call = next( | ||
(content for content in response.content if content.type == "tool_use"), | ||
None, | ||
) | ||
|
||
if tool_call: | ||
function = Function( | ||
name=tool_call.name, arguments=json.dumps(tool_call.input) | ||
) | ||
tool_call_obj = ChatCompletionMessageToolCall( | ||
id=tool_call.id, function=function, type="function" | ||
) | ||
text_content = next( | ||
( | ||
content.text | ||
for content in response.content | ||
if content.type == "text" | ||
), | ||
"", | ||
) | ||
|
||
return Message( | ||
content=text_content or None, | ||
tool_calls=[tool_call_obj] if tool_call else None, | ||
role="assistant", | ||
refusal=None, | ||
) | ||
return None | ||
|
||
def convert_tool_spec(self, openai_tools): | ||
"""Convert OpenAI tool specification to Anthropic format.""" | ||
anthropic_tools = [] | ||
|
||
for tool in openai_tools: | ||
if tool.get("type") != "function": | ||
continue | ||
|
||
function = tool["function"] | ||
anthropic_tool = { | ||
"name": function["name"], | ||
"description": function["description"], | ||
"input_schema": { | ||
"type": "object", | ||
"properties": function["parameters"]["properties"], | ||
"required": function["parameters"].get("required", []), | ||
}, | ||
} | ||
anthropic_tools.append(anthropic_tool) | ||
|
||
return anthropic_tools | ||
|
||
|
||
class AnthropicProvider(Provider): | ||
def __init__(self, **config): | ||
"""Initialize the Anthropic provider with the given configuration.""" | ||
self.client = anthropic.Anthropic(**config) | ||
self.converter = AnthropicMessageConverter() | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
"""Create a chat completion using the Anthropic API.""" | ||
kwargs = self._prepare_kwargs(kwargs) | ||
system_message, converted_messages = self.converter.convert_request(messages) | ||
|
||
response = self.client.messages.create( | ||
model=model, system=system_message, messages=converted_messages, **kwargs | ||
) | ||
return self.converter.convert_response(response) | ||
|
||
def normalize_response(self, response): | ||
"""Normalize the response from the Anthropic API to match OpenAI's response format.""" | ||
normalized_response = ChatCompletionResponse() | ||
normalized_response.choices[0].message.content = response.content[0].text | ||
return normalized_response | ||
def _prepare_kwargs(self, kwargs): | ||
"""Prepare kwargs for the API call.""" | ||
kwargs = kwargs.copy() | ||
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS) | ||
|
||
if "tools" in kwargs: | ||
kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"]) | ||
|
||
return kwargs |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like both
_convert_dict_message
and_convert_message_object
logically do the same things. Would it be better to convert to one format and keep one of the convert message methods.or the other way