1
+ # Anthropic provider
2
+ # Links:
3
+ # Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use
4
+
1
5
import anthropic
2
6
import json
3
7
from aisuite .provider import Provider
7
11
# Define a constant for the default max_tokens value
8
12
DEFAULT_MAX_TOKENS = 4096
9
13
10
- # Links:
11
- # Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use
12
-
13
-
14
- class AnthropicProvider (Provider ):
15
- # Add these at the class level, after the class definition
16
- FINISH_REASON_MAPPING = {
17
- "end_turn" : "stop" ,
18
- "max_tokens" : "length" ,
19
- "tool_use" : "tool_calls" ,
20
- # Add more mappings as needed
21
- }
22
14
15
+ class AnthropicMessageConverter :
23
16
# Role constants
24
17
ROLE_USER = "user"
25
18
ROLE_ASSISTANT = "assistant"
26
19
ROLE_TOOL = "tool"
27
20
ROLE_SYSTEM = "system"
28
21
29
- def __init__ (self , ** config ):
30
- """
31
- Initialize the Anthropic provider with the given configuration.
32
- Pass the entire configuration dictionary to the Anthropic client constructor.
33
- """
34
-
35
- self .client = anthropic .Anthropic (** config )
22
+ # Finish reason mapping
23
+ FINISH_REASON_MAPPING = {
24
+ "end_turn" : "stop" ,
25
+ "max_tokens" : "length" ,
26
+ "tool_use" : "tool_calls" ,
27
+ }
36
28
37
29
def convert_request (self , messages ):
38
30
"""Convert framework messages to Anthropic format."""
39
- return [self ._convert_single_message (msg ) for msg in messages ]
31
+ system_message = self ._extract_system_message (messages )
32
+ converted_messages = [self ._convert_single_message (msg ) for msg in messages ]
33
+ return system_message , converted_messages
34
+
35
+ def convert_response (self , response ):
36
+ """Normalize the response from the Anthropic API to match OpenAI's response format."""
37
+ normalized_response = ChatCompletionResponse ()
38
+ normalized_response .choices [0 ].finish_reason = self ._get_finish_reason (response )
39
+ normalized_response .usage = self ._get_usage_stats (response )
40
+ normalized_response .choices [0 ].message = self ._get_message (response )
41
+ return normalized_response
40
42
41
43
def _convert_single_message (self , msg ):
42
44
"""Convert a single message to Anthropic format."""
@@ -104,38 +106,45 @@ def _create_assistant_tool_message(self, content, tool_calls):
104
106
105
107
return {"role" : self .ROLE_ASSISTANT , "content" : message_content }
106
108
107
- def chat_completions_create (self , model , messages , ** kwargs ):
108
- """Create a chat completion using the Anthropic API."""
109
- system_message = self ._extract_system_message (messages )
110
- kwargs = self ._prepare_kwargs (kwargs )
111
- converted_messages = self .convert_request (messages )
112
-
113
- response = self .client .messages .create (
114
- model = model , system = system_message , messages = converted_messages , ** kwargs
115
- )
116
- return self .convert_response (response )
117
-
118
109
def _extract_system_message (self , messages ):
119
110
"""Extract system message if present, otherwise return empty list."""
111
+ # TODO: This is a temporary solution to extract the system message.
112
+ # User can pass multiple system messages, which can mingled with other messages.
113
+ # This needs to be fixed to handle this case.
120
114
if messages and messages [0 ]["role" ] == "system" :
121
115
system_message = messages [0 ]["content" ]
122
116
messages .pop (0 )
123
117
return system_message
124
118
return []
125
119
126
- def _prepare_kwargs (self , kwargs ):
127
- """Prepare kwargs for the API call."""
128
- kwargs = kwargs .copy () # Create a copy to avoid modifying the original
129
- kwargs .setdefault ("max_tokens" , DEFAULT_MAX_TOKENS )
120
+ def _get_finish_reason (self , response ):
121
+ """Get the normalized finish reason."""
122
+ return self .FINISH_REASON_MAPPING .get (response .stop_reason , "stop" )
130
123
131
- if "tools" in kwargs :
132
- kwargs ["tools" ] = self ._convert_tool_spec (kwargs ["tools" ])
124
+ def _get_usage_stats (self , response ):
125
+ """Get the usage statistics."""
126
+ return {
127
+ "prompt_tokens" : response .usage .input_tokens ,
128
+ "completion_tokens" : response .usage .output_tokens ,
129
+ "total_tokens" : response .usage .input_tokens + response .usage .output_tokens ,
130
+ }
133
131
134
- return kwargs
132
+ def _get_message (self , response ):
133
+ """Get the appropriate message based on response type."""
134
+ if response .stop_reason == "tool_use" :
135
+ tool_message = self .convert_response_with_tool_use (response )
136
+ if tool_message :
137
+ return tool_message
138
+
139
+ return Message (
140
+ content = response .content [0 ].text ,
141
+ role = "assistant" ,
142
+ tool_calls = None ,
143
+ refusal = None ,
144
+ )
135
145
136
146
def convert_response_with_tool_use (self , response ):
137
147
"""Convert Anthropic tool use response to the framework's format."""
138
- # Find the tool_use content
139
148
tool_call = next (
140
149
(content for content in response .content if content .type == "tool_use" ),
141
150
None ,
@@ -148,7 +157,6 @@ def convert_response_with_tool_use(self, response):
148
157
tool_call_obj = ChatCompletionMessageToolCall (
149
158
id = tool_call .id , function = function , type = "function"
150
159
)
151
- # Get the text content if any
152
160
text_content = next (
153
161
(
154
162
content .text
@@ -166,53 +174,15 @@ def convert_response_with_tool_use(self, response):
166
174
)
167
175
return None
168
176
169
- def convert_response (self , response ):
170
- """Normalize the response from the Anthropic API to match OpenAI's response format."""
171
- normalized_response = ChatCompletionResponse ()
172
-
173
- normalized_response .choices [0 ].finish_reason = self ._get_finish_reason (response )
174
- normalized_response .usage = self ._get_usage_stats (response )
175
- normalized_response .choices [0 ].message = self ._get_message (response )
176
-
177
- return normalized_response
178
-
179
- def _get_finish_reason (self , response ):
180
- """Get the normalized finish reason."""
181
- return self .FINISH_REASON_MAPPING .get (response .stop_reason , "stop" )
182
-
183
- def _get_usage_stats (self , response ):
184
- """Get the usage statistics."""
185
- return {
186
- "prompt_tokens" : response .usage .input_tokens ,
187
- "completion_tokens" : response .usage .output_tokens ,
188
- "total_tokens" : response .usage .input_tokens + response .usage .output_tokens ,
189
- }
190
-
191
- def _get_message (self , response ):
192
- """Get the appropriate message based on response type."""
193
- if response .stop_reason == "tool_use" :
194
- tool_message = self .convert_response_with_tool_use (response )
195
- if tool_message :
196
- return tool_message
197
-
198
- return Message (
199
- content = response .content [0 ].text ,
200
- role = "assistant" ,
201
- tool_calls = None ,
202
- refusal = None ,
203
- )
204
-
205
- def _convert_tool_spec (self , openai_tools ):
177
+ def convert_tool_spec (self , openai_tools ):
206
178
"""Convert OpenAI tool specification to Anthropic format."""
207
179
anthropic_tools = []
208
180
209
181
for tool in openai_tools :
210
- # Only handle function-type tools from OpenAI
211
182
if tool .get ("type" ) != "function" :
212
183
continue
213
184
214
185
function = tool ["function" ]
215
-
216
186
anthropic_tool = {
217
187
"name" : function ["name" ],
218
188
"description" : function ["description" ],
@@ -222,7 +192,33 @@ def _convert_tool_spec(self, openai_tools):
222
192
"required" : function ["parameters" ].get ("required" , []),
223
193
},
224
194
}
225
-
226
195
anthropic_tools .append (anthropic_tool )
227
196
228
197
return anthropic_tools
198
+
199
+
200
+ class AnthropicProvider (Provider ):
201
+ def __init__ (self , ** config ):
202
+ """Initialize the Anthropic provider with the given configuration."""
203
+ self .client = anthropic .Anthropic (** config )
204
+ self .converter = AnthropicMessageConverter ()
205
+
206
+ def chat_completions_create (self , model , messages , ** kwargs ):
207
+ """Create a chat completion using the Anthropic API."""
208
+ kwargs = self ._prepare_kwargs (kwargs )
209
+ system_message , converted_messages = self .converter .convert_request (messages )
210
+
211
+ response = self .client .messages .create (
212
+ model = model , system = system_message , messages = converted_messages , ** kwargs
213
+ )
214
+ return self .converter .convert_response (response )
215
+
216
+ def _prepare_kwargs (self , kwargs ):
217
+ """Prepare kwargs for the API call."""
218
+ kwargs = kwargs .copy ()
219
+ kwargs .setdefault ("max_tokens" , DEFAULT_MAX_TOKENS )
220
+
221
+ if "tools" in kwargs :
222
+ kwargs ["tools" ] = self .converter .convert_tool_spec (kwargs ["tools" ])
223
+
224
+ return kwargs
0 commit comments