From 342481ea096cdc4dddf01adb71fa2c4a1b89b0b4 Mon Sep 17 00:00:00 2001 From: Rohit Prsad Date: Fri, 29 Nov 2024 02:23:34 -0800 Subject: [PATCH] Adding tool calling support for Azure Inference API. --- aisuite/providers/azure_provider.py | 62 ++++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 9 deletions(-) diff --git a/aisuite/providers/azure_provider.py b/aisuite/providers/azure_provider.py index 7e0233ad..a08329d0 100644 --- a/aisuite/providers/azure_provider.py +++ b/aisuite/providers/azure_provider.py @@ -4,6 +4,7 @@ from aisuite.provider import Provider from aisuite.framework import ChatCompletionResponse +from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function class AzureProvider(Provider): @@ -18,14 +19,35 @@ def __init__(self, **config): ) def chat_completions_create(self, model, messages, **kwargs): - url = f"https://{model}.westus3.models.ai.azure.com/v1/chat/completions" - url = f"https://{self.base_url}/chat/completions" - if self.base_url: - url = f"{self.base_url}/chat/completions" + url = f"{self.base_url}/chat/completions" # Remove 'stream' from kwargs if present kwargs.pop("stream", None) - data = {"messages": messages, **kwargs} + + # Transform messages if they are Message objects + transformed_messages = [] + for message in messages: + if isinstance(message, Message): + transformed_messages.append(message.model_dump(mode="json")) + else: + transformed_messages.append(message) + + # Prepare the request payload with transformed messages + data = {"messages": transformed_messages} + + # Add tools if provided + if "tools" in kwargs: + data["tools"] = kwargs["tools"] + # Remove from kwargs to avoid duplication + kwargs.pop("tools") + + # Add tool_choice if provided + if "tool_choice" in kwargs: + data["tool_choice"] = kwargs["tool_choice"] + kwargs.pop("tool_choice") + + # Add remaining kwargs + data.update(kwargs) body = json.dumps(data).encode("utf-8") headers = {"Content-Type": "application/json", "Authorization": self.api_key} @@ -36,10 +58,32 @@ def chat_completions_create(self, model, messages, **kwargs): result = response.read() resp_json = json.loads(result) completion_response = ChatCompletionResponse() - # TODO: Add checks for fields being present in resp_json. - completion_response.choices[0].message.content = resp_json["choices"][ - 0 - ]["message"]["content"] + + # Process the response + choice = resp_json["choices"][0] + message = choice["message"] + + # Set basic message content + completion_response.choices[0].message.content = message.get("content") + completion_response.choices[0].message.role = message.get( + "role", "assistant" + ) + + # Handle tool calls if present + if "tool_calls" in message and message["tool_calls"] is not None: + tool_calls = [] + for tool_call in message["tool_calls"]: + new_tool_call = ChatCompletionMessageToolCall( + id=tool_call["id"], + type=tool_call["type"], + function={ + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + }, + ) + tool_calls.append(new_tool_call) + completion_response.choices[0].message.tool_calls = tool_calls + return completion_response except urllib.error.HTTPError as error: