Skip to content

Commit 7b4d9ba

Browse files
authored
feat: introduce class method to create ChatMessage from the OpenAI dictionary format (#8670)
* add ChatMessage.from_openai_dict_format * remove print * release note * improve docstring * separate validation logic * rm obvious comment
1 parent 3ea128c commit 7b4d9ba

File tree

4 files changed

+163
-4
lines changed

4 files changed

+163
-4
lines changed

Diff for: haystack/components/generators/chat/openai.py

-4
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,6 @@ def _prepare_api_call( # noqa: PLR0913
308308
}
309309

310310
def _handle_stream_response(self, chat_completion: Stream, callback: StreamingCallbackT) -> List[ChatMessage]:
311-
print("callback")
312-
print(callback)
313-
print("-" * 100)
314-
315311
chunks: List[StreamingChunk] = []
316312
chunk = None
317313

Diff for: haystack/dataclasses/chat_message.py

+78
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,81 @@ def to_openai_dict_format(self) -> Dict[str, Any]:
426426
)
427427
openai_msg["tool_calls"] = openai_tool_calls
428428
return openai_msg
429+
430+
@staticmethod
431+
def _validate_openai_message(message: Dict[str, Any]) -> None:
432+
"""
433+
Validate that a message dictionary follows OpenAI's Chat API format.
434+
435+
:param message: The message dictionary to validate
436+
:raises ValueError: If the message format is invalid
437+
"""
438+
if "role" not in message:
439+
raise ValueError("The `role` field is required in the message dictionary.")
440+
441+
role = message["role"]
442+
content = message.get("content")
443+
tool_calls = message.get("tool_calls")
444+
445+
if role not in ["assistant", "user", "system", "developer", "tool"]:
446+
raise ValueError(f"Unsupported role: {role}")
447+
448+
if role == "assistant":
449+
if not content and not tool_calls:
450+
raise ValueError("For assistant messages, either `content` or `tool_calls` must be present.")
451+
if tool_calls:
452+
for tc in tool_calls:
453+
if "function" not in tc:
454+
raise ValueError("Tool calls must contain the `function` field")
455+
elif not content:
456+
raise ValueError(f"The `content` field is required for {role} messages.")
457+
458+
@classmethod
459+
def from_openai_dict_format(cls, message: Dict[str, Any]) -> "ChatMessage":
460+
"""
461+
Create a ChatMessage from a dictionary in the format expected by OpenAI's Chat API.
462+
463+
NOTE: While OpenAI's API requires `tool_call_id` in both tool calls and tool messages, this method
464+
accepts messages without it to support shallow OpenAI-compatible APIs.
465+
If you plan to use the resulting ChatMessage with OpenAI, you must include `tool_call_id` or you'll
466+
encounter validation errors.
467+
468+
:param message:
469+
The OpenAI dictionary to build the ChatMessage object.
470+
:returns:
471+
The created ChatMessage object.
472+
473+
:raises ValueError:
474+
If the message dictionary is missing required fields.
475+
"""
476+
cls._validate_openai_message(message)
477+
478+
role = message["role"]
479+
content = message.get("content")
480+
name = message.get("name")
481+
tool_calls = message.get("tool_calls")
482+
tool_call_id = message.get("tool_call_id")
483+
484+
if role == "assistant":
485+
haystack_tool_calls = None
486+
if tool_calls:
487+
haystack_tool_calls = []
488+
for tc in tool_calls:
489+
haystack_tc = ToolCall(
490+
id=tc.get("id"),
491+
tool_name=tc["function"]["name"],
492+
arguments=json.loads(tc["function"]["arguments"]),
493+
)
494+
haystack_tool_calls.append(haystack_tc)
495+
return cls.from_assistant(text=content, name=name, tool_calls=haystack_tool_calls)
496+
497+
assert content is not None # ensured by _validate_openai_message, but we need to make mypy happy
498+
499+
if role == "user":
500+
return cls.from_user(text=content, name=name)
501+
if role in ["system", "developer"]:
502+
return cls.from_system(text=content, name=name)
503+
504+
return cls.from_tool(
505+
tool_result=content, origin=ToolCall(id=tool_call_id, tool_name="", arguments={}), error=False
506+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
enhancements:
3+
- |
4+
Add the `from_openai_dict_format` class method to the `ChatMessage` class. It allows you to create a `ChatMessage`
5+
from a dictionary in the format expected by OpenAI's Chat API.

Diff for: test/dataclasses/test_chat_message.py

+80
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,86 @@ def test_to_openai_dict_format_invalid():
288288
message.to_openai_dict_format()
289289

290290

291+
def test_from_openai_dict_format_user_message():
292+
openai_msg = {"role": "user", "content": "Hello, how are you?", "name": "John"}
293+
message = ChatMessage.from_openai_dict_format(openai_msg)
294+
assert message.role.value == "user"
295+
assert message.text == "Hello, how are you?"
296+
assert message.name == "John"
297+
298+
299+
def test_from_openai_dict_format_system_message():
300+
openai_msg = {"role": "system", "content": "You are a helpful assistant"}
301+
message = ChatMessage.from_openai_dict_format(openai_msg)
302+
assert message.role.value == "system"
303+
assert message.text == "You are a helpful assistant"
304+
305+
306+
def test_from_openai_dict_format_assistant_message_with_content():
307+
openai_msg = {"role": "assistant", "content": "I can help with that"}
308+
message = ChatMessage.from_openai_dict_format(openai_msg)
309+
assert message.role.value == "assistant"
310+
assert message.text == "I can help with that"
311+
312+
313+
def test_from_openai_dict_format_assistant_message_with_tool_calls():
314+
openai_msg = {
315+
"role": "assistant",
316+
"content": None,
317+
"tool_calls": [{"id": "call_123", "function": {"name": "get_weather", "arguments": '{"location": "Berlin"}'}}],
318+
}
319+
message = ChatMessage.from_openai_dict_format(openai_msg)
320+
assert message.role.value == "assistant"
321+
assert message.text is None
322+
assert len(message.tool_calls) == 1
323+
tool_call = message.tool_calls[0]
324+
assert tool_call.id == "call_123"
325+
assert tool_call.tool_name == "get_weather"
326+
assert tool_call.arguments == {"location": "Berlin"}
327+
328+
329+
def test_from_openai_dict_format_tool_message():
330+
openai_msg = {"role": "tool", "content": "The weather is sunny", "tool_call_id": "call_123"}
331+
message = ChatMessage.from_openai_dict_format(openai_msg)
332+
assert message.role.value == "tool"
333+
assert message.tool_call_result.result == "The weather is sunny"
334+
assert message.tool_call_result.origin.id == "call_123"
335+
336+
337+
def test_from_openai_dict_format_tool_without_id():
338+
openai_msg = {"role": "tool", "content": "The weather is sunny"}
339+
message = ChatMessage.from_openai_dict_format(openai_msg)
340+
assert message.role.value == "tool"
341+
assert message.tool_call_result.result == "The weather is sunny"
342+
assert message.tool_call_result.origin.id is None
343+
344+
345+
def test_from_openai_dict_format_missing_role():
346+
with pytest.raises(ValueError):
347+
ChatMessage.from_openai_dict_format({"content": "test"})
348+
349+
350+
def test_from_openai_dict_format_missing_content():
351+
with pytest.raises(ValueError):
352+
ChatMessage.from_openai_dict_format({"role": "user"})
353+
354+
355+
def test_from_openai_dict_format_invalid_tool_calls():
356+
openai_msg = {"role": "assistant", "tool_calls": [{"invalid": "format"}]}
357+
with pytest.raises(ValueError):
358+
ChatMessage.from_openai_dict_format(openai_msg)
359+
360+
361+
def test_from_openai_dict_format_unsupported_role():
362+
with pytest.raises(ValueError):
363+
ChatMessage.from_openai_dict_format({"role": "invalid", "content": "test"})
364+
365+
366+
def test_from_openai_dict_format_assistant_missing_content_and_tool_calls():
367+
with pytest.raises(ValueError):
368+
ChatMessage.from_openai_dict_format({"role": "assistant", "irrelevant": "irrelevant"})
369+
370+
291371
@pytest.mark.integration
292372
def test_apply_chat_templating_on_chat_message():
293373
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]

0 commit comments

Comments
 (0)