1
1
import json
2
- from typing import Any , Dict , List , Optional
2
+ from typing import Any , Dict , List , Optional , Union
3
3
4
4
from haystack import component , default_from_dict , default_to_dict , logging
5
5
from haystack .dataclasses import ChatMessage , ToolCall
6
- from haystack .tools import Tool , _check_duplicate_tool_names
7
-
8
- # Compatibility with Haystack 2.12.0 and 2.13.0 - remove after 2.13.0 is released
9
- try :
10
- from haystack .tools import deserialize_tools_or_toolset_inplace
11
- except ImportError :
12
- from haystack .tools import deserialize_tools_inplace as deserialize_tools_or_toolset_inplace
13
-
14
- from llama_cpp import ChatCompletionResponseChoice , CreateChatCompletionResponse , Llama
6
+ from haystack .tools import (
7
+ Tool ,
8
+ Toolset ,
9
+ _check_duplicate_tool_names ,
10
+ deserialize_tools_or_toolset_inplace ,
11
+ serialize_tools_or_toolset ,
12
+ )
13
+ from llama_cpp import (
14
+ ChatCompletionMessageToolCall ,
15
+ ChatCompletionRequestAssistantMessage ,
16
+ ChatCompletionRequestMessage ,
17
+ ChatCompletionResponseChoice ,
18
+ ChatCompletionTool ,
19
+ CreateChatCompletionResponse ,
20
+ Llama ,
21
+ )
15
22
from llama_cpp .llama_tokenizer import LlamaHFTokenizer
16
23
17
24
logger = logging .getLogger (__name__ )
18
25
19
26
20
- def _convert_message_to_llamacpp_format (message : ChatMessage ) -> Dict [ str , Any ] :
27
+ def _convert_message_to_llamacpp_format (message : ChatMessage ) -> ChatCompletionRequestMessage :
21
28
"""
22
- Convert a ChatMessage to the format expected by Ollama Chat API.
29
+ Convert a ChatMessage to the format expected by llama.cpp Chat API.
23
30
"""
24
31
text_contents = message .texts
25
32
tool_calls = message .tool_calls
@@ -33,38 +40,51 @@ def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, Any]:
33
40
raise ValueError (msg )
34
41
35
42
role = message ._role .value
36
- if role == "tool" :
37
- role = "function"
38
-
39
- llamacpp_msg : Dict [str , Any ] = {"role" : role }
40
43
41
- if tool_call_results :
44
+ if role == "tool" and tool_call_results :
42
45
if tool_call_results [0 ].origin .id is None :
43
46
msg = "`ToolCall` must have a non-null `id` attribute to be used with llama.cpp."
44
47
raise ValueError (msg )
45
- llamacpp_msg ["content" ] = tool_call_results [0 ].result
46
- llamacpp_msg ["tool_call_id" ] = tool_call_results [0 ].origin .id
47
- # Llama.cpp does not provide a way to communicate errors in tool invocations, so we ignore the error field
48
- return llamacpp_msg
49
-
50
- if text_contents :
51
- llamacpp_msg ["content" ] = text_contents [0 ]
52
- if tool_calls :
53
- llamacpp_tool_calls = []
54
- for tc in tool_calls :
55
- if tc .id is None :
56
- msg = "`ToolCall` must have a non-null `id` attribute to be used with llama.cpp."
57
- raise ValueError (msg )
58
- llamacpp_tool_calls .append (
59
- {
60
- "id" : tc .id ,
61
- "type" : "function" ,
62
- # We disable ensure_ascii so special chars like emojis are not converted
63
- "function" : {"name" : tc .tool_name , "arguments" : json .dumps (tc .arguments , ensure_ascii = False )},
64
- }
65
- )
66
- llamacpp_msg ["tool_calls" ] = llamacpp_tool_calls
67
- return llamacpp_msg
48
+ return {
49
+ "role" : "function" ,
50
+ "content" : tool_call_results [0 ].result ,
51
+ "name" : tool_call_results [0 ].origin .tool_name ,
52
+ }
53
+
54
+ if role == "system" :
55
+ content = text_contents [0 ] if text_contents else None
56
+ return {"role" : "system" , "content" : content }
57
+
58
+ if role == "user" :
59
+ content = text_contents [0 ] if text_contents else None
60
+ return {"role" : "user" , "content" : content }
61
+
62
+ if role == "assistant" :
63
+ result : ChatCompletionRequestAssistantMessage = {"role" : "assistant" }
64
+
65
+ if text_contents :
66
+ result ["content" ] = text_contents [0 ]
67
+
68
+ if tool_calls :
69
+ llamacpp_tool_calls : List [ChatCompletionMessageToolCall ] = []
70
+ for tc in tool_calls :
71
+ if tc .id is None :
72
+ msg = "`ToolCall` must have a non-null `id` attribute to be used with llama.cpp."
73
+ raise ValueError (msg )
74
+ llamacpp_tool_calls .append (
75
+ {
76
+ "id" : tc .id ,
77
+ "type" : "function" ,
78
+ # We disable ensure_ascii so special chars like emojis are not converted
79
+ "function" : {"name" : tc .tool_name , "arguments" : json .dumps (tc .arguments , ensure_ascii = False )},
80
+ }
81
+ )
82
+ result ["tool_calls" ] = llamacpp_tool_calls
83
+
84
+ return result
85
+
86
+ error_msg = f"Unknown role: { role } "
87
+ raise ValueError (error_msg )
68
88
69
89
70
90
@component
@@ -94,7 +114,7 @@ def __init__(
94
114
model_kwargs : Optional [Dict [str , Any ]] = None ,
95
115
generation_kwargs : Optional [Dict [str , Any ]] = None ,
96
116
* ,
97
- tools : Optional [List [Tool ]] = None ,
117
+ tools : Optional [Union [ List [Tool ], Toolset ]] = None ,
98
118
):
99
119
"""
100
120
:param model: The path of a quantized model for text generation, for example, "zephyr-7b-beta.Q4_0.gguf".
@@ -110,7 +130,8 @@ def __init__(
110
130
For more information on the available kwargs, see
111
131
[llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
112
132
:param tools:
113
- A list of tools for which the model can prepare calls.
133
+ A list of tools or a Toolset for which the model can prepare calls.
134
+ This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
114
135
"""
115
136
116
137
model_kwargs = model_kwargs or {}
@@ -122,14 +143,14 @@ def __init__(
122
143
model_kwargs .setdefault ("n_ctx" , n_ctx )
123
144
model_kwargs .setdefault ("n_batch" , n_batch )
124
145
125
- _check_duplicate_tool_names (tools )
146
+ _check_duplicate_tool_names (list ( tools or []) )
126
147
127
148
self .model_path = model
128
149
self .n_ctx = n_ctx
129
150
self .n_batch = n_batch
130
151
self .model_kwargs = model_kwargs
131
152
self .generation_kwargs = generation_kwargs
132
- self ._model = None
153
+ self ._model : Optional [ Llama ] = None
133
154
self .tools = tools
134
155
135
156
def warm_up (self ):
@@ -147,15 +168,14 @@ def to_dict(self) -> Dict[str, Any]:
147
168
:returns:
148
169
Dictionary with serialized data.
149
170
"""
150
- serialized_tools = [tool .to_dict () for tool in self .tools ] if self .tools else None
151
171
return default_to_dict (
152
172
self ,
153
173
model = self .model_path ,
154
174
n_ctx = self .n_ctx ,
155
175
n_batch = self .n_batch ,
156
176
model_kwargs = self .model_kwargs ,
157
177
generation_kwargs = self .generation_kwargs ,
158
- tools = serialized_tools ,
178
+ tools = serialize_tools_or_toolset ( self . tools ) ,
159
179
)
160
180
161
181
@classmethod
@@ -177,8 +197,8 @@ def run(
177
197
messages : List [ChatMessage ],
178
198
generation_kwargs : Optional [Dict [str , Any ]] = None ,
179
199
* ,
180
- tools : Optional [List [Tool ]] = None ,
181
- ):
200
+ tools : Optional [Union [ List [Tool ], Toolset ]] = None ,
201
+ ) -> Dict [ str , List [ ChatMessage ]] :
182
202
"""
183
203
Run the text generation model on the given list of ChatMessages.
184
204
@@ -188,8 +208,8 @@ def run(
188
208
For more information on the available kwargs, see
189
209
[llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
190
210
:param tools:
191
- A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
192
- during component initialization.
211
+ A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
212
+ parameter set during component initialization.
193
213
:returns: A dictionary with the following keys:
194
214
- `replies`: The responses from the model
195
215
"""
@@ -204,16 +224,33 @@ def run(
204
224
formatted_messages = [_convert_message_to_llamacpp_format (msg ) for msg in messages ]
205
225
206
226
tools = tools or self .tools
207
- llamacpp_tools = {}
227
+ if isinstance (tools , Toolset ):
228
+ tools = list (tools )
229
+ _check_duplicate_tool_names (tools )
230
+
231
+ llamacpp_tools : List [ChatCompletionTool ] = []
208
232
if tools :
209
- tool_definitions = [{"type" : "function" , "function" : {** t .tool_spec }} for t in tools ]
210
- llamacpp_tools = {"tools" : tool_definitions }
233
+ for t in tools :
234
+ llamacpp_tools .append (
235
+ {
236
+ "type" : "function" ,
237
+ "function" : {
238
+ "name" : t .tool_spec ["name" ],
239
+ "description" : t .tool_spec .get ("description" , "" ),
240
+ "parameters" : t .tool_spec .get ("parameters" , {}),
241
+ },
242
+ }
243
+ )
211
244
212
245
response = self ._model .create_chat_completion (
213
- messages = formatted_messages , ** updated_generation_kwargs , ** llamacpp_tools
246
+ messages = formatted_messages , tools = llamacpp_tools , ** updated_generation_kwargs
214
247
)
215
248
216
249
replies = []
250
+ if not isinstance (response , dict ):
251
+ msg = f"Expected a dictionary response, got a different object: { response } "
252
+ raise ValueError (msg )
253
+
217
254
for choice in response ["choices" ]:
218
255
chat_message = self ._convert_chat_completion_choice_to_chat_message (choice , response )
219
256
replies .append (chat_message )
@@ -239,10 +276,10 @@ def _convert_chat_completion_choice_to_chat_message(
239
276
except json .JSONDecodeError :
240
277
logger .warning (
241
278
"Llama.cpp returned a malformed JSON string for tool call arguments. This tool call "
242
- "will be skipped. Tool call ID: %s , Tool name: %s , Arguments: %s " ,
243
- llamacpp_tc ["id" ],
244
- llamacpp_tc ["function" ]["name" ],
245
- arguments_str ,
279
+ "will be skipped. Tool call ID: {tc_id} , Tool name: {tc_name} , Arguments: {tc_args} " ,
280
+ tc_id = llamacpp_tc ["id" ],
281
+ tc_name = llamacpp_tc ["function" ]["name" ],
282
+ tc_args = arguments_str ,
246
283
)
247
284
248
285
meta = {
0 commit comments