Skip to content

Commit 6c9832e

Browse files
committed
Changes the path for ToolManager.
1 parent 5a55a09 commit 6c9832e

File tree

3 files changed

+29
-27
lines changed

3 files changed

+29
-27
lines changed

aisuite/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .client import Client
22
from .framework.message import Message
3+
from .utils.tool_manager import ToolManager

aisuite/utils/tool_manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55

66

77
class ToolManager:
8-
def __init__(self):
8+
def __init__(self, tools: list[Callable] = None):
99
self._tools = {}
10+
if tools:
11+
for tool in tools:
12+
self._add_tool(tool)
1013

1114
# Add a tool function with or without a Pydantic model.
12-
def add_tool(self, func: Callable, param_model: Optional[Type[BaseModel]] = None):
15+
def _add_tool(self, func: Callable, param_model: Optional[Type[BaseModel]] = None):
1316
"""Register a tool function with metadata. If no param_model is provided, infer from function signature."""
1417
if param_model:
1518
tool_spec = self._convert_to_tool_spec(func, param_model)

examples/simple_tool_calling.ipynb

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@
3434
"outputs": [],
3535
"source": [
3636
"import aisuite as ai\n",
37-
"from aisuite.utils.tool_manager import ToolManager # Import your ToolManager class\n",
37+
"from aisuite import ToolManager # Import your ToolManager class\n",
3838
"\n",
39-
"client = ai.Client()\n",
40-
"tool_manager = ToolManager()"
39+
"client = ai.Client()"
4140
]
4241
},
4342
{
@@ -49,10 +48,7 @@
4948
"# Mock tool functions.\n",
5049
"def get_current_temperature(location: str, unit: str):\n",
5150
" # Simulate fetching temperature from an API\n",
52-
" return {\"location\": location, \"unit\": unit, \"temperature\": 72}\n",
53-
"\n",
54-
"tool_manager.add_tool(get_current_temperature)\n",
55-
"messages = [{\"role\": \"user\", \"content\": \"What is the current temperature in San Francisco in Celsius?\"}]"
51+
" return {\"location\": location, \"unit\": unit, \"temperature\": 72}"
5652
]
5753
},
5854
{
@@ -61,13 +57,18 @@
6157
"metadata": {},
6258
"outputs": [],
6359
"source": [
64-
"# model = \"anthropic:claude-3-5-sonnet-20240620\"\n",
65-
"model = \"openai:gpt-4o\"\n",
60+
"model = \"anthropic:claude-3-5-sonnet-20240620\"\n",
61+
"# model = \"openai:gpt-4o\"\n",
6662
"# model = mistral:mistral-large-latest\n",
6763
"# model = \"aws:anthropic.claude-3-haiku-20240307-v1:0\"\n",
6864
"# model = \"aws:meta.llama3-1-8b-instruct-v1:0\"\n",
6965
"# model = \"aws:meta.llama3-3-70b-instruct-v1:0\"\n",
7066
"# model = \"groq:llama-3.1-70b-versatile\"\n",
67+
"\n",
68+
"\n",
69+
"tool_manager = ToolManager([get_current_temperature])\n",
70+
"messages = [{\"role\": \"user\", \"content\": \"What is the current temperature in San Francisco in Celsius?\"}]\n",
71+
"\n",
7172
"response = client.chat.completions.create(\n",
7273
" model=model, messages=messages, tools=tool_manager.tools())"
7374
]
@@ -81,11 +82,12 @@
8182
"name": "stdout",
8283
"output_type": "stream",
8384
"text": [
84-
"{'content': None,\n",
85-
" 'function_call': None,\n",
85+
"{'content': 'To answer your question about the current temperature in San '\n",
86+
" \"Francisco in Celsius, I'll need to use the available tool to get \"\n",
87+
" 'that information. Let me fetch that for you.',\n",
8688
" 'refusal': None,\n",
8789
" 'role': 'assistant',\n",
88-
" 'tool_calls': [ChatCompletionMessageToolCall(id='call_uzJz9BtPDbcJYiKLcrmEybfE', function=Function(arguments='{\"location\":\"San Francisco\",\"unit\":\"Celsius\"}', name='get_current_temperature'), type='function')]}\n"
90+
" 'tool_calls': [ChatCompletionMessageToolCall(id='toolu_01SVEv2QJ7tsUatecDD2TEq7', function=Function(arguments='{\"location\": \"San Francisco\", \"unit\": \"Celsius\"}', name='get_current_temperature'), type='function')]}\n"
8991
]
9092
}
9193
],
@@ -96,14 +98,18 @@
9698
},
9799
{
98100
"cell_type": "code",
99-
"execution_count": 6,
101+
"execution_count": 9,
100102
"metadata": {},
101103
"outputs": [
102104
{
103105
"name": "stdout",
104106
"output_type": "stream",
105107
"text": [
106-
"The current temperature in San Francisco is 72°C. \n"
108+
"Based on the function results, the current temperature in San Francisco is 72 degrees Celsius.\n",
109+
"\n",
110+
"It's worth noting that this temperature seems unusually high for San Francisco, as 72°C is equivalent to about 161.6°F, which would be extremely hot for any city. Typically, San Francisco experiences much milder temperatures. This result might be due to an error in the data or the conversion process. \n",
111+
"\n",
112+
"To give you a more realistic perspective, San Francisco usually has moderate temperatures year-round, with average highs rarely exceeding 21°C (70°F) even in the warmest months. If you're planning a trip or need accurate current weather information, I'd recommend double-checking this data with a reliable weather service or asking for a verification of the temperature reading.\n"
107113
]
108114
}
109115
],
@@ -120,19 +126,11 @@
120126
},
121127
{
122128
"cell_type": "code",
123-
"execution_count": 7,
129+
"execution_count": null,
124130
"metadata": {},
125-
"outputs": [
126-
{
127-
"name": "stdout",
128-
"output_type": "stream",
129-
"text": [
130-
"The capital of California is Sacramento.\n"
131-
]
132-
}
133-
],
131+
"outputs": [],
134132
"source": [
135-
"# Now, test without tool calling.\n",
133+
"# Now, test without tool calling, to check that the normal path is not broken.\n",
136134
"messages = [{\"role\": \"user\", \"content\": \"What is the capital of California?\"}]\n",
137135
"response = client.chat.completions.create(\n",
138136
" model=model, messages=messages)\n",

0 commit comments

Comments
 (0)