Skip to content

Commit 9a15467

Browse files
authored
fix: fix llama.cpp types; add py.typed; Toolset support (#1973)
* fix: fix llama.cpp types; add py.typed; Toolset support * missing check
1 parent 6fe3cf7 commit 9a15467

File tree

6 files changed

+158
-99
lines changed

6 files changed

+158
-99
lines changed

.github/workflows/llama_cpp.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,9 @@ jobs:
5050
- name: Install Hatch
5151
run: pip install --upgrade hatch
5252

53-
# TODO: Once this integration is properly typed, use hatch run test:types
54-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
5553
- name: Lint
5654
if: matrix.python-version == '3.9' && runner.os == 'Linux'
57-
run: hatch run fmt-check && hatch run lint:typing
55+
run: hatch run fmt-check && hatch run test:types
5856

5957
- name: Generate docs
6058
if: matrix.python-version == '3.9' && runner.os == 'Linux'

integrations/llama_cpp/pyproject.toml

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Programming Language :: Python :: Implementation :: CPython",
2727
"Programming Language :: Python :: Implementation :: PyPy",
2828
]
29-
dependencies = ["haystack-ai>=2.9.0", "llama-cpp-python>=0.2.87"]
29+
dependencies = ["haystack-ai>=2.13.0", "llama-cpp-python>=0.2.87"]
3030

3131
[project.urls]
3232
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/llama_cpp#readme"
@@ -68,18 +68,13 @@ unit = 'pytest -m "not integration" {args:tests}'
6868
integration = 'pytest -m "integration" {args:tests}'
6969
all = 'pytest {args:tests}'
7070
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
71-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
71+
types = "mypy -p haystack_integrations.components.generators.llama_cpp {args}"
7272

73-
# TODO: remove lint environment once this integration is properly typed
74-
# test environment should be used instead
75-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
76-
[tool.hatch.envs.lint]
77-
installer = "uv"
78-
detached = true
79-
dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]
80-
81-
[tool.hatch.envs.lint.scripts]
82-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
73+
[tool.mypy]
74+
install_types = true
75+
non_interactive = true
76+
check_untyped_defs = true
77+
disallow_incomplete_defs = true
8378

8479
[tool.hatch.metadata]
8580
allow-direct-references = true
@@ -169,7 +164,3 @@ markers = [
169164
"integration: marks tests as slow (deselect with '-m \"not integration\"')",
170165
]
171166
addopts = ["--import-mode=importlib"]
172-
173-
[[tool.mypy.overrides]]
174-
module = ["haystack.*", "haystack_integrations.*", "pytest.*", "llama_cpp.*"]
175-
ignore_missing_imports = true

integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py

Lines changed: 95 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,32 @@
11
import json
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Dict, List, Optional, Union
33

44
from haystack import component, default_from_dict, default_to_dict, logging
55
from haystack.dataclasses import ChatMessage, ToolCall
6-
from haystack.tools import Tool, _check_duplicate_tool_names
7-
8-
# Compatibility with Haystack 2.12.0 and 2.13.0 - remove after 2.13.0 is released
9-
try:
10-
from haystack.tools import deserialize_tools_or_toolset_inplace
11-
except ImportError:
12-
from haystack.tools import deserialize_tools_inplace as deserialize_tools_or_toolset_inplace
13-
14-
from llama_cpp import ChatCompletionResponseChoice, CreateChatCompletionResponse, Llama
6+
from haystack.tools import (
7+
Tool,
8+
Toolset,
9+
_check_duplicate_tool_names,
10+
deserialize_tools_or_toolset_inplace,
11+
serialize_tools_or_toolset,
12+
)
13+
from llama_cpp import (
14+
ChatCompletionMessageToolCall,
15+
ChatCompletionRequestAssistantMessage,
16+
ChatCompletionRequestMessage,
17+
ChatCompletionResponseChoice,
18+
ChatCompletionTool,
19+
CreateChatCompletionResponse,
20+
Llama,
21+
)
1522
from llama_cpp.llama_tokenizer import LlamaHFTokenizer
1623

1724
logger = logging.getLogger(__name__)
1825

1926

20-
def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, Any]:
27+
def _convert_message_to_llamacpp_format(message: ChatMessage) -> ChatCompletionRequestMessage:
2128
"""
22-
Convert a ChatMessage to the format expected by Ollama Chat API.
29+
Convert a ChatMessage to the format expected by llama.cpp Chat API.
2330
"""
2431
text_contents = message.texts
2532
tool_calls = message.tool_calls
@@ -33,38 +40,51 @@ def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, Any]:
3340
raise ValueError(msg)
3441

3542
role = message._role.value
36-
if role == "tool":
37-
role = "function"
38-
39-
llamacpp_msg: Dict[str, Any] = {"role": role}
4043

