Skip to content

Commit af07385

Browse files
sjrlanakin87julian-risch
authored
feat: Add usage when using HuggingFaceAPIChatGenerator with streaming (#9371)
* Small fix and update tests * Add usage support to streaming for HuggingFaceAPIChatGenerator * Add reno * try using provider='auto' * Undo provider * Fix unit tests * Update releasenotes/notes/add-usage-hf-api-chat-streaming-91fd04705f45d5b3.yaml Co-authored-by: Julian Risch <[email protected]> --------- Co-authored-by: anakin87 <[email protected]> Co-authored-by: Julian Risch <[email protected]>
1 parent 9ae76e1 commit af07385

File tree

3 files changed

+84
-28
lines changed

3 files changed

+84
-28
lines changed

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from huggingface_hub import (
2828
AsyncInferenceClient,
2929
ChatCompletionInputFunctionDefinition,
30+
ChatCompletionInputStreamOptions,
3031
ChatCompletionInputTool,
3132
ChatCompletionOutput,
3233
ChatCompletionOutputToolCall,
@@ -396,37 +397,52 @@ def _run_streaming(
396397
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
397398
):
398399
api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
399-
messages, stream=True, **generation_kwargs
400+
messages,
401+
stream=True,
402+
stream_options=ChatCompletionInputStreamOptions(include_usage=True),
403+
**generation_kwargs,
400404
)
401405

402406
generated_text = ""
403407
first_chunk_time = None
408+
finish_reason = None
409+
usage = None
404410
meta: Dict[str, Any] = {}
405411

406412
for chunk in api_output:
407-
# n is unused, so the API always returns only one choice
408-
# the argument is probably allowed for compatibility with OpenAI
409-
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
410-
choice = chunk.choices[0]
413+
# The chunk with usage returns an empty array for choices
414+
if len(chunk.choices) > 0:
415+
# n is unused, so the API always returns only one choice
416+
# the argument is probably allowed for compatibility with OpenAI
417+
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
418+
choice = chunk.choices[0]
411419

412-
text = choice.delta.content or ""
413-
generated_text += text
420+
text = choice.delta.content or ""
421+
generated_text += text
414422

415-
finish_reason = choice.finish_reason
416-
if finish_reason:
417-
meta["finish_reason"] = finish_reason
423+
if choice.finish_reason:
424+
finish_reason = choice.finish_reason
425+
426+
stream_chunk = StreamingChunk(text, meta)
427+
streaming_callback(stream_chunk)
428+
429+
if chunk.usage:
430+
usage = chunk.usage
418431

419432
if first_chunk_time is None:
420433
first_chunk_time = datetime.now().isoformat()
421434

422-
stream_chunk = StreamingChunk(text, meta)
423-
streaming_callback(stream_chunk)
435+
if usage:
436+
usage_dict = {"prompt_tokens": usage.prompt_tokens, "completion_tokens": usage.completion_tokens}
437+
else:
438+
usage_dict = {"prompt_tokens": 0, "completion_tokens": 0}
424439

425440
meta.update(
426441
{
427442
"model": self._client.model,
428443
"index": 0,
429-
"usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming
444+
"finish_reason": finish_reason,
445+
"usage": usage_dict,
430446
"completion_start_time": first_chunk_time,
431447
}
432448
)
@@ -477,34 +493,52 @@ async def _run_streaming_async(
477493
self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
478494
):
479495
api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
480-
messages, stream=True, **generation_kwargs
496+
messages,
497+
stream=True,
498+
stream_options=ChatCompletionInputStreamOptions(include_usage=True),
499+
**generation_kwargs,
481500
)
482501

483502
generated_text = ""
484503
first_chunk_time = None
504+
finish_reason = None
505+
usage = None
485506
meta: Dict[str, Any] = {}
486507

487508
async for chunk in api_output:
488-
choice = chunk.choices[0]
509+
# The chunk with usage returns an empty array for choices
510+
if len(chunk.choices) > 0:
511+
# n is unused, so the API always returns only one choice
512+
# the argument is probably allowed for compatibility with OpenAI
513+
# see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n
514+
choice = chunk.choices[0]
489515

490-
text = choice.delta.content or ""
491-
generated_text += text
516+
text = choice.delta.content or ""
517+
generated_text += text
492518

