|
27 | 27 | from huggingface_hub import (
|
28 | 28 | AsyncInferenceClient,
|
29 | 29 | ChatCompletionInputFunctionDefinition,
|
| 30 | + ChatCompletionInputStreamOptions, |
30 | 31 | ChatCompletionInputTool,
|
31 | 32 | ChatCompletionOutput,
|
32 | 33 | ChatCompletionOutputToolCall,
|
@@ -396,37 +397,52 @@ def _run_streaming(
|
396 | 397 | self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
|
397 | 398 | ):
|
398 | 399 | api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion(
|
399 |
| - messages, stream=True, **generation_kwargs |
| 400 | + messages, |
| 401 | + stream=True, |
| 402 | + stream_options=ChatCompletionInputStreamOptions(include_usage=True), |
| 403 | + **generation_kwargs, |
400 | 404 | )
|
401 | 405 |
|
402 | 406 | generated_text = ""
|
403 | 407 | first_chunk_time = None
|
| 408 | + finish_reason = None |
| 409 | + usage = None |
404 | 410 | meta: Dict[str, Any] = {}
|
405 | 411 |
|
406 | 412 | for chunk in api_output:
|
407 |
| - # n is unused, so the API always returns only one choice |
408 |
| - # the argument is probably allowed for compatibility with OpenAI |
409 |
| - # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n |
410 |
| - choice = chunk.choices[0] |
| 413 | + # The chunk with usage returns an empty array for choices |
| 414 | + if len(chunk.choices) > 0: |
| 415 | + # n is unused, so the API always returns only one choice |
| 416 | + # the argument is probably allowed for compatibility with OpenAI |
| 417 | + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n |
| 418 | + choice = chunk.choices[0] |
411 | 419 |
|
412 |
| - text = choice.delta.content or "" |
413 |
| - generated_text += text |
| 420 | + text = choice.delta.content or "" |
| 421 | + generated_text += text |
414 | 422 |
|
415 |
| - finish_reason = choice.finish_reason |
416 |
| - if finish_reason: |
417 |
| - meta["finish_reason"] = finish_reason |
| 423 | + if choice.finish_reason: |
| 424 | + finish_reason = choice.finish_reason |
| 425 | + |
| 426 | + stream_chunk = StreamingChunk(text, meta) |
| 427 | + streaming_callback(stream_chunk) |
| 428 | + |
| 429 | + if chunk.usage: |
| 430 | + usage = chunk.usage |
418 | 431 |
|
419 | 432 | if first_chunk_time is None:
|
420 | 433 | first_chunk_time = datetime.now().isoformat()
|
421 | 434 |
|
422 |
| - stream_chunk = StreamingChunk(text, meta) |
423 |
| - streaming_callback(stream_chunk) |
| 435 | + if usage: |
| 436 | + usage_dict = {"prompt_tokens": usage.prompt_tokens, "completion_tokens": usage.completion_tokens} |
| 437 | + else: |
| 438 | + usage_dict = {"prompt_tokens": 0, "completion_tokens": 0} |
424 | 439 |
|
425 | 440 | meta.update(
|
426 | 441 | {
|
427 | 442 | "model": self._client.model,
|
428 | 443 | "index": 0,
|
429 |
| - "usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming |
| 444 | + "finish_reason": finish_reason, |
| 445 | + "usage": usage_dict, |
430 | 446 | "completion_start_time": first_chunk_time,
|
431 | 447 | }
|
432 | 448 | )
|
@@ -477,34 +493,52 @@ async def _run_streaming_async(
|
477 | 493 | self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any], streaming_callback: StreamingCallbackT
|
478 | 494 | ):
|
479 | 495 | api_output: AsyncIterable[ChatCompletionStreamOutput] = await self._async_client.chat_completion(
|
480 |
| - messages, stream=True, **generation_kwargs |
| 496 | + messages, |
| 497 | + stream=True, |
| 498 | + stream_options=ChatCompletionInputStreamOptions(include_usage=True), |
| 499 | + **generation_kwargs, |
481 | 500 | )
|
482 | 501 |
|
483 | 502 | generated_text = ""
|
484 | 503 | first_chunk_time = None
|
| 504 | + finish_reason = None |
| 505 | + usage = None |
485 | 506 | meta: Dict[str, Any] = {}
|
486 | 507 |
|
487 | 508 | async for chunk in api_output:
|
488 |
| - choice = chunk.choices[0] |
| 509 | + # The chunk with usage returns an empty array for choices |
| 510 | + if len(chunk.choices) > 0: |
| 511 | + # n is unused, so the API always returns only one choice |
| 512 | + # the argument is probably allowed for compatibility with OpenAI |
| 513 | + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n |
| 514 | + choice = chunk.choices[0] |
489 | 515 |
|
490 |
| - text = choice.delta.content or "" |
491 |
| - generated_text += text |
| 516 | + text = choice.delta.content or "" |
| 517 | + generated_text += text |
492 | 518 |
|
493 |
| - finish_reason = choice.finish_reason |
494 |
| - if finish_reason: |
495 |
| - meta["finish_reason"] = finish_reason |
| 519 | + if choice.finish_reason: |
| 520 | + finish_reason = choice.finish_reason |
| 521 | + |
| 522 | + stream_chunk = StreamingChunk(text, meta) |
| 523 | + await streaming_callback(stream_chunk) # type: ignore |
| 524 | + |
| 525 | + if chunk.usage: |
| 526 | + usage = chunk.usage |
496 | 527 |
|
497 | 528 | if first_chunk_time is None:
|
498 | 529 | first_chunk_time = datetime.now().isoformat()
|
499 | 530 |
|
500 |
| - stream_chunk = StreamingChunk(text, meta) |
501 |
| - await streaming_callback(stream_chunk) # type: ignore |
| 531 | + if usage: |
| 532 | + usage_dict = {"prompt_tokens": usage.prompt_tokens, "completion_tokens": usage.completion_tokens} |
| 533 | + else: |
| 534 | + usage_dict = {"prompt_tokens": 0, "completion_tokens": 0} |
502 | 535 |
|
503 | 536 | meta.update(
|
504 | 537 | {
|
505 | 538 | "model": self._async_client.model,
|
506 | 539 | "index": 0,
|
507 |
| - "usage": {"prompt_tokens": 0, "completion_tokens": 0}, |
| 540 | + "finish_reason": finish_reason, |
| 541 | + "usage": usage_dict, |
508 | 542 | "completion_start_time": first_chunk_time,
|
509 | 543 | }
|
510 | 544 | )
|
|
0 commit comments