From 1a7be38160dad0818d2341a23fa9d612430ace49 Mon Sep 17 00:00:00 2001 From: Rohit P Date: Wed, 22 Jan 2025 15:59:51 -0800 Subject: [PATCH] Removed debug prints, and other cleanup. --- aisuite/providers/aws_provider.py | 6 - aisuite/providers/fireworks_provider.py | 2 - aisuite/providers/huggingface_provider.py | 3 - examples/simple_tool_calling.ipynb | 266 ++++++++++++++++------ 4 files changed, 191 insertions(+), 86 deletions(-) diff --git a/aisuite/providers/aws_provider.py b/aisuite/providers/aws_provider.py index 01b2526..e370488 100644 --- a/aisuite/providers/aws_provider.py +++ b/aisuite/providers/aws_provider.py @@ -36,10 +36,6 @@ def convert_request( for message in messages ] - import pprint - - pprint.pprint(messages) - # Handle system message system_message = [] if messages and messages[0]["role"] == "system": @@ -68,8 +64,6 @@ def convert_request( } ) - pprint.pprint(formatted_messages) - return system_message, formatted_messages @staticmethod diff --git a/aisuite/providers/fireworks_provider.py b/aisuite/providers/fireworks_provider.py index 878eb1e..10bea19 100644 --- a/aisuite/providers/fireworks_provider.py +++ b/aisuite/providers/fireworks_provider.py @@ -113,8 +113,6 @@ def chat_completions_create(self, model, messages, **kwargs): "Content-Type": "application/json", } - print(data) - try: # Make the request to Fireworks AI endpoint. response = httpx.post( diff --git a/aisuite/providers/huggingface_provider.py b/aisuite/providers/huggingface_provider.py index ac8af9c..97b1dbd 100644 --- a/aisuite/providers/huggingface_provider.py +++ b/aisuite/providers/huggingface_provider.py @@ -64,12 +64,9 @@ def chat_completions_create(self, model, messages, **kwargs): **kwargs, # Include other parameters like temperature, max_tokens, etc. } - print(transformed_messages) - # Make the API call using the client response = self.client.chat_completion(model=model, **payload) - print(response) return self._normalize_response(response) except Exception as e: diff --git a/examples/simple_tool_calling.ipynb b/examples/simple_tool_calling.ipynb index 4bd35d0..4487072 100644 --- a/examples/simple_tool_calling.ipynb +++ b/examples/simple_tool_calling.ipynb @@ -2,13 +2,25 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import json\n", "import sys\n", "from dotenv import load_dotenv, find_dotenv\n", + "import os\n", "\n", "sys.path.append('../../aisuite')\n", "\n", @@ -17,48 +29,77 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "# Mock tool functions.\n", - "def get_current_temperature(location: str, unit: str):\n", - " # Simulate fetching temperature from an API\n", - " return {\"location\": location, \"unit\": unit, \"temperature\": 72}" + "## Make a request to model without tools" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "For model: xai:grok-2-latest\n", + "I'm here to help with your query! To provide the most accurate and up-to-date temperature for San Francisco, I would need to access real-time weather data. However, I don't have direct access to real-time data sources. \n", + "\n", + "I recommend checking a reliable weather website or app, such as AccuWeather, Weather.com, or the National Weather Service, to get the current temperature in San Francisco in Celsius. These platforms update frequently and will give you the most current information.\n" + ] + } + ], "source": [ - "def is_it_raining(location: str):\n", - " \"\"\"Check if it is raining at a location.\n", + "from aisuite import Client\n", + "\n", + "client = Client()\n", + "# Configuring Azure. Rest all providers use environment variables for their parameters.\n", + "client.configure({\"azure\" : {\n", + " \"api_key\": os.environ[\"AZURE_API_KEY\"],\n", + " \"base_url\": \"https://aisuite-mistral-large-2407.westus3.models.ai.azure.com/v1/\",\n", + "}})\n", + "# model = \"anthropic:claude-3-5-sonnet-20241022\"\n", + "# model = \"aws:mistral.mistral-7b-instruct-v0:2\"\n", + "# model = \"azure:aisuite-mistral-large\"\n", + "# model = \"cohere:command-r-plus\"\n", + "# model = \"deepseek:deepseek-chat\"\n", + "# model = \"fireworks:accounts/fireworks/models/llama-v3p1-405b-instruct\"\n", + "# model = \"google:gemini-1.5-pro-002\"\n", + "# model = \"groq:llama-3.3-70b-versatile\"\n", + "# model = \"huggingface:meta-llama/Llama-3.1-8B-Instruct\"\n", + "# model = \"mistral:mistral-large-latest\"\n", + "# model = \"nebius:\"\n", + "# model = \"ollama:\"\n", + "# model = \"sambanova:Meta-Llama-3.3-70B-Instruct\"\n", + "# model = \"together:meta-llama/Llama-3.3-70B-Instruct-Turbo\"\n", + "# model = \"watsonx:\"\n", + "model = \"xai:grok-2-latest\"\n", + "\n", + "messages = [{\n", + " \"role\": \"user\",\n", + " \"content\": \"What is the current temperature in San Francisco in Celsius?\"}]\n", "\n", - " Args:\n", - " location (str): Name of the Place.\n", + "response = client.chat.completions.create(\n", + " model=model, messages=messages)\n", "\n", - " Returns:\n", - " bool: Whether it is raining in that place.\n", - " \"\"\"\n", - " return True" + "print(\"For model: \" + model)\n", + "print(response.choices[0].message.content)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "model = \"anthropic:claude-3-5-sonnet-20240620\"\n", - "# model = \"openai:gpt-4o\"\n", - "# model = mistral:mistral-large-latest\n", - "# model = \"aws:anthropic.claude-3-haiku-20240307-v1:0\"\n", - "# model = \"aws:meta.llama3-1-8b-instruct-v1:0\"\n", - "# model = \"aws:meta.llama3-3-70b-instruct-v1:0\"\n", - "# model = \"groq:llama-3.1-70b-versatile\"" + "## Equip model with tools" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define the functions" ] }, { @@ -67,15 +108,58 @@ "metadata": {}, "outputs": [], "source": [ - "from aisuite import Client, ToolManager # Import your ToolManager class\n", - "\n", - "client = Client()\n", - "tool_manager = ToolManager([get_current_temperature, is_it_raining])\n", + "# Mock tool functions.\n", + "def get_current_temperature(location: str, unit: str):\n", + " # Simulate fetching temperature from an API\n", + " return {\"location\": location, \"unit\": unit, \"temperature\": 72}\n", "\n", - "messages = [{\"role\": \"user\", \"content\": \"What is the current temperature in San Francisco in Celsius, and is it raining?\"}]\n", + "def get_rain_probability(location: str):\n", + " # Simulate fetching rain probability\n", + " return {\"location\": location, \"probability\": 40}\n", "\n", - "response = client.chat.completions.create(\n", - " model=model, messages=messages, tools=tool_manager.tools()) # tool_choice=\"auto\", parallel_tool_calls=True)" + "# Function to get the available tools (functions) to provide to the model\n", + "# Note: we could use decorators or utils from OpenAI to generate this.\n", + "def get_available_tools():\n", + " return [\n", + " { \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_current_temperature\",\n", + " \"description\": \"Get the current temperature for a specific location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city and state, e.g., San Francisco, CA\"\n", + " },\n", + " \"unit\": {\n", + " \"type\": \"string\",\n", + " \"enum\": [\"Celsius\", \"Fahrenheit\"],\n", + " \"description\": \"The temperature unit to use.\"\n", + " }\n", + " },\n", + " \"required\": [\"location\", \"unit\"]\n", + " }\n", + " }\n", + " },\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_rain_probability\",\n", + " \"description\": \"Get the probability of rain for a specific location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city and state, e.g., San Francisco, CA\"\n", + " }\n", + " },\n", + " \"required\": [\"location\"]\n", + " }\n", + " }\n", + " }\n", + " ]" ] }, { @@ -84,8 +168,17 @@ "metadata": {}, "outputs": [], "source": [ - "from pprint import pprint\n", - "pprint(response.choices[0].message)" + "# Function to process tool calls and get the result\n", + "def handle_tool_call(tool_call):\n", + " function_name = tool_call.function.name\n", + " arguments = json.loads(tool_call.function.arguments)\n", + "\n", + " # Map function names to actual tool function implementations\n", + " tools_map = {\n", + " \"get_current_temperature\": get_current_temperature,\n", + " \"get_rain_probability\": get_rain_probability,\n", + " }\n", + " return tools_map[function_name](**arguments)" ] }, { @@ -94,17 +187,14 @@ "metadata": {}, "outputs": [], "source": [ - "if response.choices[0].message.tool_calls:\n", - " tool_results, result_as_message = tool_manager.execute_tool(response.choices[0].message.tool_calls)\n", - " messages.append(response.choices[0].message) # Model's function call message\n", - " messages.extend(result_as_message)\n", - "\n", - " response = client.chat.completions.create(\n", - " model=model, messages=messages, tools=tool_manager.tools())\n", - " print(response.choices[0].message.content)\n", - " pprint(response.choices[0].message)\n", - "else:\n", - " pprint(response.choices[0].message)" + "# Function to format tool response as a message\n", + "def create_tool_response_message(tool_call, tool_result):\n", + " return {\n", + " \"role\": \"tool\",\n", + " \"tool_call_id\": tool_call.id,\n", + " \"name\": tool_call.function.name,\n", + " \"content\": json.dumps(tool_result)\n", + " }" ] }, { @@ -113,29 +203,22 @@ "metadata": {}, "outputs": [], "source": [ - "if response.choices[0].message.tool_calls:\n", - " tool_results, result_as_message = tool_manager.execute_tool(response.choices[0].message.tool_calls)\n", - " messages.append(response.choices[0].message) # Model's function call message\n", - " messages.extend(result_as_message)\n", - "\n", - " response = client.chat.completions.create(\n", - " model=model, messages=messages, tools=tool_manager.tools())\n", - " print(response.choices[0].message.content)\n", - "else:\n", - " pprint(response.choices[0].message)" + "model = \"anthropic:claude-3-5-sonnet-20240620\"\n", + "model = \"huggingface:meta-llama/Llama-3.1-8B-Instruct\"\n", + "model = \"huggingface:meta-llama/Llama-3.3-70B-Instruct\"\n", + "# model = \"openai:gpt-4o\"\n", + "# model = mistral:mistral-large-latest\n", + "# model = \"aws:anthropic.claude-3-haiku-20240307-v1:0\"\n", + "# model = \"aws:meta.llama3-1-8b-instruct-v1:0\"\n", + "# model = \"aws:meta.llama3-3-70b-instruct-v1:0\"\n", + "# model = \"groq:llama-3.1-70b-versatile\"" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "# Now, test without tool calling, to check that the normal path is not broken.\n", - "messages = [{\"role\": \"user\", \"content\": \"What is the capital of California?\"}]\n", - "response = client.chat.completions.create(\n", - " model=model, messages=messages)\n", - "print(response.choices[0].message.content)" + "### Call the model with tools" ] }, { @@ -144,17 +227,37 @@ "metadata": {}, "outputs": [], "source": [ - "pprint(tool_manager.tools())" + "from aisuite import Client\n", + "\n", + "client = Client()\n", + "client.configure({\"azure\" : {\n", + " \"api_key\": os.environ[\"AZURE_API_KEY\"],\n", + " \"base_url\": \"https://aisuite-mistral-large-2407.westus3.models.ai.azure.com/v1/\",\n", + "}})\n", + "model = \"azure:aisuite-mistral-large-2407\"\n", + "model = \"fireworks:accounts/fireworks/models/llama-v3p1-405b-instruct\"\n", + "model = \"mistral:mistral-large-latest\"\n", + "model = \"together:meta-llama/Llama-3.3-70B-Instruct-Turbo\"\n", + "\n", + "messages = [{\n", + " \"role\": \"user\",\n", + " \"content\": \"What is the current temperature in San Francisco in Celsius?\"}]\n", + "\n", + "tools = get_available_tools()\n", + "\n", + "# Make the initial request to OpenAI API\n", + "response = client.chat.completions.create(\n", + " model=model, messages=messages, tools=tools)\n", + "\n", + "print(response)\n", + "print(response.choices[0].message)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "from aisuite import Client, ToolManager\n", - "ToolManager([get_current_temperature, is_it_raining]).tools()" + "### Process tool calls - Parse tool name, args, and call the function. Pass the result to the model." ] }, { @@ -162,7 +265,20 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "if response.choices[0].message.tool_calls:\n", + " for tool_call in response.choices[0].message.tool_calls:\n", + " tool_result = handle_tool_call(tool_call)\n", + "\n", + " messages.append(response.choices[0].message) # Model's function call message\n", + " messages.append(create_tool_response_message(tool_call, tool_result))\n", + " # Send the tool response back to the model\n", + " final_response = client.chat.completions.create(\n", + " model=model, messages=messages, tools=tools)\n", + "\n", + " # Output the final response from the model\n", + " print(final_response.choices[0].message.content)" + ] } ], "metadata": { @@ -181,7 +297,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.12.8" } }, "nbformat": 4,