Skip to content

Commit 57f85b6

Browse files
committed
feat: add support for client-tools (Python SDK)
1 parent 97d77ed commit 57f85b6

File tree

2 files changed

+216
-10
lines changed

2 files changed

+216
-10
lines changed

src/elevenlabs/conversational_ai/conversation.py

+137-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import base64
33
import json
44
import threading
5-
from typing import Callable, Optional
5+
from typing import Callable, Optional, Awaitable, Union, Any
6+
import asyncio
7+
from concurrent.futures import ThreadPoolExecutor
68

79
from websockets.sync.client import connect
810

@@ -52,22 +54,133 @@ def interrupt(self):
5254
"""
5355
pass
5456

57+
58+
class ClientTools:
59+
"""Handles registration and execution of client-side tools that can be called by the agent.
60+
61+
Supports both synchronous and asynchronous tools running in a dedicated event loop,
62+
ensuring non-blocking operation of the main conversation thread.
63+
"""
64+
65+
def __init__(self):
66+
self.tools: dict[str, tuple[Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]], bool]] = {}
67+
self.lock = threading.Lock()
68+
self._loop = None
69+
self._thread = None
70+
self._running = threading.Event()
71+
self.thread_pool = ThreadPoolExecutor()
72+
73+
def start(self):
74+
"""Start the event loop in a separate thread for handling async operations."""
75+
if self._running.is_set():
76+
return
77+
78+
def run_event_loop():
79+
self._loop = asyncio.new_event_loop()
80+
asyncio.set_event_loop(self._loop)
81+
self._running.set()
82+
try:
83+
self._loop.run_forever()
84+
finally:
85+
self._running.clear()
86+
self._loop.close()
87+
self._loop = None
88+
89+
self._thread = threading.Thread(target=run_event_loop, daemon=True, name="ClientTools-EventLoop")
90+
self._thread.start()
91+
# Wait for loop to be ready
92+
self._running.wait()
93+
94+
def stop(self):
95+
"""Gracefully stop the event loop and clean up resources."""
96+
if self._loop and self._running.is_set():
97+
self._loop.call_soon_threadsafe(self._loop.stop)
98+
self._thread.join()
99+
self.thread_pool.shutdown(wait=False)
100+
101+
def register(
102+
self,
103+
tool_name: str,
104+
handler: Union[Callable[[dict], Any], Callable[[dict], Awaitable[Any]]],
105+
is_async: bool = False,
106+
) -> None:
107+
"""Register a new tool that can be called by the AI agent.
108+
109+
Args:
110+
tool_name: Unique identifier for the tool
111+
handler: Function that implements the tool's logic
112+
is_async: Whether the handler is an async function
113+
"""
114+
with self.lock:
115+
if not callable(handler):
116+
raise ValueError("Handler must be callable")
117+
if tool_name in self.tools:
118+
raise ValueError(f"Tool '{tool_name}' is already registered")
119+
self.tools[tool_name] = (handler, is_async)
120+
121+
async def handle(self, tool_name: str, parameters: dict) -> Any:
122+
"""Execute a registered tool with the given parameters.
123+
124+
Returns the result of the tool execution.
125+
"""
126+
with self.lock:
127+
if tool_name not in self.tools:
128+
raise ValueError(f"Tool '{tool_name}' is not registered")
129+
handler, is_async = self.tools[tool_name]
130+
131+
if is_async:
132+
return await handler(parameters)
133+
else:
134+
return await asyncio.get_event_loop().run_in_executor(self.thread_pool, handler, parameters)
135+
136+
def execute_tool(self, tool_name: str, parameters: dict, callback: Callable[[dict], None]):
137+
"""Execute a tool and send its result via the provided callback.
138+
139+
This method is non-blocking and handles both sync and async tools.
140+
"""
141+
if not self._running.is_set():
142+
raise RuntimeError("ClientTools event loop is not running")
143+
144+
async def _execute_and_callback():
145+
try:
146+
result = await self.handle(tool_name, parameters)
147+
response = {
148+
"type": "client_tool_result",
149+
"tool_call_id": parameters.get("tool_call_id"),
150+
"result": result or f"Client tool: {tool_name} called successfully.",
151+
"is_error": False,
152+
}
153+
except Exception as e:
154+
response = {
155+
"type": "client_tool_result",
156+
"tool_call_id": parameters.get("tool_call_id"),
157+
"result": str(e),
158+
"is_error": True,
159+
}
160+
callback(response)
161+
162+
asyncio.run_coroutine_threadsafe(_execute_and_callback(), self._loop)
163+
164+
55165
class ConversationConfig:
56166
"""Configuration options for the Conversation."""
167+
57168
def __init__(
58169
self,
59170
extra_body: Optional[dict] = None,
60171
conversation_config_override: Optional[dict] = None,
61172
):
62173
self.extra_body = extra_body or {}
63174
self.conversation_config_override = conversation_config_override or {}
64-
175+
176+
65177
class Conversation:
66178
client: BaseElevenLabs
67179
agent_id: str
68180
requires_auth: bool
69181
config: ConversationConfig
70182
audio_interface: AudioInterface
183+
client_tools: Optional[ClientTools]
71184
callback_agent_response: Optional[Callable[[str], None]]
72185
callback_agent_response_correction: Optional[Callable[[str, str], None]]
73186
callback_user_transcript: Optional[Callable[[str], None]]
@@ -86,7 +199,7 @@ def __init__(
86199
requires_auth: bool,
87200
audio_interface: AudioInterface,
88201
config: Optional[ConversationConfig] = None,
89-
202+
client_tools: Optional[ClientTools] = None,
90203
callback_agent_response: Optional[Callable[[str], None]] = None,
91204
callback_agent_response_correction: Optional[Callable[[str, str], None]] = None,
92205
callback_user_transcript: Optional[Callable[[str], None]] = None,
@@ -101,6 +214,7 @@ def __init__(
101214
agent_id: The ID of the agent to converse with.
102215
requires_auth: Whether the agent requires authentication.
103216
audio_interface: The audio interface to use for input and output.
217+
client_tools: The client tools to use for the conversation.
104218
callback_agent_response: Callback for agent responses.
105219
callback_agent_response_correction: Callback for agent response corrections.
106220
First argument is the original response (previously given to
@@ -112,14 +226,16 @@ def __init__(
112226
self.client = client
113227
self.agent_id = agent_id
114228
self.requires_auth = requires_auth
115-
116229
self.audio_interface = audio_interface
117230
self.callback_agent_response = callback_agent_response
118231
self.config = config or ConversationConfig()
232+
self.client_tools = client_tools or ClientTools()
119233
self.callback_agent_response_correction = callback_agent_response_correction
120234
self.callback_user_transcript = callback_user_transcript
121235
self.callback_latency_measurement = callback_latency_measurement
122236

237+
self.client_tools.start()
238+
123239
self._thread = None
124240
self._should_stop = threading.Event()
125241
self._conversation_id = None
@@ -135,8 +251,9 @@ def start_session(self):
135251
self._thread.start()
136252

137253
def end_session(self):
138-
"""Ends the conversation session."""
254+
"""Ends the conversation session and cleans up resources."""
139255
self.audio_interface.stop()
256+
self.client_tools.stop()
140257
self._should_stop.set()
141258

142259
def wait_for_session_end(self) -> Optional[str]:
@@ -155,10 +272,10 @@ def _run(self, ws_url: str):
155272
with connect(ws_url) as ws:
156273
ws.send(
157274
json.dumps(
158-
{
159-
"type": "conversation_initiation_client_data",
160-
"custom_llm_extra_body": self.config.extra_body,
161-
"conversation_config_override": self.config.conversation_config_override,
275+
{
276+
"type": "conversation_initiation_client_data",
277+
"custom_llm_extra_body": self.config.extra_body,
278+
"conversation_config_override": self.config.conversation_config_override,
162279
}
163280
)
164281
)
@@ -210,7 +327,7 @@ def _handle_message(self, message, ws):
210327
self.callback_user_transcript(event["user_transcript"].strip())
211328
elif message["type"] == "interruption":
212329
event = message["interruption_event"]
213-
self.last_interrupt_id = int(event["event_id"])
330+
self._last_interrupt_id = int(event["event_id"])
214331
self.audio_interface.interrupt()
215332
elif message["type"] == "ping":
216333
event = message["ping_event"]
@@ -224,6 +341,16 @@ def _handle_message(self, message, ws):
224341
)
225342
if self.callback_latency_measurement and event["ping_ms"]:
226343
self.callback_latency_measurement(int(event["ping_ms"]))
344+
elif message["type"] == "client_tool_call":
345+
tool_call = message.get("client_tool_call", {})
346+
tool_name = tool_call.get("tool_name")
347+
parameters = {"tool_call_id": tool_call["tool_call_id"], **tool_call.get("parameters", {})}
348+
349+
def send_response(response):
350+
if not self._should_stop.is_set():
351+
ws.send(json.dumps(response))
352+
353+
self.client_tools.execute_tool(tool_name, parameters, send_response)
227354
else:
228355
pass # Ignore all other message types.
229356

tests/e2e_test_convai.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import os
2+
import time
3+
import asyncio
4+
5+
import pytest
6+
from elevenlabs import ElevenLabs
7+
from elevenlabs.conversational_ai.conversation import Conversation, ClientTools
8+
from elevenlabs.conversational_ai.default_audio_interface import DefaultAudioInterface
9+
10+
11+
@pytest.mark.skipif(os.getenv("CI") == "true", reason="Skip live conversation test in CI environment")
12+
def test_live_conversation():
13+
"""Test a live conversation with actual audio I/O"""
14+
15+
api_key = os.getenv("ELEVENLABS_API_KEY")
16+
if not api_key:
17+
raise ValueError("ELEVENLABS_API_KEY environment variable missing.")
18+
19+
agent_id = os.getenv("AGENT_ID")
20+
if not api_key or not agent_id:
21+
raise ValueError("AGENT_ID environment variable missing.")
22+
23+
client = ElevenLabs(api_key=api_key)
24+
25+
# Create conversation handlers
26+
def on_agent_response(text: str):
27+
print(f"Agent: {text}")
28+
29+
def on_user_transcript(text: str):
30+
print(f"You: {text}")
31+
32+
def on_latency(ms: int):
33+
print(f"Latency: {ms}ms")
34+
35+
# Initialize client tools
36+
client_tools = ClientTools()
37+
38+
def test(parameters):
39+
print("Sync tool called with parameters:", parameters)
40+
return "Tool called successfully"
41+
42+
async def test_async(parameters):
43+
# Simulate some async work
44+
await asyncio.sleep(10)
45+
print("Async tool called with parameters:", parameters)
46+
return "Tool called successfully"
47+
48+
client_tools.register("test", test)
49+
client_tools.register("test_async", test_async, is_async=True)
50+
51+
# Initialize conversation
52+
conversation = Conversation(
53+
client=client,
54+
agent_id=agent_id,
55+
requires_auth=False,
56+
audio_interface=DefaultAudioInterface(),
57+
callback_agent_response=on_agent_response,
58+
callback_user_transcript=on_user_transcript,
59+
callback_latency_measurement=on_latency,
60+
client_tools=client_tools,
61+
)
62+
63+
# Start the conversation
64+
conversation.start_session()
65+
66+
# Let it run for 100 seconds
67+
time.sleep(100)
68+
69+
# End the conversation
70+
conversation.end_session()
71+
conversation.wait_for_session_end()
72+
73+
# Get the conversation ID for reference
74+
conversation_id = conversation._conversation_id
75+
print(f"Conversation ID: {conversation_id}")
76+
77+
78+
if __name__ == "__main__":
79+
test_live_conversation()

0 commit comments

Comments
 (0)