Skip to content

Commit 5a55a09

Browse files
committed
Few Refactorings.
1 parent 342481e commit 5a55a09

File tree

4 files changed

+176
-105
lines changed

4 files changed

+176
-105
lines changed
Lines changed: 76 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# Anthropic provider
2+
# Links:
3+
# Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use
4+
15
import anthropic
26
import json
37
from aisuite.provider import Provider
@@ -7,36 +11,34 @@
711
# Define a constant for the default max_tokens value
812
DEFAULT_MAX_TOKENS = 4096
913

10-
# Links:
11-
# Tool calling docs - https://docs.anthropic.com/en/docs/build-with-claude/tool-use
12-
13-
14-
class AnthropicProvider(Provider):
15-
# Add these at the class level, after the class definition
16-
FINISH_REASON_MAPPING = {
17-
"end_turn": "stop",
18-
"max_tokens": "length",
19-
"tool_use": "tool_calls",
20-
# Add more mappings as needed
21-
}
2214

15+
class AnthropicMessageConverter:
2316
# Role constants
2417
ROLE_USER = "user"
2518
ROLE_ASSISTANT = "assistant"
2619
ROLE_TOOL = "tool"
2720
ROLE_SYSTEM = "system"
2821

29-
def __init__(self, **config):
30-
"""
31-
Initialize the Anthropic provider with the given configuration.
32-
Pass the entire configuration dictionary to the Anthropic client constructor.
33-
"""
34-
35-
self.client = anthropic.Anthropic(**config)
22+
# Finish reason mapping
23+
FINISH_REASON_MAPPING = {
24+
"end_turn": "stop",
25+
"max_tokens": "length",
26+
"tool_use": "tool_calls",
27+
}
3628

3729
def convert_request(self, messages):
3830
"""Convert framework messages to Anthropic format."""
39-
return [self._convert_single_message(msg) for msg in messages]
31+
system_message = self._extract_system_message(messages)
32+
converted_messages = [self._convert_single_message(msg) for msg in messages]
33+
return system_message, converted_messages
34+
35+
def convert_response(self, response):
36+
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
37+
normalized_response = ChatCompletionResponse()
38+
normalized_response.choices[0].finish_reason = self._get_finish_reason(response)
39+
normalized_response.usage = self._get_usage_stats(response)
40+
normalized_response.choices[0].message = self._get_message(response)
41+
return normalized_response
4042

4143
def _convert_single_message(self, msg):
4244
"""Convert a single message to Anthropic format."""
@@ -104,38 +106,45 @@ def _create_assistant_tool_message(self, content, tool_calls):
104106

105107
return {"role": self.ROLE_ASSISTANT, "content": message_content}
106108

107-
def chat_completions_create(self, model, messages, **kwargs):
108-
"""Create a chat completion using the Anthropic API."""
109-
system_message = self._extract_system_message(messages)
110-
kwargs = self._prepare_kwargs(kwargs)
111-
converted_messages = self.convert_request(messages)
112-
113-
response = self.client.messages.create(
114-
model=model, system=system_message, messages=converted_messages, **kwargs
115-
)
116-
return self.convert_response(response)
117-
118109
def _extract_system_message(self, messages):
119110
"""Extract system message if present, otherwise return empty list."""
111+
# TODO: This is a temporary solution to extract the system message.
112+
# User can pass multiple system messages, which can mingled with other messages.
113+
# This needs to be fixed to handle this case.
120114
if messages and messages[0]["role"] == "system":
121115
system_message = messages[0]["content"]
122116
messages.pop(0)
123117
return system_message
124118
return []
125119

126-
def _prepare_kwargs(self, kwargs):
127-
"""Prepare kwargs for the API call."""
128-
kwargs = kwargs.copy() # Create a copy to avoid modifying the original
129-
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS)
120+
def _get_finish_reason(self, response):
121+
"""Get the normalized finish reason."""
122+
return self.FINISH_REASON_MAPPING.get(response.stop_reason, "stop")
130123

131-
if "tools" in kwargs:
132-
kwargs["tools"] = self._convert_tool_spec(kwargs["tools"])
124+
def _get_usage_stats(self, response):
125+
"""Get the usage statistics."""
126+
return {
127+
"prompt_tokens": response.usage.input_tokens,
128+
"completion_tokens": response.usage.output_tokens,
129+
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
130+
}
133131

