Skip to content

Commit f4e524c

Browse files
committed
Tool calling support. Part I
Add tool calling support for below providers - OpenAI, Groq, Anthropic, AWS, & Mistral. OpenAI compatible SDKs need to changes for tool calling support. Adding utility ToolManager for users to easily supply tools, and parse model's request for tool usage.
1 parent 1b5da0e commit f4e524c

File tree

11 files changed

+796
-88
lines changed

11 files changed

+796
-88
lines changed

aisuite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .client import Client
2+
from .framework.message import Message

aisuite/framework/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .provider_interface import ProviderInterface
22
from .chat_completion_response import ChatCompletionResponse
3+
from .message import Message

aisuite/framework/choice.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from aisuite.framework.message import Message
2+
from typing import Literal, Optional
23

34

45
class Choice:
56
def __init__(self):
6-
self.message = Message()
7+
self.finish_reason: Optional[Literal["stop", "tool_calls"]] = None
8+
self.message = Message(
9+
content=None, tool_calls=None, role="assistant", refusal=None
10+
)

aisuite/framework/message.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,22 @@
11
"""Interface to hold contents of api responses when they do not conform to the OpenAI style response"""
22

3+
from pydantic import BaseModel
4+
from typing import Literal, Optional
35

4-
class Message:
5-
def __init__(self):
6-
self.content = None
6+
7+
class Function(BaseModel):
8+
arguments: str
9+
name: str
10+
11+
12+
class ChatCompletionMessageToolCall(BaseModel):
13+
id: str
14+
function: Function
15+
type: Literal["function"]
16+
17+
18+
class Message(BaseModel):
19+
content: Optional[str]
20+
tool_calls: Optional[list[ChatCompletionMessageToolCall]]
21+
role: Optional[Literal["user", "assistant", "system"]]
22+
refusal: Optional[str]
Lines changed: 202 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,31 @@
11
import anthropic
2+
import json
23
from aisuite.provider import Provider
34
from aisuite.framework import ChatCompletionResponse
5+
from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function
46

57
# Define a constant for the default max_tokens value
68
DEFAULT_MAX_TOKENS = 4096
79

10+
# Links:
11+
# Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use
12+
813

914
class AnthropicProvider(Provider):
15+
# Add these at the class level, after the class definition
16+
FINISH_REASON_MAPPING = {
17+
"end_turn": "stop",
18+
"max_tokens": "length",
19+
"tool_use": "tool_calls",
20+
# Add more mappings as needed
21+
}
22+
23+
# Role constants
24+
ROLE_USER = "user"
25+
ROLE_ASSISTANT = "assistant"
26+
ROLE_TOOL = "tool"
27+
ROLE_SYSTEM = "system"
28+
1029
def __init__(self, **config):
1130
"""
1231
Initialize the Anthropic provider with the given configuration.
@@ -15,26 +34,195 @@ def __init__(self, **config):
1534

1635
self.client = anthropic.Anthropic(**config)
1736

37+
def convert_request(self, messages):
38+
"""Convert framework messages to Anthropic format."""
39+
return [self._convert_single_message(msg) for msg in messages]
40+
41+
def _convert_single_message(self, msg):
42+
"""Convert a single message to Anthropic format."""
43+
if isinstance(msg, dict):
44+
return self._convert_dict_message(msg)
45+
return self._convert_message_object(msg)
46+
47+
def _convert_dict_message(self, msg):
48+
"""Convert a dictionary message to Anthropic format."""
49+
if msg["role"] == self.ROLE_TOOL:
50+
return self._create_tool_result_message(msg["tool_call_id"], msg["content"])
51+
elif msg["role"] == self.ROLE_ASSISTANT and "tool_calls" in msg:
52+
return self._create_assistant_tool_message(
53+
msg["content"], msg["tool_calls"]
54+
)
55+
return {"role": msg["role"], "content": msg["content"]}
56+
57+
def _convert_message_object(self, msg):
58+
"""Convert a Message object to Anthropic format."""
59+
if msg.role == self.ROLE_TOOL:
60+
return self._create_tool_result_message(msg.tool_call_id, msg.content)
61+
elif msg.role == self.ROLE_ASSISTANT and msg.tool_calls:
62+
return self._create_assistant_tool_message(msg.content, msg.tool_calls)
63+
return {"role": msg.role, "content": msg.content}
64+
65+
def _create_tool_result_message(self, tool_call_id, content):
66+
"""Create a tool result message in Anthropic format."""
67+
return {
68+
"role": self.ROLE_USER,
69+
"content": [
70+
{
71+
"type": "tool_result",
72+
"tool_use_id": tool_call_id,
73+
"content": content,
74+
}
75+
],
76+
}
77+
78+
def _create_assistant_tool_message(self, content, tool_calls):
79+
"""Create an assistant message with tool calls in Anthropic format."""
80+
message_content = []
81+
if content:
82+
message_content.append({"type": "text", "text": content})
83+
84+
for tool_call in tool_calls:
85+
tool_input = (
86+
tool_call["function"]["arguments"]
87+
if isinstance(tool_call, dict)
88+
else tool_call.function.arguments
89+
)
90+
message_content.append(
91+
{
92+
"type": "tool_use",
93+
"id": (
94+
tool_call["id"] if isinstance(tool_call, dict) else tool_call.id
95+
),
96+
"name": (
97+
tool_call["function"]["name"]
98+
if isinstance(tool_call, dict)
99+
else tool_call.function.name
100+
),
101+
"input": json.loads(tool_input),
102+
}
103+
)
104+
105+
return {"role": self.ROLE_ASSISTANT, "content": message_content}
106+
18107
def chat_completions_create(self, model, messages, **kwargs):
19-
# Check if the fist message is a system message
20-
if messages[0]["role"] == "system":
108+
"""Create a chat completion using the Anthropic API."""
109+
system_message = self._extract_system_message(messages)
110+
kwargs = self._prepare_kwargs(kwargs)
111+
converted_messages = self.convert_request(messages)
112+
113+
response = self.client.messages.create(
114+
model=model, system=system_message, messages=converted_messages, **kwargs
115+
)
116+
return self.convert_response(response)
117+
118+
def _extract_system_message(self, messages):
119+
"""Extract system message if present, otherwise return empty list."""
120+
if messages and messages[0]["role"] == "system":
21121
system_message = messages[0]["content"]
22-
messages = messages[1:]
23-
else:
24-
system_message = []
122+
messages.pop(0)
123+
return system_message
124+
return []
25125

26-
# kwargs.setdefault('max_tokens', DEFAULT_MAX_TOKENS)
27-
if "max_tokens" not in kwargs:
28-
kwargs["max_tokens"] = DEFAULT_MAX_TOKENS
126+
def _prepare_kwargs(self, kwargs):
127+
"""Prepare kwargs for the API call."""
128+
kwargs = kwargs.copy() # Create a copy to avoid modifying the original
129+
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS)
29130

30-
return self.normalize_response(
31-
self.client.messages.create(
32-
model=model, system=system_message, messages=messages, **kwargs
33-
)
131+
if "tools" in kwargs:
132+
kwargs["tools"] = self._convert_tool_spec(kwargs["tools"])
133+
134+
return kwargs
135+
136+
def convert_response_with_tool_use(self, response):
137+
"""Convert Anthropic tool use response to the framework's format."""
138+
# Find the tool_use content
139+
tool_call = next(
140+
(content for content in response.content if content.type == "tool_use"),
141+
None,
34142
)
35143

