From 44e0dd6a3cb6efb352507bdf6810d1632bae0fcc Mon Sep 17 00:00:00 2001 From: Rohit P Date: Sat, 4 Jan 2025 15:00:39 -0800 Subject: [PATCH] Improvements to notebook. --- aisuite/utils/tool_manager.py | 33 +++++++-- examples/simple_tool_calling.ipynb | 111 +++++++++++++++-------------- pyproject.toml | 1 + 3 files changed, 87 insertions(+), 58 deletions(-) diff --git a/aisuite/utils/tool_manager.py b/aisuite/utils/tool_manager.py index af88771f..69f98ca1 100644 --- a/aisuite/utils/tool_manager.py +++ b/aisuite/utils/tool_manager.py @@ -2,6 +2,7 @@ from pydantic import BaseModel, create_model, Field, ValidationError import inspect import json +from docstring_parser import parse class ToolManager: @@ -82,6 +83,24 @@ def _convert_to_tool_spec( }, } + def _extract_param_descriptions(self, func: Callable) -> dict[str, str]: + """Extract parameter descriptions from function docstring. + + Args: + func: The function to extract parameter descriptions from + + Returns: + Dictionary mapping parameter names to their descriptions + """ + docstring = inspect.getdoc(func) or "" + parsed_docstring = parse(docstring) + + param_descriptions = {} + for param in parsed_docstring.params: + param_descriptions[param.arg_name] = param.description or "" + + return param_descriptions + def _infer_from_signature( self, func: Callable ) -> tuple[Dict[str, Any], Type[BaseModel]]: @@ -90,8 +109,9 @@ def _infer_from_signature( fields = {} required_fields = [] - # Get function's docstring - docstring = inspect.getdoc(func) or " " + # Get function's docstring and parse parameter descriptions + param_descriptions = self._extract_param_descriptions(func) + docstring = inspect.getdoc(func) or "" for param_name, param in signature.parameters.items(): # Check if a type annotation is missing @@ -102,11 +122,16 @@ def _infer_from_signature( # Determine field type and optionality param_type = param.annotation + description = param_descriptions.get(param_name, "") + if param.default == inspect._empty: - fields[param_name] = (param_type, ...) + fields[param_name] = (param_type, Field(..., description=description)) required_fields.append(param_name) else: - fields[param_name] = (param_type, Field(default=param.default)) + fields[param_name] = ( + param_type, + Field(default=param.default, description=description), + ) # Dynamically create a Pydantic model based on inferred fields param_model = create_model(f"{func.__name__.capitalize()}Params", **fields) diff --git a/examples/simple_tool_calling.ipynb b/examples/simple_tool_calling.ipynb index 689b909b..4bd35d01 100644 --- a/examples/simple_tool_calling.ipynb +++ b/examples/simple_tool_calling.ipynb @@ -2,20 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import json\n", "import sys\n", @@ -29,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -41,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -59,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -74,7 +63,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -83,62 +72,57 @@ "client = Client()\n", "tool_manager = ToolManager([get_current_temperature, is_it_raining])\n", "\n", - "messages = [{\"role\": \"user\", \"content\": \"What is the current temperature in San Francisco in Celsius?\"}]\n", + "messages = [{\"role\": \"user\", \"content\": \"What is the current temperature in San Francisco in Celsius, and is it raining?\"}]\n", "\n", "response = client.chat.completions.create(\n", - " model=model, messages=messages, tools=tool_manager.tools())" + " model=model, messages=messages, tools=tool_manager.tools()) # tool_choice=\"auto\", parallel_tool_calls=True)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "'module' object is not callable. Did you mean: 'pprint.pprint(...)'?", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpprint\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mpprint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mTypeError\u001b[0m: 'module' object is not callable. Did you mean: 'pprint.pprint(...)'?" - ] - } - ], + "outputs": [], "source": [ - "import pprint\n", - "pprint(response)" + "from pprint import pprint\n", + "pprint(response.choices[0].message)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Based on the function result, the current temperature in San Francisco is 72 degrees Celsius.\n", - "\n", - "However, I must point out that this temperature seems unusually high for San Francisco, especially in Celsius. A temperature of 72°C would be equivalent to about 161.6°F, which is extremely hot and not typical for San Francisco's climate. \n", - "\n", - "It's possible there might be an error in the data or in how the function is interpreting or reporting the temperature. In a normal situation, we would expect San Francisco's temperature to be much lower, typically between 10°C to 25°C (50°F to 77°F) depending on the time of year.\n", - "\n", - "If you'd like, we can double-check this information or try to get the temperature in Fahrenheit to compare. Would you like me to do that?\n" - ] - } - ], + "outputs": [], "source": [ "if response.choices[0].message.tool_calls:\n", " tool_results, result_as_message = tool_manager.execute_tool(response.choices[0].message.tool_calls)\n", " messages.append(response.choices[0].message) # Model's function call message\n", - " messages.append(result_as_message[0])\n", + " messages.extend(result_as_message)\n", "\n", - " final_response = client.chat.completions.create(\n", + " response = client.chat.completions.create(\n", " model=model, messages=messages, tools=tool_manager.tools())\n", - " print(final_response.choices[0].message.content)" + " print(response.choices[0].message.content)\n", + " pprint(response.choices[0].message)\n", + "else:\n", + " pprint(response.choices[0].message)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if response.choices[0].message.tool_calls:\n", + " tool_results, result_as_message = tool_manager.execute_tool(response.choices[0].message.tool_calls)\n", + " messages.append(response.choices[0].message) # Model's function call message\n", + " messages.extend(result_as_message)\n", + "\n", + " response = client.chat.completions.create(\n", + " model=model, messages=messages, tools=tool_manager.tools())\n", + " print(response.choices[0].message.content)\n", + "else:\n", + " pprint(response.choices[0].message)" ] }, { @@ -154,6 +138,25 @@ "print(response.choices[0].message.content)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pprint(tool_manager.tools())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from aisuite import Client, ToolManager\n", + "ToolManager([get_current_temperature, is_it_raining]).tools()" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/pyproject.toml b/pyproject.toml index 8ae9295b..2a009f74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ vertexai = { version = "^1.63.0", optional = true } groq = { version = "^0.9.0", optional = true } mistralai = { version = "^1.0.3", optional = true } openai = { version = "^1.35.8", optional = true } +docstring-parser = { version = "^0.14.0", optional = true } # Optional dependencies for different providers [tool.poetry.extras]