134-
return kwargs
132+
def _get_message(self, response):
133+
"""Get the appropriate message based on response type."""
134+
if response.stop_reason == "tool_use":
135+
tool_message = self.convert_response_with_tool_use(response)
136+
if tool_message:
137+
return tool_message
138+
139+
return Message(
140+
content=response.content[0].text,
141+
role="assistant",
142+
tool_calls=None,
143+
refusal=None,
144+
)
135145

136146
def convert_response_with_tool_use(self, response):
137147
"""Convert Anthropic tool use response to the framework's format."""
138-
# Find the tool_use content
139148
tool_call = next(
140149
(content for content in response.content if content.type == "tool_use"),
141150
None,
@@ -148,7 +157,6 @@ def convert_response_with_tool_use(self, response):
148157
tool_call_obj = ChatCompletionMessageToolCall(
149158
id=tool_call.id, function=function, type="function"
150159
)
151-
# Get the text content if any
152160
text_content = next(
153161
(
154162
content.text
@@ -166,53 +174,15 @@ def convert_response_with_tool_use(self, response):
166174
)
167175
return None
168176

169-
def convert_response(self, response):
170-
"""Normalize the response from the Anthropic API to match OpenAI's response format."""
171-
normalized_response = ChatCompletionResponse()
172-
173-
normalized_response.choices[0].finish_reason = self._get_finish_reason(response)
174-
normalized_response.usage = self._get_usage_stats(response)
175-
normalized_response.choices[0].message = self._get_message(response)
176-
177-
return normalized_response
178-
179-
def _get_finish_reason(self, response):
180-
"""Get the normalized finish reason."""
181-
return self.FINISH_REASON_MAPPING.get(response.stop_reason, "stop")
182-
183-
def _get_usage_stats(self, response):
184-
"""Get the usage statistics."""
185-
return {
186-
"prompt_tokens": response.usage.input_tokens,
187-
"completion_tokens": response.usage.output_tokens,
188-
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
189-
}
190-
191-
def _get_message(self, response):
192-
"""Get the appropriate message based on response type."""
193-
if response.stop_reason == "tool_use":
194-
tool_message = self.convert_response_with_tool_use(response)
195-
if tool_message:
196-
return tool_message
197-
198-
return Message(
199-
content=response.content[0].text,
200-
role="assistant",
201-
tool_calls=None,
202-
refusal=None,
203-
)
204-
205-
def _convert_tool_spec(self, openai_tools):
177+
def convert_tool_spec(self, openai_tools):
206178
"""Convert OpenAI tool specification to Anthropic format."""
207179
anthropic_tools = []
208180

209181
for tool in openai_tools:
210-
# Only handle function-type tools from OpenAI
211182
if tool.get("type") != "function":
212183
continue
213184

214185
function = tool["function"]
215-
216186
anthropic_tool = {
217187
"name": function["name"],
218188
"description": function["description"],
@@ -222,7 +192,33 @@ def _convert_tool_spec(self, openai_tools):
222192
"required": function["parameters"].get("required", []),
223193
},
224194
}
225-
226195
anthropic_tools.append(anthropic_tool)
227196

228197
return anthropic_tools
198+
199+
200+
class AnthropicProvider(Provider):
201+
def __init__(self, **config):
202+
"""Initialize the Anthropic provider with the given configuration."""
203+
self.client = anthropic.Anthropic(**config)
204+
self.converter = AnthropicMessageConverter()
205+
206+
def chat_completions_create(self, model, messages, **kwargs):
207+
"""Create a chat completion using the Anthropic API."""
208+
kwargs = self._prepare_kwargs(kwargs)
209+
system_message, converted_messages = self.converter.convert_request(messages)
210+
211+
response = self.client.messages.create(
212+
model=model, system=system_message, messages=converted_messages, **kwargs
213+
)
214+
return self.converter.convert_response(response)
215+
216+
def _prepare_kwargs(self, kwargs):
217+
"""Prepare kwargs for the API call."""
218+
kwargs = kwargs.copy()
219+
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS)
220+
221+
if "tools" in kwargs:
222+
kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"])
223+
224+
return kwargs

