Skip to content

Commit 342481e

Browse files
committed
Adding tool calling support for Azure Inference API.
1 parent f4e524c commit 342481e

File tree

1 file changed

+53
-9
lines changed

1 file changed

+53
-9
lines changed

aisuite/providers/azure_provider.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from aisuite.provider import Provider
66
from aisuite.framework import ChatCompletionResponse
7+
from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function
78

89

910
class AzureProvider(Provider):
@@ -18,14 +19,35 @@ def __init__(self, **config):
1819
)
1920

2021
def chat_completions_create(self, model, messages, **kwargs):
21-
url = f"https://{model}.westus3.models.ai.azure.com/v1/chat/completions"
22-
url = f"https://{self.base_url}/chat/completions"
23-
if self.base_url:
24-
url = f"{self.base_url}/chat/completions"
22+
url = f"{self.base_url}/chat/completions"
2523

2624
# Remove 'stream' from kwargs if present
2725
kwargs.pop("stream", None)
28-
data = {"messages": messages, **kwargs}
26+
27+
# Transform messages if they are Message objects
28+
transformed_messages = []
29+
for message in messages:
30+
if isinstance(message, Message):
31+
transformed_messages.append(message.model_dump(mode="json"))
32+
else:
33+
transformed_messages.append(message)
34+
35+
# Prepare the request payload with transformed messages
36+
data = {"messages": transformed_messages}
37+
38+
# Add tools if provided
39+
if "tools" in kwargs:
40+
data["tools"] = kwargs["tools"]
41+
# Remove from kwargs to avoid duplication
42+
kwargs.pop("tools")
43+
44+
# Add tool_choice if provided
45+
if "tool_choice" in kwargs:
46+
data["tool_choice"] = kwargs["tool_choice"]
47+
kwargs.pop("tool_choice")
48+
49+
# Add remaining kwargs
50+
data.update(kwargs)
2951

3052
body = json.dumps(data).encode("utf-8")
3153
headers = {"Content-Type": "application/json", "Authorization": self.api_key}
@@ -36,10 +58,32 @@ def chat_completions_create(self, model, messages, **kwargs):
3658
result = response.read()
3759
resp_json = json.loads(result)
3860
completion_response = ChatCompletionResponse()
39-
# TODO: Add checks for fields being present in resp_json.
40-
completion_response.choices[0].message.content = resp_json["choices"][
41-
0
42-
]["message"]["content"]
61+
62+
# Process the response
63+
choice = resp_json["choices"][0]
64+
message = choice["message"]
65+
66+
# Set basic message content
67+
completion_response.choices[0].message.content = message.get("content")
68+
completion_response.choices[0].message.role = message.get(
69+
"role", "assistant"
70+
)
71+
72+
# Handle tool calls if present
73+
if "tool_calls" in message and message["tool_calls"] is not None:
74+
tool_calls = []
75+
for tool_call in message["tool_calls"]:
76+
new_tool_call = ChatCompletionMessageToolCall(
77+
id=tool_call["id"],
78+
type=tool_call["type"],
79+
function={
80+
"name": tool_call["function"]["name"],
81+
"arguments": tool_call["function"]["arguments"],
82+
},
83+
)
84+
tool_calls.append(new_tool_call)
85+
completion_response.choices[0].message.tool_calls = tool_calls
86+
4387
return completion_response
4488

4589
except urllib.error.HTTPError as error:

0 commit comments

Comments
 (0)