36-
def normalize_response(self, response):
144+
if tool_call:
145+
function = Function(
146+
name=tool_call.name, arguments=json.dumps(tool_call.input)
147+
)
148+
tool_call_obj = ChatCompletionMessageToolCall(
149+
id=tool_call.id, function=function, type="function"
150+
)
151+
# Get the text content if any
152+
text_content = next(
153+
(
154+
content.text
155+
for content in response.content
156+
if content.type == "text"
157+
),
158+
"",
159+
)
160+
161+
return Message(
162+
content=text_content or None,
163+
tool_calls=[tool_call_obj] if tool_call else None,
164+
role="assistant",
165+
refusal=None,
166+
)
167+
return None
168+
169+
def convert_response(self, response):
37170
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
38171
normalized_response = ChatCompletionResponse()
39-
normalized_response.choices[0].message.content = response.content[0].text
172+
173+
normalized_response.choices[0].finish_reason = self._get_finish_reason(response)
174+
normalized_response.usage = self._get_usage_stats(response)
175+
normalized_response.choices[0].message = self._get_message(response)
176+
40177
return normalized_response
178+
179+
def _get_finish_reason(self, response):
180+
"""Get the normalized finish reason."""
181+
return self.FINISH_REASON_MAPPING.get(response.stop_reason, "stop")
182+
183+
def _get_usage_stats(self, response):
184+
"""Get the usage statistics."""
185+
return {
186+
"prompt_tokens": response.usage.input_tokens,
187+
"completion_tokens": response.usage.output_tokens,
188+
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
189+
}
190+
191+
def _get_message(self, response):
192+
"""Get the appropriate message based on response type."""
193+
if response.stop_reason == "tool_use":
194+
tool_message = self.convert_response_with_tool_use(response)
195+
if tool_message:
196+
return tool_message
197+
198+
return Message(
199+
content=response.content[0].text,
200+
role="assistant",
201+
tool_calls=None,
202+
refusal=None,
203+
)
204+
205+
def _convert_tool_spec(self, openai_tools):
206+
"""Convert OpenAI tool specification to Anthropic format."""
207+
anthropic_tools = []
208+
209+
for tool in openai_tools:
210+
# Only handle function-type tools from OpenAI
211+
if tool.get("type") != "function":
212+
continue
213+
214+
function = tool["function"]
215+
216+
anthropic_tool = {
217+
"name": function["name"],
218+
"description": function["description"],
219+
"input_schema": {
220+
"type": "object",
221+
"properties": function["parameters"]["properties"],
222+
"required": function["parameters"].get("required", []),
223+
},
224+
}
225+
226+
anthropic_tools.append(anthropic_tool)
227+
228+
return anthropic_tools

0 commit comments

Comments
 (0)