aisuite/providers/aws_provider.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ def convert_request(
3333
for message in messages
3434
]
3535

36+
import pprint
37+
38+
pprint.pprint(messages)
39+
3640
# Handle system message
3741
system_message = []
3842
if messages and messages[0]["role"] == "system":
@@ -61,6 +65,8 @@ def convert_request(
6165
}
6266
)
6367

68+
pprint.pprint(formatted_messages)
69+
6470
return system_message, formatted_messages
6571

6672
@staticmethod
@@ -150,21 +156,14 @@ def convert_assistant(message: Dict[str, Any]) -> Optional[Dict[str, Any]]:
150156

151157
return {"role": "assistant", "content": content} if content else None
152158

153-
154-
class AwsProvider(Provider):
155-
def __init__(self, **config):
156-
"""Initialize the AWS Bedrock provider with the given configuration."""
157-
self.config = BedrockConfig(**config)
158-
self.client = self.config.create_client()
159-
self.transformer = BedrockMessageConverter()
160-
161-
def convert_response(self, response: Dict[str, Any]) -> ChatCompletionResponse:
159+
@staticmethod
160+
def convert_response(response: Dict[str, Any]) -> ChatCompletionResponse:
162161
"""Normalize the response from the Bedrock API to match OpenAI's response format."""
163162
norm_response = ChatCompletionResponse()
164163

165164
# Check if the model is requesting tool use
166165
if response.get("stopReason") == "tool_use":
167-
tool_message = self.transformer.convert_response_tool_call(response)
166+
tool_message = BedrockMessageConverter.convert_response_tool_call(response)
168167
if tool_message:
169168
norm_response.choices[0].message = Message(**tool_message)
170169
norm_response.choices[0].finish_reason = "tool_calls"
@@ -176,6 +175,18 @@ def convert_response(self, response: Dict[str, Any]) -> ChatCompletionResponse:
176175
][0]["text"]
177176
return norm_response
178177

178+
179+
class AwsProvider(Provider):
180+
def __init__(self, **config):
181+
"""Initialize the AWS Bedrock provider with the given configuration."""
182+
self.config = BedrockConfig(**config)
183+
self.client = self.config.create_client()
184+
self.transformer = BedrockMessageConverter()
185+
186+
def convert_response(self, response: Dict[str, Any]) -> ChatCompletionResponse:
187+
"""Normalize the response from the Bedrock API to match OpenAI's response format."""
188+
return self.transformer.convert_response(response)
189+
179190
def _convert_tool_spec(self, kwargs: Dict[str, Any]) -> Optional[Dict[str, Any]]:
180191
"""Convert tool specifications to Bedrock format."""
181192
if "tools" not in kwargs:
@@ -213,12 +224,16 @@ def _prepare_request_config(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
213224
if key not in BedrockConfig.INFERENCE_PARAMETERS
214225
}
215226

216-
return {
227+
request_config = {
217228
"inferenceConfig": inference_config,
218229
"additionalModelRequestFields": additional_fields,
219-
"toolConfig": tool_config,
220230
}
221231

232+
if tool_config is not None:
233+
request_config["toolConfig"] = tool_config
234+
235+
return request_config
236+
222237
def chat_completions_create(
223238
self, model: str, messages: List[Dict[str, Any]], **kwargs
224239
) -> ChatCompletionResponse:

examples/client.ipynb

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,18 @@
231231
"response = client.chat.completions.create(model=togetherai_model, messages=messages, temperature=0.75, top_p=0.7, top_k=50)\n",
232232
"print(response.choices[0].message.content)"
233233
]
234+
},
235+
{
236+
"cell_type": "code",
237+
"execution_count": null,
238+
"id": "dcf63a11",
239+
"metadata": {},
240+
"outputs": [],
241+
"source": [
242+
"gemini_15_flash = \"google:gemini-1.5-flash\"\n",
243+
"response = client.chat.completions.create(model=gemini_15_flash, messages=messages, temperature=0.75)\n",
244+
"print(response.choices[0].message.content)"
245+
]
234246
}
235247
],
236248
"metadata": {

0 commit comments

Comments
 (0)