Skip to content

Commit bf79f04

Browse files
authored
feat: support streaming_callback as run param for HF Chat generators (#8763)
* feat: support streaming_callback as run param for HF Chat generators * add tests
1 parent c3d0643 commit bf79f04

File tree

5 files changed

+133
-7
lines changed

5 files changed

+133
-7
lines changed

Diff for: haystack/components/generators/chat/hugging_face_api.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def run(
220220
messages: List[ChatMessage],
221221
generation_kwargs: Optional[Dict[str, Any]] = None,
222222
tools: Optional[List[Tool]] = None,
223+
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
223224
):
224225
"""
225226
Invoke the text generation inference based on the provided messages and generation parameters.
@@ -231,6 +232,9 @@ def run(
231232
:param tools:
232233
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
233234
during component initialization.
235+
:param streaming_callback:
236+
An optional callable for handling streaming responses. If set, it will override the `streaming_callback`
237+
parameter set during component initialization.
234238
:returns: A dictionary with the following keys:
235239
- `replies`: A list containing the generated responses as ChatMessage objects.
236240
"""
@@ -245,16 +249,22 @@ def run(
245249
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
246250
_check_duplicate_tool_names(tools)
247251

248-
if self.streaming_callback:
249-
return self._run_streaming(formatted_messages, generation_kwargs)
252+
streaming_callback = streaming_callback or self.streaming_callback
253+
if streaming_callback:
254+
return self._run_streaming(formatted_messages, generation_kwargs, streaming_callback)
250255

251256
hf_tools = None
252257
if tools:
253258
hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools]
254259

255260
return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools)
256261

257-
def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]):
262+
def _run_streaming(
263+
self,
264+
messages: List[Dict[str, str]],
265+
generation_kwargs: Dict[str, Any],
266+
streaming_callback: Callable[[StreamingChunk], None],
267+
):
258268
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
259269
messages, stream=True, **generation_kwargs
260270
)
@@ -282,7 +292,7 @@ def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict
282292
first_chunk_time = datetime.now().isoformat()
283293

284294
stream_chunk = StreamingChunk(text, meta)
285-
self.streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
295+
streaming_callback(stream_chunk)
286296

287297
meta.update(
288298
{

Diff for: haystack/components/generators/chat/hugging_face_local.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,18 @@ def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceLocalChatGenerator":
233233
return default_from_dict(cls, data)
234234

235235
@component.output_types(replies=List[ChatMessage])
236-
def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None):
236+
def run(
237+
self,
238+
messages: List[ChatMessage],
239+
generation_kwargs: Optional[Dict[str, Any]] = None,
240+
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
241+
):
237242
"""
238243
Invoke text generation inference based on the provided messages and generation parameters.
239244
240245
:param messages: A list of ChatMessage objects representing the input messages.
241246
:param generation_kwargs: Additional keyword arguments for text generation.
247+
:param streaming_callback: An optional callable for handling streaming responses.
242248
:returns:
243249
A list containing the generated responses as ChatMessage instances.
244250
"""
@@ -259,7 +265,8 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
259265
if stop_words_criteria:
260266
generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stop_words_criteria])
261267

262-
if self.streaming_callback:
268+
streaming_callback = streaming_callback or self.streaming_callback
269+
if streaming_callback:
263270
num_responses = generation_kwargs.get("num_return_sequences", 1)
264271
if num_responses > 1:
265272
msg = (
@@ -270,7 +277,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
270277
logger.warning(msg, num_responses=num_responses)
271278
generation_kwargs["num_return_sequences"] = 1
272279
# streamer parameter hooks into HF streaming, HFTokenStreamingHandler is an adapter to our streaming
273-
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, self.streaming_callback, stop_words)
280+
generation_kwargs["streamer"] = HFTokenStreamingHandler(tokenizer, streaming_callback, stop_words)
274281

275282
hf_messages = [convert_message_to_hf_format(message) for message in messages]
276283

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
enhancements:
3+
- |
4+
Streaming callback run param support for HF chat generators.

Diff for: test/components/generators/chat/test_hugging_face_api.py

+64
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,70 @@ def mock_iter(self):
395395
assert len(response["replies"]) > 0
396396
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
397397

