2
2
import base64
3
3
import json
4
4
import threading
5
- from typing import Callable , Optional
5
+ from typing import Callable , Optional , Awaitable , Union , Any
6
+ import asyncio
7
+ from concurrent .futures import ThreadPoolExecutor
6
8
7
9
from websockets .sync .client import connect
8
10
@@ -52,22 +54,133 @@ def interrupt(self):
52
54
"""
53
55
pass
54
56
57
+
58
+ class ClientTools :
59
+ """Handles registration and execution of client-side tools that can be called by the agent.
60
+
61
+ Supports both synchronous and asynchronous tools running in a dedicated event loop,
62
+ ensuring non-blocking operation of the main conversation thread.
63
+ """
64
+
65
+ def __init__ (self ):
66
+ self .tools : dict [str , tuple [Union [Callable [[dict ], Any ], Callable [[dict ], Awaitable [Any ]]], bool ]] = {}
67
+ self .lock = threading .Lock ()
68
+ self ._loop = None
69
+ self ._thread = None
70
+ self ._running = threading .Event ()
71
+ self .thread_pool = ThreadPoolExecutor ()
72
+
73
+ def start (self ):
74
+ """Start the event loop in a separate thread for handling async operations."""
75
+ if self ._running .is_set ():
76
+ return
77
+
78
+ def run_event_loop ():
79
+ self ._loop = asyncio .new_event_loop ()
80
+ asyncio .set_event_loop (self ._loop )
81
+ self ._running .set ()
82
+ try :
83
+ self ._loop .run_forever ()
84
+ finally :
85
+ self ._running .clear ()
86
+ self ._loop .close ()
87
+ self ._loop = None
88
+
89
+ self ._thread = threading .Thread (target = run_event_loop , daemon = True , name = "ClientTools-EventLoop" )
90
+ self ._thread .start ()
91
+ # Wait for loop to be ready
92
+ self ._running .wait ()
93
+
94
+ def stop (self ):
95
+ """Gracefully stop the event loop and clean up resources."""
96
+ if self ._loop and self ._running .is_set ():
97
+ self ._loop .call_soon_threadsafe (self ._loop .stop )
98
+ self ._thread .join ()
99
+ self .thread_pool .shutdown (wait = False )
100
+
101
+ def register (
102
+ self ,
103
+ tool_name : str ,
104
+ handler : Union [Callable [[dict ], Any ], Callable [[dict ], Awaitable [Any ]]],
105
+ is_async : bool = False ,
106
+ ) -> None :
107
+ """Register a new tool that can be called by the AI agent.
108
+
109
+ Args:
110
+ tool_name: Unique identifier for the tool
111
+ handler: Function that implements the tool's logic
112
+ is_async: Whether the handler is an async function
113
+ """
114
+ with self .lock :
115
+ if not callable (handler ):
116
+ raise ValueError ("Handler must be callable" )
117
+ if tool_name in self .tools :
118
+ raise ValueError (f"Tool '{ tool_name } ' is already registered" )
119
+ self .tools [tool_name ] = (handler , is_async )
120
+
121
+ async def handle (self , tool_name : str , parameters : dict ) -> Any :
122
+ """Execute a registered tool with the given parameters.
123
+
124
+ Returns the result of the tool execution.
125
+ """
126
+ with self .lock :
127
+ if tool_name not in self .tools :
128
+ raise ValueError (f"Tool '{ tool_name } ' is not registered" )
129
+ handler , is_async = self .tools [tool_name ]
130
+
131
+ if is_async :
132
+ return await handler (parameters )
133
+ else :
134
+ return await asyncio .get_event_loop ().run_in_executor (self .thread_pool , handler , parameters )
135
+
136
+ def execute_tool (self , tool_name : str , parameters : dict , callback : Callable [[dict ], None ]):
137
+ """Execute a tool and send its result via the provided callback.
138
+
139
+ This method is non-blocking and handles both sync and async tools.
140
+ """
141
+ if not self ._running .is_set ():
142
+ raise RuntimeError ("ClientTools event loop is not running" )
143
+
144
+ async def _execute_and_callback ():
145
+ try :
146
+ result = await self .handle (tool_name , parameters )
147
+ response = {
148
+ "type" : "client_tool_result" ,
149
+ "tool_call_id" : parameters .get ("tool_call_id" ),
150
+ "result" : result or f"Client tool: { tool_name } called successfully." ,
151
+ "is_error" : False ,
152
+ }
153
+ except Exception as e :
154
+ response = {
155
+ "type" : "client_tool_result" ,
156
+ "tool_call_id" : parameters .get ("tool_call_id" ),
157
+ "result" : str (e ),
158
+ "is_error" : True ,
159
+ }
160
+ callback (response )
161
+
162
+ asyncio .run_coroutine_threadsafe (_execute_and_callback (), self ._loop )
163
+
164
+
55
165
class ConversationConfig :
56
166
"""Configuration options for the Conversation."""
167
+
57
168
def __init__ (
58
169
self ,
59
170
extra_body : Optional [dict ] = None ,
60
171
conversation_config_override : Optional [dict ] = None ,
61
172
):
62
173
self .extra_body = extra_body or {}
63
174
self .conversation_config_override = conversation_config_override or {}
64
-
175
+
176
+
65
177
class Conversation :
66
178
client : BaseElevenLabs
67
179
agent_id : str
68
180
requires_auth : bool
69
181
config : ConversationConfig
70
182
audio_interface : AudioInterface
183
+ client_tools : Optional [ClientTools ]
71
184
callback_agent_response : Optional [Callable [[str ], None ]]
72
185
callback_agent_response_correction : Optional [Callable [[str , str ], None ]]
73
186
callback_user_transcript : Optional [Callable [[str ], None ]]
@@ -86,7 +199,7 @@ def __init__(
86
199
requires_auth : bool ,
87
200
audio_interface : AudioInterface ,
88
201
config : Optional [ConversationConfig ] = None ,
89
-
202
+ client_tools : Optional [ ClientTools ] = None ,
90
203
callback_agent_response : Optional [Callable [[str ], None ]] = None ,
91
204
callback_agent_response_correction : Optional [Callable [[str , str ], None ]] = None ,
92
205
callback_user_transcript : Optional [Callable [[str ], None ]] = None ,
@@ -101,6 +214,7 @@ def __init__(
101
214
agent_id: The ID of the agent to converse with.
102
215
requires_auth: Whether the agent requires authentication.
103
216
audio_interface: The audio interface to use for input and output.
217
+ client_tools: The client tools to use for the conversation.
104
218
callback_agent_response: Callback for agent responses.
105
219
callback_agent_response_correction: Callback for agent response corrections.
106
220
First argument is the original response (previously given to
@@ -112,14 +226,16 @@ def __init__(
112
226
self .client = client
113
227
self .agent_id = agent_id
114
228
self .requires_auth = requires_auth
115
-
116
229
self .audio_interface = audio_interface
117
230
self .callback_agent_response = callback_agent_response
118
231
self .config = config or ConversationConfig ()
232
+ self .client_tools = client_tools or ClientTools ()
119
233
self .callback_agent_response_correction = callback_agent_response_correction
120
234
self .callback_user_transcript = callback_user_transcript
121
235
self .callback_latency_measurement = callback_latency_measurement
122
236
237
+ self .client_tools .start ()
238
+
123
239
self ._thread = None
124
240
self ._should_stop = threading .Event ()
125
241
self ._conversation_id = None
@@ -135,8 +251,9 @@ def start_session(self):
135
251
self ._thread .start ()
136
252
137
253
def end_session (self ):
138
- """Ends the conversation session."""
254
+ """Ends the conversation session and cleans up resources ."""
139
255
self .audio_interface .stop ()
256
+ self .client_tools .stop ()
140
257
self ._should_stop .set ()
141
258
142
259
def wait_for_session_end (self ) -> Optional [str ]:
@@ -155,10 +272,10 @@ def _run(self, ws_url: str):
155
272
with connect (ws_url ) as ws :
156
273
ws .send (
157
274
json .dumps (
158
- {
159
- "type" : "conversation_initiation_client_data" ,
160
- "custom_llm_extra_body" : self .config .extra_body ,
161
- "conversation_config_override" : self .config .conversation_config_override ,
275
+ {
276
+ "type" : "conversation_initiation_client_data" ,
277
+ "custom_llm_extra_body" : self .config .extra_body ,
278
+ "conversation_config_override" : self .config .conversation_config_override ,
162
279
}
163
280
)
164
281
)
@@ -210,7 +327,7 @@ def _handle_message(self, message, ws):
210
327
self .callback_user_transcript (event ["user_transcript" ].strip ())
211
328
elif message ["type" ] == "interruption" :
212
329
event = message ["interruption_event" ]
213
- self .last_interrupt_id = int (event ["event_id" ])
330
+ self ._last_interrupt_id = int (event ["event_id" ])
214
331
self .audio_interface .interrupt ()
215
332
elif message ["type" ] == "ping" :
216
333
event = message ["ping_event" ]
@@ -224,6 +341,16 @@ def _handle_message(self, message, ws):
224
341
)
225
342
if self .callback_latency_measurement and event ["ping_ms" ]:
226
343
self .callback_latency_measurement (int (event ["ping_ms" ]))
344
+ elif message ["type" ] == "client_tool_call" :
345
+ tool_call = message .get ("client_tool_call" , {})
346
+ tool_name = tool_call .get ("tool_name" )
347
+ parameters = {"tool_call_id" : tool_call ["tool_call_id" ], ** tool_call .get ("parameters" , {})}
348
+
349
+ def send_response (response ):
350
+ if not self ._should_stop .is_set ():
351
+ ws .send (json .dumps (response ))
352
+
353
+ self .client_tools .execute_tool (tool_name , parameters , send_response )
227
354
else :
228
355
pass # Ignore all other message types.
229
356
0 commit comments