Skip to content

Commit

Permalink
Adding tool calling support for Azure Inference API.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitprasad15 committed Nov 29, 2024
1 parent f4e524c commit 342481e
Showing 1 changed file with 53 additions and 9 deletions.
62 changes: 53 additions & 9 deletions aisuite/providers/azure_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from aisuite.provider import Provider
from aisuite.framework import ChatCompletionResponse
from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function


class AzureProvider(Provider):
Expand All @@ -18,14 +19,35 @@ def __init__(self, **config):
)

def chat_completions_create(self, model, messages, **kwargs):
url = f"https://{model}.westus3.models.ai.azure.com/v1/chat/completions"
url = f"https://{self.base_url}/chat/completions"
if self.base_url:
url = f"{self.base_url}/chat/completions"
url = f"{self.base_url}/chat/completions"

# Remove 'stream' from kwargs if present
kwargs.pop("stream", None)
data = {"messages": messages, **kwargs}

# Transform messages if they are Message objects
transformed_messages = []
for message in messages:
if isinstance(message, Message):
transformed_messages.append(message.model_dump(mode="json"))
else:
transformed_messages.append(message)

# Prepare the request payload with transformed messages
data = {"messages": transformed_messages}

# Add tools if provided
if "tools" in kwargs:
data["tools"] = kwargs["tools"]
# Remove from kwargs to avoid duplication
kwargs.pop("tools")

# Add tool_choice if provided
if "tool_choice" in kwargs:
data["tool_choice"] = kwargs["tool_choice"]
kwargs.pop("tool_choice")

# Add remaining kwargs
data.update(kwargs)

body = json.dumps(data).encode("utf-8")
headers = {"Content-Type": "application/json", "Authorization": self.api_key}
Expand All @@ -36,10 +58,32 @@ def chat_completions_create(self, model, messages, **kwargs):
result = response.read()
resp_json = json.loads(result)
completion_response = ChatCompletionResponse()
# TODO: Add checks for fields being present in resp_json.
completion_response.choices[0].message.content = resp_json["choices"][
0
]["message"]["content"]

# Process the response
choice = resp_json["choices"][0]
message = choice["message"]

# Set basic message content
completion_response.choices[0].message.content = message.get("content")
completion_response.choices[0].message.role = message.get(
"role", "assistant"
)

# Handle tool calls if present
if "tool_calls" in message and message["tool_calls"] is not None:
tool_calls = []
for tool_call in message["tool_calls"]:
new_tool_call = ChatCompletionMessageToolCall(
id=tool_call["id"],
type=tool_call["type"],
function={
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
},
)
tool_calls.append(new_tool_call)
completion_response.choices[0].message.tool_calls = tool_calls

return completion_response

except urllib.error.HTTPError as error:
Expand Down

0 comments on commit 342481e

Please sign in to comment.