493-
finish_reason = choice.finish_reason
494-
if finish_reason:
495-
meta["finish_reason"] = finish_reason
519+
if choice.finish_reason:
520+
finish_reason = choice.finish_reason
521+
522+
stream_chunk = StreamingChunk(text, meta)
523+
await streaming_callback(stream_chunk) # type: ignore
524+
525+
if chunk.usage:
526+
usage = chunk.usage
496527

497528
if first_chunk_time is None:
498529
first_chunk_time = datetime.now().isoformat()
499530

500-
stream_chunk = StreamingChunk(text, meta)
501-
await streaming_callback(stream_chunk) # type: ignore
531+
if usage:
532+
usage_dict = {"prompt_tokens": usage.prompt_tokens, "completion_tokens": usage.completion_tokens}
533+
else:
534+
usage_dict = {"prompt_tokens": 0, "completion_tokens": 0}
502535

503536
meta.update(
504537
{
505538
"model": self._async_client.model,
506539
"index": 0,
507-
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
540+
"finish_reason": finish_reason,
541+
"usage": usage_dict,
508542
"completion_start_time": first_chunk_time,
509543
}
510544
)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
enhancements:
3+
- |
4+
When using HuggingFaceAPIChatGenerator with streaming, the returned ChatMessage now contains the number of prompt tokens and completion tokens in its meta data.
5+
Internally, the HuggingFaceAPIChatGenerator requests an additional streaming chunk that contains usage data.
6+
It then processes the usage streaming chunk to add usage meta data to the returned ChatMessage.

test/components/generators/chat/test_hugging_face_api.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ChatCompletionStreamOutput,
2323
ChatCompletionStreamOutputChoice,
2424
ChatCompletionStreamOutputDelta,
25+
ChatCompletionInputStreamOptions,
2526
)
2627
from huggingface_hub.errors import RepositoryNotFoundError
2728

@@ -441,7 +442,12 @@ def mock_iter(self):
441442

442443
# check kwargs passed to text_generation
443444
_, kwargs = mock_chat_completion.call_args
444-
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
445+
assert kwargs == {
446+
"stop": [],
447+
"stream": True,
448+
"max_tokens": 512,
449+
"stream_options": ChatCompletionInputStreamOptions(include_usage=True),
450+
}
445451

446452
# Assert that the streaming callback was called twice
447453
assert streaming_call_count == 2
@@ -505,7 +511,12 @@ def mock_iter(self):
505511

506512
# check kwargs passed to text_generation
507513
_, kwargs = mock_chat_completion.call_args
508-
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
514+
assert kwargs == {
515+
"stop": [],
516+
"stream": True,
517+
"max_tokens": 512,
518+
"stream_options": ChatCompletionInputStreamOptions(include_usage=True),
519+
}
509520

510521
# Assert that the streaming callback was called twice
511522
assert streaming_call_count == 2
@@ -717,9 +728,9 @@ def test_live_run_serverless_streaming(self):
717728
assert datetime.fromisoformat(response_meta["completion_start_time"]) <= datetime.now()
718729
assert "usage" in response_meta
719730
assert "prompt_tokens" in response_meta["usage"]
720-
assert response_meta["usage"]["prompt_tokens"] == 0
731+
assert response_meta["usage"]["prompt_tokens"] > 0
721732
assert "completion_tokens" in response_meta["usage"]
722-
assert response_meta["usage"]["completion_tokens"] == 0
733+
assert response_meta["usage"]["completion_tokens"] > 0
723734
assert response_meta["model"] == "microsoft/Phi-3.5-mini-instruct"
724735
assert response_meta["finish_reason"] is not None
725736

@@ -848,7 +859,12 @@ async def mock_aiter(self):
848859

849860
# check kwargs passed to chat_completion
850861
_, kwargs = mock_chat_completion_async.call_args
851-
assert kwargs == {"stop": [], "stream": True, "max_tokens": 512}
862+
assert kwargs == {
863+
"stop": [],
864+
"stream": True,
865+
"max_tokens": 512,
866+
"stream_options": ChatCompletionInputStreamOptions(include_usage=True),
867+
}
852868

853869
# Assert that the streaming callback was called twice
854870
assert streaming_call_count == 2

0 commit comments

Comments
 (0)