Skip to content

Commit 188b2a7

Browse files
authored
feat: support for tools in OpenAIChatGenerator (#8666)
* move chatmsg>openai conversion to chatmsg dataclass * implementation and tests cleanup * release note * try fixing azure chat generator * add serde test for toolinvoker * small fix
1 parent 7dcbf25 commit 188b2a7

File tree

17 files changed

+720
-305
lines changed

17 files changed

+720
-305
lines changed

haystack/components/generators/chat/azure.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ def __init__( # pylint: disable=too-many-positional-arguments
142142
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
143143
self.default_headers = default_headers or {}
144144

145+
# This ChatGenerator does not yet supports tools. The following workaround ensures that we do not
146+
# get an error when invoking the run method of the parent class (OpenAIChatGenerator).
147+
self.tools = None
148+
self.tools_strict = False
149+
145150
self.client = AzureOpenAI(
146151
api_version=api_version,
147152
azure_endpoint=azure_endpoint,

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,9 @@ def __init__( # pylint: disable=too-many-positional-arguments
163163
msg = f"Unknown api_type {api_type}"
164164
raise ValueError(msg)
165165

166-
if tools:
167-
if streaming_callback is not None:
168-
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
169-
_check_duplicate_tool_names(tools)
166+
if tools and streaming_callback is not None:
167+
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
168+
_check_duplicate_tool_names(tools)
170169

171170
# handle generation kwargs setup
172171
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
@@ -241,10 +240,9 @@ def run(
241240
formatted_messages = [convert_message_to_hf_format(message) for message in messages]
242241

243242
tools = tools or self.tools
244-
if tools:
245-
if self.streaming_callback:
246-
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
247-
_check_duplicate_tool_names(tools)
243+
if tools and self.streaming_callback:
244+
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
245+
_check_duplicate_tool_names(tools)
248246

249247
if self.streaming_callback:
250248
return self._run_streaming(formatted_messages, generation_kwargs)

haystack/components/generators/chat/openai.py

Lines changed: 209 additions & 141 deletions
Large diffs are not rendered by default.

haystack/components/generators/openai.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from openai.types.chat import ChatCompletion, ChatCompletionChunk
1010

1111
from haystack import component, default_from_dict, default_to_dict, logging
12-
from haystack.components.generators.openai_utils import _convert_message_to_openai_format
1312
from haystack.dataclasses import ChatMessage, StreamingChunk
1413
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1514

@@ -207,7 +206,7 @@ def run(
207206
streaming_callback = streaming_callback or self.streaming_callback
208207

209208
# adapt ChatMessage(s) to the format expected by the OpenAI API
210-
openai_formatted_messages = [_convert_message_to_openai_format(message) for message in messages]
209+
openai_formatted_messages = [message.to_openai_dict_format() for message in messages]
211210

212211
completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
213212
model=self.model,

haystack/components/generators/openai_utils.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

haystack/dataclasses/chat_message.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import json
56
import warnings
67
from dataclasses import asdict, dataclass, field
78
from enum import Enum
@@ -381,3 +382,47 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage":
381382
data["_content"] = content
382383

383384
return cls(**data)
385+
386+
def to_openai_dict_format(self) -> Dict[str, Any]:
387+
"""
388+
Convert a ChatMessage to the dictionary format expected by OpenAI's Chat API.
389+
"""
390+
text_contents = self.texts
391+
tool_calls = self.tool_calls
392+
tool_call_results = self.tool_call_results
393+
394+
if not text_contents and not tool_calls and not tool_call_results:
395+
raise ValueError(
396+
"A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
397+
)
398+
if len(text_contents) + len(tool_call_results) > 1:
399+
raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")
400+
401+
openai_msg: Dict[str, Any] = {"role": self._role.value}
402+
403+
if tool_call_results:
404+
result = tool_call_results[0]
405+
if result.origin.id is None:
406+
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.")
407+
openai_msg["content"] = result.result
408+
openai_msg["tool_call_id"] = result.origin.id
409+
# OpenAI does not provide a way to communicate errors in tool invocations, so we ignore the error field
410+
return openai_msg
411+
412+
if text_contents:
413+
openai_msg["content"] = text_contents[0]
414+
if tool_calls:
415+
openai_tool_calls = []
416+
for tc in tool_calls:
417+
if tc.id is None:
418+
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.")
419+
openai_tool_calls.append(
420+
{
421+
"id": tc.id,
422+
"type": "function",
423+
# We disable ensure_ascii so special chars like emojis are not converted
424+
"function": {"name": tc.tool_name, "arguments": json.dumps(tc.arguments, ensure_ascii=False)},
425+
}
426+
)
427+
openai_msg["tool_calls"] = openai_tool_calls
428+
return openai_msg

haystack/dataclasses/tool.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,15 @@ def _remove_title_from_schema(schema: Dict[str, Any]):
216216
del property_schema[key]
217217

218218

219-
def _check_duplicate_tool_names(tools: List[Tool]) -> None:
219+
def _check_duplicate_tool_names(tools: Optional[List[Tool]]) -> None:
220220
"""
221-
Check for duplicate tool names and raises a ValueError if they are found.
221+
Checks for duplicate tool names and raises a ValueError if they are found.
222222
223223
:param tools: The list of tools to check.
224224
:raises ValueError: If duplicate tool names are found.
225225
"""
226+
if tools is None:
227+
return
226228
tool_names = [tool.name for tool in tools]
227229
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
228230
if duplicate_tool_names:
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- |
4+
Add support for Tools in the OpenAI Chat Generator.

test/components/generators/chat/conftest.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

test/components/generators/chat/test_hugging_face_local.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ def streaming_callback_handler(x):
1717
return x
1818

1919

20+
@pytest.fixture
21+
def chat_messages():
22+
return [
23+
ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"),
24+
ChatMessage.from_user("Tell me about Berlin"),
25+
]
26+
27+
2028
@pytest.fixture
2129
def model_info_mock():
2230
with patch(

0 commit comments

Comments
 (0)