398+
def test_run_with_streaming_callback_in_run_method(
399+
self, mock_check_valid_model, mock_chat_completion, chat_messages
400+
):
401+
streaming_call_count = 0
402+
403+
# Define the streaming callback function
404+
def streaming_callback_fn(chunk: StreamingChunk):
405+
nonlocal streaming_call_count
406+
streaming_call_count += 1
407+
assert isinstance(chunk, StreamingChunk)
408+
409+
generator = HuggingFaceAPIChatGenerator(
410+
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
411+
api_params={"model": "meta-llama/Llama-2-13b-chat-hf"},
412+
)
413+
414+
# Create a fake streamed response
415+
# self needed here, don't remove
416+
def mock_iter(self):
417+
yield ChatCompletionStreamOutput(
418+
choices=[
419+
ChatCompletionStreamOutputChoice(
420+
delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"),
421+
index=0,
422+
finish_reason=None,
423+
)
424+
],
425+
id="some_id",
426+
model="some_model",
427+
system_fingerprint="some_fingerprint",
428+
created=1710498504,
429+
)
430+
431+
yield ChatCompletionStreamOutput(
432+
choices=[
433+
ChatCompletionStreamOutputChoice(
434+
delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length"
435+
)
436+
],
437+
id="some_id",
438+
model="some_model",
439+
system_fingerprint="some_fingerprint",
440+
created=1710498504,
441+
)
442+
443+
mock_response = Mock(**{"__iter__": mock_iter})
444+
mock_chat_completion.return_value = mock_response
445+
446+
# Generate text response with streaming callback
447+
response = generator.run(chat_messages, streaming_callback=streaming_callback_fn)
448+
449+
# check kwargs passed to text_generation
450+
_, kwargs = mock_chat_completion.call_args
451+
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
452+
453+
# Assert that the streaming callback was called twice
454+
assert streaming_call_count == 2
455+
456+
# Assert that the response contains the generated replies
457+
assert "replies" in response
458+
assert isinstance(response["replies"], list)
459+
assert len(response["replies"]) > 0
460+
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
461+
398462
def test_run_fail_with_tools_and_streaming(self, tools, mock_check_valid_model):
399463
component = HuggingFaceAPIChatGenerator(
400464
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,

Diff for: test/components/generators/chat/test_hugging_face_local.py

+41
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44
from unittest.mock import Mock, patch
55

6+
from haystack.dataclasses.streaming_chunk import StreamingChunk
67
import pytest
78
from transformers import PreTrainedTokenizer
89

@@ -233,6 +234,46 @@ def test_run_with_custom_generation_parameters(self, model_info_mock, mock_pipel
233234
assert chat_message.is_from(ChatRole.ASSISTANT)
234235
assert chat_message.text == "Berlin is cool"
235236

237+
def test_run_with_streaming_callback(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
238+
# Define the streaming callback function
239+
def streaming_callback_fn(chunk: StreamingChunk): ...
240+
241+
generator = HuggingFaceLocalChatGenerator(
242+
model="meta-llama/Llama-2-13b-chat-hf", streaming_callback=streaming_callback_fn
243+
)
244+
245+
# Use the mocked pipeline from the fixture and simulate warm_up
246+
generator.pipeline = mock_pipeline_tokenizer
247+
248+
results = generator.run(messages=chat_messages)
249+
250+
assert "replies" in results
251+
assert isinstance(results["replies"][0], ChatMessage)
252+
chat_message = results["replies"][0]
253+
assert chat_message.is_from(ChatRole.ASSISTANT)
254+
assert chat_message.text == "Berlin is cool"
255+
generator.pipeline.assert_called_once()
256+
generator.pipeline.call_args[1]["streamer"].token_handler == streaming_callback_fn
257+
258+
def test_run_with_streaming_callback_in_run_method(self, model_info_mock, mock_pipeline_tokenizer, chat_messages):
259+
# Define the streaming callback function
260+
def streaming_callback_fn(chunk: StreamingChunk): ...
261+
262+
generator = HuggingFaceLocalChatGenerator(model="meta-llama/Llama-2-13b-chat-hf")
263+
264+
# Use the mocked pipeline from the fixture and simulate warm_up
265+
generator.pipeline = mock_pipeline_tokenizer
266+
267+
results = generator.run(messages=chat_messages, streaming_callback=streaming_callback_fn)
268+
269+
assert "replies" in results
270+
assert isinstance(results["replies"][0], ChatMessage)
271+
chat_message = results["replies"][0]
272+
assert chat_message.is_from(ChatRole.ASSISTANT)
273+
assert chat_message.text == "Berlin is cool"
274+
generator.pipeline.assert_called_once()
275+
generator.pipeline.call_args[1]["streamer"].token_handler == streaming_callback_fn
276+
236277
@patch("haystack.components.generators.chat.hugging_face_local.convert_message_to_hf_format")
237278
def test_messages_conversion_is_called(self, mock_convert, model_info_mock):
238279
generator = HuggingFaceLocalChatGenerator(model="fake-model")

0 commit comments

Comments
 (0)