Skip to content

Commit 3ef8c08

Browse files
authored
fix: OpenAIChatGenerator and OpenAIGenerator crashing when streaming with usage tracking (#8558)
* Fix OpenAIGenerator crashing with tracking usage with streaming enabled * Fix OpenAIChatGenerator crashing with tracking usage with streaming enabled * Add release notes * Fix linting
1 parent 3a30ee3 commit 3ef8c08

File tree

5 files changed

+113
-29
lines changed

5 files changed

+113
-29
lines changed

haystack/components/generators/chat/openai.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -222,19 +222,19 @@ def run(
222222
if num_responses > 1:
223223
raise ValueError("Cannot stream multiple responses, please set n=1.")
224224
chunks: List[StreamingChunk] = []
225-
chunk = None
225+
completion_chunk = None
226226
_first_token = True
227227

228228
# pylint: disable=not-an-iterable
229-
for chunk in chat_completion:
230-
if chunk.choices and streaming_callback:
231-
chunk_delta: StreamingChunk = self._build_chunk(chunk)
229+
for completion_chunk in chat_completion:
230+
if completion_chunk.choices and streaming_callback:
231+
chunk_delta: StreamingChunk = self._build_chunk(completion_chunk)
232232
if _first_token:
233233
_first_token = False
234234
chunk_delta.meta["completion_start_time"] = datetime.now().isoformat()
235235
chunks.append(chunk_delta)
236236
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
237-
completions = [self._connect_chunks(chunk, chunks)]
237+
completions = [self._create_message_from_chunks(completion_chunk, chunks)]
238238
# if streaming is disabled, the completion is a ChatCompletion
239239
elif isinstance(chat_completion, ChatCompletion):
240240
completions = [self._build_message(chat_completion, choice) for choice in chat_completion.choices]
@@ -245,18 +245,20 @@ def run(
245245

246246
return {"replies": completions}
247247

248-
def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
248+
def _create_message_from_chunks(
249+
self, completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk]
250+
) -> ChatMessage:
249251
"""
250-
Connects the streaming chunks into a single ChatMessage.
252+
Creates a single ChatMessage from the streamed chunks. Some data is retrieved from the completion chunk.
251253
252-
:param chunk: The last chunk returned by the OpenAI API.
253-
:param chunks: The list of all chunks returned by the OpenAI API.
254+
:param completion_chunk: The last completion chunk returned by the OpenAI API.
255+
:param streamed_chunks: The list of all chunks returned by the OpenAI API.
254256
"""
255-
is_tools_call = bool(chunks[0].meta.get("tool_calls"))
256-
is_function_call = bool(chunks[0].meta.get("function_call"))
257+
is_tools_call = bool(streamed_chunks[0].meta.get("tool_calls"))
258+
is_function_call = bool(streamed_chunks[0].meta.get("function_call"))
257259
# if it's a tool call or function call, we need to build the payload dict from all the chunks
258260
if is_tools_call or is_function_call:
259-
tools_len = 1 if is_function_call else len(chunks[0].meta.get("tool_calls", []))
261+
tools_len = 1 if is_function_call else len(streamed_chunks[0].meta.get("tool_calls", []))
260262
# don't change this approach of building payload dicts, otherwise mypy will complain
261263
p_def: Dict[str, Any] = {
262264
"index": 0,
@@ -265,7 +267,7 @@ def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessa
265267
"type": "function",
266268
}
267269
payloads = [copy.deepcopy(p_def) for _ in range(tools_len)]
268-
for chunk_payload in chunks:
270+
for chunk_payload in streamed_chunks:
269271
if is_tools_call:
270272
deltas = chunk_payload.meta.get("tool_calls") or []
271273
else:
@@ -287,16 +289,18 @@ def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessa
287289
else:
288290
total_content = ""
289291
total_meta = {}
290-
for streaming_chunk in chunks:
292+
for streaming_chunk in streamed_chunks:
291293
total_content += streaming_chunk.content
292294
total_meta.update(streaming_chunk.meta)
293295
complete_response = ChatMessage.from_assistant(total_content, meta=total_meta)
296+
finish_reason = streamed_chunks[-1].meta["finish_reason"]
294297
complete_response.meta.update(
295298
{
296-
"model": chunk.model,
299+
"model": completion_chunk.model,
297300
"index": 0,
298-
"finish_reason": chunk.choices[0].finish_reason,
299-
"usage": {}, # we don't have usage data for streaming responses
301+
"finish_reason": finish_reason,
302+
# Usage is available when streaming only if the user explicitly requests it
303+
"usage": dict(completion_chunk.usage or {}),
300304
}
301305
)
302306
return complete_response

haystack/components/generators/openai.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class OpenAIGenerator:
4949
```
5050
"""
5151

52-
def __init__(
52+
def __init__( # pylint: disable=too-many-positional-arguments
5353
self,
5454
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
5555
model: str = "gpt-4o-mini",
@@ -222,15 +222,17 @@ def run(
222222
if num_responses > 1:
223223
raise ValueError("Cannot stream multiple responses, please set n=1.")
224224
chunks: List[StreamingChunk] = []
225-
chunk = None
225+
completion_chunk: Optional[ChatCompletionChunk] = None
226226

227227
# pylint: disable=not-an-iterable
228-
for chunk in completion:
229-
if chunk.choices and streaming_callback:
230-
chunk_delta: StreamingChunk = self._build_chunk(chunk)
228+
for completion_chunk in completion:
229+
if completion_chunk.choices and streaming_callback:
230+
chunk_delta: StreamingChunk = self._build_chunk(completion_chunk)
231231
chunks.append(chunk_delta)
232232
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
233-
completions = [self._connect_chunks(chunk, chunks)]
233+
# Makes type checkers happy
234+
assert completion_chunk is not None
235+
completions = [self._create_message_from_chunks(completion_chunk, chunks)]
234236
elif isinstance(completion, ChatCompletion):
235237
completions = [self._build_message(completion, choice) for choice in completion.choices]
236238

@@ -244,17 +246,21 @@ def run(
244246
}
245247

246248
@staticmethod
247-
def _connect_chunks(chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
249+
def _create_message_from_chunks(
250+
completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk]
251+
) -> ChatMessage:
248252
"""
249-
Connects the streaming chunks into a single ChatMessage.
253+
Creates a single ChatMessage from the streamed chunks. Some data is retrieved from the completion chunk.
250254
"""
251-
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks]))
255+
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in streamed_chunks]))
256+
finish_reason = streamed_chunks[-1].meta["finish_reason"]
252257
complete_response.meta.update(
253258
{
254-
"model": chunk.model,
259+
"model": completion_chunk.model,
255260
"index": 0,
256-
"finish_reason": chunk.choices[0].finish_reason,
257-
"usage": {}, # we don't have usage data for streaming responses
261+
"finish_reason": finish_reason,
262+
# Usage is available when streaming only if the user explicitly requests it
263+
"usage": dict(completion_chunk.usage or {}),
258264
}
259265
)
260266
return complete_response
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Fix `OpenAIChatGenerator` and `OpenAIGenerator` crashing when using a `streaming_callback` and `generation_kwargs` contain `{"stream_options": {"include_usage": True}}`.

test/components/generators/chat/test_openai.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,3 +329,34 @@ def __call__(self, chunk: StreamingChunk) -> None:
329329

330330
assert callback.counter > 1
331331
assert "Paris" in callback.responses
332+
333+
@pytest.mark.skipif(
334+
not os.environ.get("OPENAI_API_KEY", None),
335+
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
336+
)
337+
@pytest.mark.integration
338+
def test_live_run_streaming_with_include_usage(self):
339+
class Callback:
340+
def __init__(self):
341+
self.responses = ""
342+
self.counter = 0
343+
344+
def __call__(self, chunk: StreamingChunk) -> None:
345+
self.counter += 1
346+
self.responses += chunk.content if chunk.content else ""
347+
348+
callback = Callback()
349+
component = OpenAIChatGenerator(
350+
streaming_callback=callback, generation_kwargs={"stream_options": {"include_usage": True}}
351+
)
352+
results = component.run([ChatMessage.from_user("What's the capital of France?")])
353+
354+
assert len(results["replies"]) == 1
355+
message: ChatMessage = results["replies"][0]
356+
assert "Paris" in message.content
357+
358+
assert "gpt-4o-mini" in message.meta["model"]
359+
assert message.meta["finish_reason"] == "stop"
360+
361+
assert callback.counter > 1
362+
assert "Paris" in callback.responses

test/components/generators/test_openai.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,42 @@ def test_run_with_system_prompt(self):
333333
system_prompt="You answer in German, regardless of the language on which a question is asked.",
334334
)
335335
assert "pythagoras" in result["replies"][0].lower()
336+
337+
@pytest.mark.skipif(
338+
not os.environ.get("OPENAI_API_KEY", None),
339+
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
340+
)
341+
@pytest.mark.integration
342+
def test_live_run_streaming_with_include_usage(self):
343+
class Callback:
344+
def __init__(self):
345+
self.responses = ""
346+
self.counter = 0
347+
348+
def __call__(self, chunk: StreamingChunk) -> None:
349+
self.counter += 1
350+
self.responses += chunk.content if chunk.content else ""
351+
352+
callback = Callback()
353+
component = OpenAIGenerator(
354+
streaming_callback=callback, generation_kwargs={"stream_options": {"include_usage": True}}
355+
)
356+
results = component.run("What's the capital of France?")
357+
358+
assert len(results["replies"]) == 1
359+
assert len(results["meta"]) == 1
360+
response: str = results["replies"][0]
361+
assert "Paris" in response
362+
363+
metadata = results["meta"][0]
364+
365+
assert "gpt-4o-mini" in metadata["model"]
366+
assert metadata["finish_reason"] == "stop"
367+
368+
assert "usage" in metadata
369+
assert "prompt_tokens" in metadata["usage"] and metadata["usage"]["prompt_tokens"] > 0
370+
assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0
371+
assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0
372+
373+
assert callback.counter > 1
374+
assert "Paris" in callback.responses

0 commit comments

Comments
 (0)