1
1
import anthropic
2
+ import json
2
3
from aisuite .provider import Provider
3
4
from aisuite .framework import ChatCompletionResponse
5
+ from aisuite .framework .message import Message , ChatCompletionMessageToolCall , Function
4
6
5
7
# Define a constant for the default max_tokens value
6
8
DEFAULT_MAX_TOKENS = 4096
7
9
10
+ # Links:
11
+ # Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use
12
+
8
13
9
14
class AnthropicProvider (Provider ):
15
+ # Add these at the class level, after the class definition
16
+ FINISH_REASON_MAPPING = {
17
+ "end_turn" : "stop" ,
18
+ "max_tokens" : "length" ,
19
+ "tool_use" : "tool_calls" ,
20
+ # Add more mappings as needed
21
+ }
22
+
23
+ # Role constants
24
+ ROLE_USER = "user"
25
+ ROLE_ASSISTANT = "assistant"
26
+ ROLE_TOOL = "tool"
27
+ ROLE_SYSTEM = "system"
28
+
10
29
def __init__ (self , ** config ):
11
30
"""
12
31
Initialize the Anthropic provider with the given configuration.
@@ -15,26 +34,195 @@ def __init__(self, **config):
15
34
16
35
self .client = anthropic .Anthropic (** config )
17
36
37
+ def convert_request (self , messages ):
38
+ """Convert framework messages to Anthropic format."""
39
+ return [self ._convert_single_message (msg ) for msg in messages ]
40
+
41
+ def _convert_single_message (self , msg ):
42
+ """Convert a single message to Anthropic format."""
43
+ if isinstance (msg , dict ):
44
+ return self ._convert_dict_message (msg )
45
+ return self ._convert_message_object (msg )
46
+
47
+ def _convert_dict_message (self , msg ):
48
+ """Convert a dictionary message to Anthropic format."""
49
+ if msg ["role" ] == self .ROLE_TOOL :
50
+ return self ._create_tool_result_message (msg ["tool_call_id" ], msg ["content" ])
51
+ elif msg ["role" ] == self .ROLE_ASSISTANT and "tool_calls" in msg :
52
+ return self ._create_assistant_tool_message (
53
+ msg ["content" ], msg ["tool_calls" ]
54
+ )
55
+ return {"role" : msg ["role" ], "content" : msg ["content" ]}
56
+
57
+ def _convert_message_object (self , msg ):
58
+ """Convert a Message object to Anthropic format."""
59
+ if msg .role == self .ROLE_TOOL :
60
+ return self ._create_tool_result_message (msg .tool_call_id , msg .content )
61
+ elif msg .role == self .ROLE_ASSISTANT and msg .tool_calls :
62
+ return self ._create_assistant_tool_message (msg .content , msg .tool_calls )
63
+ return {"role" : msg .role , "content" : msg .content }
64
+
65
+ def _create_tool_result_message (self , tool_call_id , content ):
66
+ """Create a tool result message in Anthropic format."""
67
+ return {
68
+ "role" : self .ROLE_USER ,
69
+ "content" : [
70
+ {
71
+ "type" : "tool_result" ,
72
+ "tool_use_id" : tool_call_id ,
73
+ "content" : content ,
74
+ }
75
+ ],
76
+ }
77
+
78
+ def _create_assistant_tool_message (self , content , tool_calls ):
79
+ """Create an assistant message with tool calls in Anthropic format."""
80
+ message_content = []
81
+ if content :
82
+ message_content .append ({"type" : "text" , "text" : content })
83
+
84
+ for tool_call in tool_calls :
85
+ tool_input = (
86
+ tool_call ["function" ]["arguments" ]
87
+ if isinstance (tool_call , dict )
88
+ else tool_call .function .arguments
89
+ )
90
+ message_content .append (
91
+ {
92
+ "type" : "tool_use" ,
93
+ "id" : (
94
+ tool_call ["id" ] if isinstance (tool_call , dict ) else tool_call .id
95
+ ),
96
+ "name" : (
97
+ tool_call ["function" ]["name" ]
98
+ if isinstance (tool_call , dict )
99
+ else tool_call .function .name
100
+ ),
101
+ "input" : json .loads (tool_input ),
102
+ }
103
+ )
104
+
105
+ return {"role" : self .ROLE_ASSISTANT , "content" : message_content }
106
+
18
107
def chat_completions_create (self , model , messages , ** kwargs ):
19
- # Check if the fist message is a system message
20
- if messages [0 ]["role" ] == "system" :
108
+ """Create a chat completion using the Anthropic API."""
109
+ system_message = self ._extract_system_message (messages )
110
+ kwargs = self ._prepare_kwargs (kwargs )
111
+ converted_messages = self .convert_request (messages )
112
+
113
+ response = self .client .messages .create (
114
+ model = model , system = system_message , messages = converted_messages , ** kwargs
115
+ )
116
+ return self .convert_response (response )
117
+
118
+ def _extract_system_message (self , messages ):
119
+ """Extract system message if present, otherwise return empty list."""
120
+ if messages and messages [0 ]["role" ] == "system" :
21
121
system_message = messages [0 ]["content" ]
22
- messages = messages [ 1 :]
23
- else :
24
- system_message = []
122
+ messages . pop ( 0 )
123
+ return system_message
124
+ return []
25
125
26
- # kwargs.setdefault('max_tokens', DEFAULT_MAX_TOKENS)
27
- if "max_tokens" not in kwargs :
28
- kwargs ["max_tokens" ] = DEFAULT_MAX_TOKENS
126
+ def _prepare_kwargs (self , kwargs ):
127
+ """Prepare kwargs for the API call."""
128
+ kwargs = kwargs .copy () # Create a copy to avoid modifying the original
129
+ kwargs .setdefault ("max_tokens" , DEFAULT_MAX_TOKENS )
29
130
30
- return self .normalize_response (
31
- self .client .messages .create (
32
- model = model , system = system_message , messages = messages , ** kwargs
33
- )
131
+ if "tools" in kwargs :
132
+ kwargs ["tools" ] = self ._convert_tool_spec (kwargs ["tools" ])
133
+
134
+ return kwargs
135
+
136
+ def convert_response_with_tool_use (self , response ):
137
+ """Convert Anthropic tool use response to the framework's format."""
138
+ # Find the tool_use content
139
+ tool_call = next (
140
+ (content for content in response .content if content .type == "tool_use" ),
141
+ None ,
34
142
)
35
143
36
- def normalize_response (self , response ):
144
+ if tool_call :
145
+ function = Function (
146
+ name = tool_call .name , arguments = json .dumps (tool_call .input )
147
+ )
148
+ tool_call_obj = ChatCompletionMessageToolCall (
149
+ id = tool_call .id , function = function , type = "function"
150
+ )
151
+ # Get the text content if any
152
+ text_content = next (
153
+ (
154
+ content .text
155
+ for content in response .content
156
+ if content .type == "text"
157
+ ),
158
+ "" ,
159
+ )
160
+
161
+ return Message (
162
+ content = text_content or None ,
163
+ tool_calls = [tool_call_obj ] if tool_call else None ,
164
+ role = "assistant" ,
165
+ refusal = None ,
166
+ )
167
+ return None
168
+
169
+ def convert_response (self , response ):
37
170
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
38
171
normalized_response = ChatCompletionResponse ()
39
- normalized_response .choices [0 ].message .content = response .content [0 ].text
172
+
173
+ normalized_response .choices [0 ].finish_reason = self ._get_finish_reason (response )
174
+ normalized_response .usage = self ._get_usage_stats (response )
175
+ normalized_response .choices [0 ].message = self ._get_message (response )
176
+
40
177
return normalized_response
178
+
179
+ def _get_finish_reason (self , response ):
180
+ """Get the normalized finish reason."""
181
+ return self .FINISH_REASON_MAPPING .get (response .stop_reason , "stop" )
182
+
183
+ def _get_usage_stats (self , response ):
184
+ """Get the usage statistics."""
185
+ return {
186
+ "prompt_tokens" : response .usage .input_tokens ,
187
+ "completion_tokens" : response .usage .output_tokens ,
188
+ "total_tokens" : response .usage .input_tokens + response .usage .output_tokens ,
189
+ }
190
+
191
+ def _get_message (self , response ):
192
+ """Get the appropriate message based on response type."""
193
+ if response .stop_reason == "tool_use" :
194
+ tool_message = self .convert_response_with_tool_use (response )
195
+ if tool_message :
196
+ return tool_message
197
+
198
+ return Message (
199
+ content = response .content [0 ].text ,
200
+ role = "assistant" ,
201
+ tool_calls = None ,
202
+ refusal = None ,
203
+ )
204
+
205
+ def _convert_tool_spec (self , openai_tools ):
206
+ """Convert OpenAI tool specification to Anthropic format."""
207
+ anthropic_tools = []
208
+
209
+ for tool in openai_tools :
210
+ # Only handle function-type tools from OpenAI
211
+ if tool .get ("type" ) != "function" :
212
+ continue
213
+
214
+ function = tool ["function" ]
215
+
216
+ anthropic_tool = {
217
+ "name" : function ["name" ],
218
+ "description" : function ["description" ],
219
+ "input_schema" : {
220
+ "type" : "object" ,
221
+ "properties" : function ["parameters" ]["properties" ],
222
+ "required" : function ["parameters" ].get ("required" , []),
223
+ },
224
+ }
225
+
226
+ anthropic_tools .append (anthropic_tool )
227
+
228
+ return anthropic_tools
0 commit comments