41-
if tool_call_results:
44+
if role == "tool" and tool_call_results:
4245
if tool_call_results[0].origin.id is None:
4346
msg = "`ToolCall` must have a non-null `id` attribute to be used with llama.cpp."
4447
raise ValueError(msg)
45-
llamacpp_msg["content"] = tool_call_results[0].result
46-
llamacpp_msg["tool_call_id"] = tool_call_results[0].origin.id
47-
# Llama.cpp does not provide a way to communicate errors in tool invocations, so we ignore the error field
48-
return llamacpp_msg
49-
50-
if text_contents:
51-
llamacpp_msg["content"] = text_contents[0]
52-
if tool_calls:
53-
llamacpp_tool_calls = []
54-
for tc in tool_calls:
55-
if tc.id is None:
56-
msg = "`ToolCall` must have a non-null `id` attribute to be used with llama.cpp."
57-
raise ValueError(msg)
58-
llamacpp_tool_calls.append(
59-
{
60-
"id": tc.id,
61-
"type": "function",
62-
# We disable ensure_ascii so special chars like emojis are not converted
63-
"function": {"name": tc.tool_name, "arguments": json.dumps(tc.arguments, ensure_ascii=False)},
64-
}
65-
)
66-
llamacpp_msg["tool_calls"] = llamacpp_tool_calls
67-
return llamacpp_msg
48+
return {
49+
"role": "function",
50+
"content": tool_call_results[0].result,
51+
"name": tool_call_results[0].origin.tool_name,
52+
}
53+
54+
if role == "system":
55+
content = text_contents[0] if text_contents else None
56+
return {"role": "system", "content": content}
57+
58+
if role == "user":
59+
content = text_contents[0] if text_contents else None
60+
return {"role": "user", "content": content}
61+
62+
if role == "assistant":
63+
result: ChatCompletionRequestAssistantMessage = {"role": "assistant"}
64+
65+
if text_contents:
66+
result["content"] = text_contents[0]
67+
68+
if tool_calls:
69+
llamacpp_tool_calls: List[ChatCompletionMessageToolCall] = []
70+
for tc in tool_calls:
71+
if tc.id is None:
72+
msg = "`ToolCall` must have a non-null `id` attribute to be used with llama.cpp."
73+
raise ValueError(msg)
74+
llamacpp_tool_calls.append(
75+
{
76+
"id": tc.id,
77+
"type": "function",
78+
# We disable ensure_ascii so special chars like emojis are not converted
79+
"function": {"name": tc.tool_name, "arguments": json.dumps(tc.arguments, ensure_ascii=False)},
80+
}
81+
)
82+
result["tool_calls"] = llamacpp_tool_calls
83+
84+
return result
85+
86+
error_msg = f"Unknown role: {role}"
87+
raise ValueError(error_msg)
6888

6989

7090
@component
@@ -94,7 +114,7 @@ def __init__(
94114
model_kwargs: Optional[Dict[str, Any]] = None,
95115
generation_kwargs: Optional[Dict[str, Any]] = None,
96116
*,
97-
tools: Optional[List[Tool]] = None,
117+
tools: Optional[Union[List[Tool], Toolset]] = None,
98118
):
99119
"""
100120
:param model: The path of a quantized model for text generation, for example, "zephyr-7b-beta.Q4_0.gguf".
@@ -110,7 +130,8 @@ def __init__(
110130
For more information on the available kwargs, see
111131
[llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
112132
:param tools:
113-
A list of tools for which the model can prepare calls.
133+
A list of tools or a Toolset for which the model can prepare calls.
134+
This parameter can accept either a list of `Tool` objects or a `Toolset` instance.
114135
"""
115136

