Skip to content

Commit f4d9c2b

Browse files
authored
fix: Make the HuggingFaceLocalChatGenerator compatible with the new ChatMessage; serialize chat_template (#8663)
* message conversion function * hfapi w tools * right test file + hf_hub version * release note * fix for new chatmessage; serialize chat_template * feedback
1 parent 2bc58d2 commit f4d9c2b

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

haystack/components/generators/chat/hugging_face_local.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from haystack.utils.hf import ( # pylint: disable=ungrouped-imports
2626
HFTokenStreamingHandler,
2727
StopWordsCriteria,
28+
convert_message_to_hf_format,
2829
deserialize_hf_model_kwargs,
2930
serialize_hf_model_kwargs,
3031
)
@@ -201,6 +202,7 @@ def to_dict(self) -> Dict[str, Any]:
201202
generation_kwargs=self.generation_kwargs,
202203
streaming_callback=callback_name,
203204
token=self.token.to_dict() if self.token else None,
205+
chat_template=self.chat_template,
204206
)
205207

206208
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
@@ -270,9 +272,11 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
270272
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
271273
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)
272274

275+
hf_messages = [convert_message_to_hf_format(message) for message in messages]
276+
273277
# Prepare the prompt for the model
274278
prepared_prompt = tokenizer.apply_chat_template(
275-
messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
279+
hf_messages, tokenize=False, chat_template=self.chat_template, add_generation_prompt=True
276280
)
277281

278282
# Avoid some unnecessary warnings in the generation pipeline call
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
fixes:
3+
- |
4+
Make the HuggingFaceLocalChatGenerator compatible with the new ChatMessage format, by converting the messages to
5+
the format expected by Hugging Face.
6+
7+
Serialize the chat_template parameter.

test/components/generators/chat/test_hugging_face_local.py

+37
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def test_to_dict(self, model_info_mock):
135135
generation_kwargs={"n": 5},
136136
stop_words=["stop", "words"],
137137
streaming_callback=lambda x: x,
138+
chat_template="irrelevant",
138139
)
139140

140141
# Call the to_dict method
@@ -146,13 +147,15 @@ def test_to_dict(self, model_info_mock):
146147
assert init_params["huggingface_pipeline_kwargs"]["model"] == "NousResearch/Llama-2-7b-chat-hf"
147148
assert "token" not in init_params["huggingface_pipeline_kwargs"]
148149
assert init_params["generation_kwargs"] == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
150+
assert init_params["chat_template"] == "irrelevant"
149151

150152
def test_from_dict(self, model_info_mock):
151153
generator = HuggingFaceLocalChatGenerator(
152154
model="NousResearch/Llama-2-7b-chat-hf",
153155
generation_kwargs={"n": 5},
154156
stop_words=["stop", "words"],
155157
streaming_callback=streaming_callback_handler,
158+
chat_template="irrelevant",
156159
)
157160
# Call the to_dict method
158161
result = generator.to_dict()
@@ -162,6 +165,7 @@ def test_from_dict(self, model_info_mock):
162165
assert generator_2.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
163166
assert generator_2.generation_kwargs == {"max_new_tokens": 512, "n": 5, "stop_sequences": ["stop", "words"]}
164167
assert generator_2.streaming_callback is streaming_callback_handler
168+
assert generator_2.chat_template == "irrelevant"
165169

166170
@patch("haystack.components.generators.chat.hugging_face_local.pipeline")
167171
def test_warm_up(self, pipeline_mock, monkeypatch):
@@ -218,3 +222,36 @@ def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipel
218222
chat_message = results["replies"][0]
219223
assert chat_message.is_from(ChatRole.ASSISTANT)
220224
assert chat_message.text == "Berlin is cool"
225+
226+
@patch("haystack.components.generators.chat.hugging_face_local.convert_message_to_hf_format")
227+
def test_messages_conversion_is_called(self, mock_convert, model_info_mock):
228+
generator = HuggingFaceLocalChatGenerator(model="fake-model")
229+
230+
messages = [ChatMessage.from_user("Hello"), ChatMessage.from_assistant("Hi there")]
231+
232+
with patch.object(generator, "pipeline") as mock_pipeline:
233+
mock_pipeline.tokenizer.apply_chat_template.return_value = "test prompt"
234+
mock_pipeline.return_value = [{"generated_text": "test response"}]
235+
236+
generator.warm_up()
237+
generator.run(messages)
238+
239+
assert mock_convert.call_count == 2
240+
mock_convert.assert_any_call(messages[0])
241+
mock_convert.assert_any_call(messages[1])
242+
243+
@pytest.mark.integration
244+
@pytest.mark.flaky(reruns=3, reruns_delay=10)
245+
def test_live_run(self):
246+
messages = [ChatMessage.from_user("Please create a summary about the following topic: Climate change")]
247+
248+
llm = HuggingFaceLocalChatGenerator(
249+
model="Qwen/Qwen2.5-0.5B-Instruct", generation_kwargs={"max_new_tokens": 50}
250+
)
251+
llm.warm_up()
252+
253+
result = llm.run(messages)
254+
255+
assert "replies" in result
256+
assert isinstance(result["replies"][0], ChatMessage)
257+
assert "climate change" in result["replies"][0].text.lower()

0 commit comments

Comments
 (0)