Skip to content

Commit 9f2c067

Browse files
authored
Small fix and update tests (#9370)
1 parent da5fc0f commit 9f2c067

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

haystack/components/generators/chat/hugging_face_api.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def _run_streaming(
401401

402402
generated_text = ""
403403
first_chunk_time = None
404+
meta: Dict[str, Any] = {}
404405

405406
for chunk in api_output:
406407
# n is unused, so the API always returns only one choice
@@ -412,8 +413,6 @@ def _run_streaming(
412413
generated_text += text
413414

414415
finish_reason = choice.finish_reason
415-
416-
meta: Dict[str, Any] = {}
417416
if finish_reason:
418417
meta["finish_reason"] = finish_reason
419418

@@ -426,15 +425,13 @@ def _run_streaming(
426425
meta.update(
427426
{
428427
"model": self._client.model,
429-
"finish_reason": finish_reason,
430428
"index": 0,
431429
"usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming
432430
"completion_start_time": first_chunk_time,
433431
}
434432
)
435433

436434
message = ChatMessage.from_assistant(text=generated_text, meta=meta)
437-
438435
return {"replies": [message]}
439436

440437
def _run_non_streaming(
@@ -485,6 +482,7 @@ async def _run_streaming_async(
485482

486483
generated_text = ""
487484
first_chunk_time = None
485+
meta: Dict[str, Any] = {}
488486

489487
async for chunk in api_output:
490488
choice = chunk.choices[0]
@@ -493,8 +491,6 @@ async def _run_streaming_async(
493491
generated_text += text
494492

495493
finish_reason = choice.finish_reason
496-
497-
meta: Dict[str, Any] = {}
498494
if finish_reason:
499495
meta["finish_reason"] = finish_reason
500496

@@ -507,7 +503,6 @@ async def _run_streaming_async(
507503
meta.update(
508504
{
509505
"model": self._async_client.model,
510-
"finish_reason": finish_reason,
511506
"index": 0,
512507
"usage": {"prompt_tokens": 0, "completion_tokens": 0},
513508
"completion_start_time": first_chunk_time,

test/components/generators/chat/test_hugging_face_api.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -671,9 +671,15 @@ def test_live_run_serverless(self):
671671
assert isinstance(response["replies"], list)
672672
assert len(response["replies"]) > 0
673673
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
674-
assert "usage" in response["replies"][0].meta
675-
assert "prompt_tokens" in response["replies"][0].meta["usage"]
676-
assert "completion_tokens" in response["replies"][0].meta["usage"]
674+
assert response["replies"][0].text is not None
675+
meta = response["replies"][0].meta
676+
assert "usage" in meta
677+
assert "prompt_tokens" in meta["usage"]
678+
assert meta["usage"]["prompt_tokens"] > 0
679+
assert "completion_tokens" in meta["usage"]
680+
assert meta["usage"]["completion_tokens"] > 0
681+
assert meta["model"] == "microsoft/Phi-3.5-mini-instruct"
682+
assert meta["finish_reason"] is not None
677683

678684
@pytest.mark.integration
679685
@pytest.mark.slow
@@ -701,13 +707,18 @@ def test_live_run_serverless_streaming(self):
701707
assert isinstance(response["replies"], list)
702708
assert len(response["replies"]) > 0
703709
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
710+
assert response["replies"][0].text is not None
704711

705712
response_meta = response["replies"][0].meta
706713
assert "completion_start_time" in response_meta
707714
assert datetime.fromisoformat(response_meta["completion_start_time"]) <= datetime.now()
708715
assert "usage" in response_meta
709716
assert "prompt_tokens" in response_meta["usage"]
717+
assert response_meta["usage"]["prompt_tokens"] == 0
710718
assert "completion_tokens" in response_meta["usage"]
719+
assert response_meta["usage"]["completion_tokens"] == 0
720+
assert response_meta["model"] == "microsoft/Phi-3.5-mini-instruct"
721+
assert response_meta["finish_reason"] is not None
711722

712723
@pytest.mark.integration
713724
@pytest.mark.slow
@@ -926,9 +937,16 @@ async def test_live_run_async_serverless(self):
926937
assert isinstance(response["replies"], list)
927938
assert len(response["replies"]) > 0
928939
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
929-
assert "usage" in response["replies"][0].meta
930-
assert "prompt_tokens" in response["replies"][0].meta["usage"]
931-
assert "completion_tokens" in response["replies"][0].meta["usage"]
940+
assert response["replies"][0].text is not None
941+
942+
meta = response["replies"][0].meta
943+
assert "usage" in meta
944+
assert "prompt_tokens" in meta["usage"]
945+
assert meta["usage"]["prompt_tokens"] > 0
946+
assert "completion_tokens" in meta["usage"]
947+
assert meta["usage"]["completion_tokens"] > 0
948+
assert meta["model"] == "microsoft/Phi-3.5-mini-instruct"
949+
assert meta["finish_reason"] is not None
932950
finally:
933951
await generator._async_client.close()
934952

0 commit comments

Comments
 (0)