116137
model_kwargs = model_kwargs or {}
@@ -122,14 +143,14 @@ def __init__(
122143
model_kwargs.setdefault("n_ctx", n_ctx)
123144
model_kwargs.setdefault("n_batch", n_batch)
124145

125-
_check_duplicate_tool_names(tools)
146+
_check_duplicate_tool_names(list(tools or []))
126147

127148
self.model_path = model
128149
self.n_ctx = n_ctx
129150
self.n_batch = n_batch
130151
self.model_kwargs = model_kwargs
131152
self.generation_kwargs = generation_kwargs
132-
self._model = None
153+
self._model: Optional[Llama] = None
133154
self.tools = tools
134155

135156
def warm_up(self):
@@ -147,15 +168,14 @@ def to_dict(self) -> Dict[str, Any]:
147168
:returns:
148169
Dictionary with serialized data.
149170
"""
150-
serialized_tools = [tool.to_dict() for tool in self.tools] if self.tools else None
151171
return default_to_dict(
152172
self,
153173
model=self.model_path,
154174
n_ctx=self.n_ctx,
155175
n_batch=self.n_batch,
156176
model_kwargs=self.model_kwargs,
157177
generation_kwargs=self.generation_kwargs,
158-
tools=serialized_tools,
178+
tools=serialize_tools_or_toolset(self.tools),
159179
)
160180

161181
@classmethod
@@ -177,8 +197,8 @@ def run(
177197
messages: List[ChatMessage],
178198
generation_kwargs: Optional[Dict[str, Any]] = None,
179199
*,
180-
tools: Optional[List[Tool]] = None,
181-
):
200+
tools: Optional[Union[List[Tool], Toolset]] = None,
201+
) -> Dict[str, List[ChatMessage]]:
182202
"""
183203
Run the text generation model on the given list of ChatMessages.
184204
@@ -188,8 +208,8 @@ def run(
188208
For more information on the available kwargs, see
189209
[llama.cpp documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_chat_completion).
190210
:param tools:
191-
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
192-
during component initialization.
211+
A list of tools or a Toolset for which the model can prepare calls. If set, it will override the `tools`
212+
parameter set during component initialization.
193213
:returns: A dictionary with the following keys:
194214
- `replies`: The responses from the model
195215
"""
@@ -204,16 +224,33 @@ def run(
204224
formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages]
205225

206226
tools = tools or self.tools
207-
llamacpp_tools = {}
227+
if isinstance(tools, Toolset):
228+
tools = list(tools)
229+
_check_duplicate_tool_names(tools)
230+
231+
llamacpp_tools: List[ChatCompletionTool] = []
208232
if tools:
209-
tool_definitions = [{"type": "function", "function": {**t.tool_spec}} for t in tools]
210-
llamacpp_tools = {"tools": tool_definitions}
233+
for t in tools:
234+
llamacpp_tools.append(
235+
{
236+
"type": "function",
237+
"function": {
238+
"name": t.tool_spec["name"],
239+
"description": t.tool_spec.get("description", ""),
240+
"parameters": t.tool_spec.get("parameters", {}),
241+
},
242+
}
243+
)
211244

212245
response = self._model.create_chat_completion(
213-
messages=formatted_messages, **updated_generation_kwargs, **llamacpp_tools
246+
messages=formatted_messages, tools=llamacpp_tools, **updated_generation_kwargs
214247
)
215248

216249
replies = []
250+
if not isinstance(response, dict):
251+
msg = f"Expected a dictionary response, got a different object: {response}"
252+
raise ValueError(msg)
253+
217254
for choice in response["choices"]:
218255
chat_message = self._convert_chat_completion_choice_to_chat_message(choice, response)
219256
replies.append(chat_message)
@@ -239,10 +276,10 @@ def _convert_chat_completion_choice_to_chat_message(
239276
except json.JSONDecodeError:
240277
logger.warning(
241278
"Llama.cpp returned a malformed JSON string for tool call arguments. This tool call "
242-
"will be skipped. Tool call ID: %s, Tool name: %s, Arguments: %s",
243-
llamacpp_tc["id"],
244-
llamacpp_tc["function"]["name"],
245-
arguments_str,
279+
"will be skipped. Tool call ID: {tc_id}, Tool name: {tc_name}, Arguments: {tc_args}",
280+
tc_id=llamacpp_tc["id"],
281+
tc_name=llamacpp_tc["function"]["name"],
282+
tc_args=arguments_str,
246283
)
247284

248285
meta = {

integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/generator.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Dict, List, Optional, Union
22

33
from haystack import component, logging
44

@@ -62,14 +62,16 @@ def __init__(
6262
self.n_batch = n_batch
6363
self.model_kwargs = model_kwargs
6464
self.generation_kwargs = generation_kwargs
65-
self.model = None
65+
self.model: Optional[Llama] = None
6666

6767
def warm_up(self):
6868
if self.model is None:
6969
self.model = Llama(**self.model_kwargs)
7070

7171
@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
72-
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
72+
def run(
73+
self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None
74+
) -> Dict[str, Union[List[str], List[Dict[str, Any]]]]:
7375
"""
7476
Run the text generation model on the given prompt.
7577
@@ -92,6 +94,10 @@ def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
9294
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
9395

9496
output = self.model.create_completion(prompt=prompt, **updated_generation_kwargs)
97+
if not isinstance(output, dict):
98+
msg = f"Expected a dictionary response, got a different object: {output}"
99+
raise ValueError(msg)
100+
95101
replies = [output["choices"][0]["text"]]
96102

97-
return {"replies": replies, "meta": [output]}
103+
return {"replies": replies, "meta": [dict(output.items())]}

integrations/llama_cpp/src/haystack_integrations/components/generators/py.typed

Whitespace-only changes.

0 commit comments

Comments
 (0)