Skip to content

feat(drivers-prompt-openai):add audio input/output support to OpenAiChatPromptDriver #1617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for enums in `GriptapeCloudToolTool`.
- `LocalRerankDriver` for reranking locally.
- `griptape.utils.griptape_cloud.GriptapeCloudStructure` for automatically configuring Cloud-specific Drivers when in the Griptape Cloud Structures Runtime.
- Support for `AudioArtifact` inputs/outputs in `OpenAiChatPromptDriver`.

### Changed

Expand Down
4 changes: 4 additions & 0 deletions griptape/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from .prompt_stack.contents.base_message_content import BaseMessageContent
from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent
from .prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent
from .prompt_stack.contents.audio_delta_message_content import AudioDeltaMessageContent
from .prompt_stack.contents.text_message_content import TextMessageContent
from .prompt_stack.contents.image_message_content import ImageMessageContent
from .prompt_stack.contents.audio_message_content import AudioMessageContent
from .prompt_stack.contents.action_call_delta_message_content import ActionCallDeltaMessageContent
from .prompt_stack.contents.action_call_message_content import ActionCallMessageContent
from .prompt_stack.contents.action_result_message_content import ActionResultMessageContent
Expand All @@ -30,8 +32,10 @@
"DeltaMessage",
"Message",
"TextDeltaMessageContent",
"AudioDeltaMessageContent",
"TextMessageContent",
"ImageMessageContent",
"AudioMessageContent",
"GenericMessageContent",
"ActionCallDeltaMessageContent",
"ActionCallMessageContent",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing import Optional

from attrs import define, field

from griptape.common import BaseDeltaMessageContent


@define(kw_only=True)
class AudioDeltaMessageContent(BaseDeltaMessageContent):
"""A delta message content for audio data.

Attributes:
id: The ID of the audio data.
data: Base64 encoded audio data.
transcript: The transcript of the audio data.
expires_at: The Unix timestamp (in seconds) for when this audio data will no longer be accessible.
"""

id: Optional[str] = field(default=None, metadata={"serializable": True})
data: Optional[str] = field(default=None, metadata={"serializable": True})
transcript: Optional[str] = field(default=None, metadata={"serializable": True})
expires_at: Optional[int] = field(default=None, metadata={"serializable": True})
43 changes: 43 additions & 0 deletions griptape/common/prompt_stack/contents/audio_message_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

import base64
from typing import TYPE_CHECKING

from attrs import define, field

from griptape.artifacts import AudioArtifact
from griptape.common import (
AudioDeltaMessageContent,
BaseDeltaMessageContent,
BaseMessageContent,
)

if TYPE_CHECKING:
from collections.abc import Sequence


@define
class AudioMessageContent(BaseMessageContent):
artifact: AudioArtifact = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> AudioMessageContent:
audio_deltas = [delta for delta in deltas if isinstance(delta, AudioDeltaMessageContent)]
audio_data = [delta.data for delta in audio_deltas if delta.data is not None]
transcript_data = [delta.transcript for delta in audio_deltas if delta.transcript is not None]
expires_at = next(delta.expires_at for delta in audio_deltas if delta.expires_at is not None)
audio_id = next(delta.id for delta in audio_deltas if delta.id is not None)

audio_transcript = "".join(data for data in transcript_data)

artifact = AudioArtifact(
value=b"".join(base64.b64decode(data) for data in audio_data),
format="wav",
meta={
"audio_id": audio_id,
"expires_at": expires_at,
"transcript": audio_transcript,
},
)

return cls(artifact=artifact)
4 changes: 4 additions & 0 deletions griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from griptape.artifacts import (
ActionArtifact,
AudioArtifact,
BaseArtifact,
GenericArtifact,
ImageArtifact,
Expand All @@ -17,6 +18,7 @@
from griptape.common import (
ActionCallMessageContent,
ActionResultMessageContent,
AudioMessageContent,
BaseMessageContent,
GenericMessageContent,
ImageMessageContent,
Expand Down Expand Up @@ -91,6 +93,8 @@ def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessage
return [TextMessageContent(artifact)]
elif isinstance(artifact, ImageArtifact):
return [ImageMessageContent(artifact)]
elif isinstance(artifact, AudioArtifact):
return [AudioMessageContent(artifact)]
elif isinstance(artifact, GenericArtifact):
return [GenericMessageContent(artifact)]
elif isinstance(artifact, ActionArtifact):
Expand Down
8 changes: 8 additions & 0 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from griptape.common import (
ActionCallDeltaMessageContent,
ActionCallMessageContent,
AudioDeltaMessageContent,
AudioMessageContent,
BaseDeltaMessageContent,
DeltaMessage,
Message,
Expand All @@ -19,6 +21,7 @@
)
from griptape.events import (
ActionChunkEvent,
AudioChunkEvent,
EventBus,
FinishPromptEvent,
StartPromptEvent,
Expand Down Expand Up @@ -177,6 +180,8 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message:
delta_contents[content.index] = [content]
if isinstance(content, TextDeltaMessageContent):
EventBus.publish_event(TextChunkEvent(token=content.text, index=content.index))
elif isinstance(content, AudioDeltaMessageContent) and content.data is not None:
EventBus.publish_event(AudioChunkEvent(data=content.data))
elif isinstance(content, ActionCallDeltaMessageContent):
EventBus.publish_event(
ActionChunkEvent(
Expand All @@ -197,10 +202,13 @@ def __build_message(
content = []
for delta_content in delta_contents:
text_deltas = [delta for delta in delta_content if isinstance(delta, TextDeltaMessageContent)]
audio_deltas = [delta for delta in delta_content if isinstance(delta, AudioDeltaMessageContent)]
action_deltas = [delta for delta in delta_content if isinstance(delta, ActionCallDeltaMessageContent)]

if text_deltas:
content.append(TextMessageContent.from_deltas(text_deltas))
if audio_deltas:
content.append(AudioMessageContent.from_deltas(audio_deltas))
if action_deltas:
content.append(ActionCallMessageContent.from_deltas(action_deltas))

Expand Down
116 changes: 90 additions & 26 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from __future__ import annotations

import base64
import json
import logging
import time
from typing import TYPE_CHECKING, Literal, Optional

import openai
from attrs import Factory, define, field

from griptape.artifacts import ActionArtifact, TextArtifact
from griptape.artifacts import ActionArtifact, AudioArtifact, TextArtifact
from griptape.common import (
ActionCallDeltaMessageContent,
ActionCallMessageContent,
ActionResultMessageContent,
AudioDeltaMessageContent,
AudioMessageContent,
BaseDeltaMessageContent,
BaseMessageContent,
DeltaMessage,
Expand Down Expand Up @@ -96,6 +100,10 @@
),
kw_only=True,
)
modalities: list[str] = field(default=Factory(lambda: ["text"]), kw_only=True, metadata={"serializable": True})
audio: dict = field(
default=Factory(lambda: {"voice": "alloy", "format": "pcm16"}), kw_only=True, metadata={"serializable": True}
)
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
Expand Down Expand Up @@ -150,15 +158,20 @@
choice = chunk.choices[0]
delta = choice.delta

yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(delta))
content = self.__to_prompt_stack_delta_message_content(delta)

if content is not None:
yield DeltaMessage(content=content)

def _base_params(self, prompt_stack: PromptStack) -> dict:
params = {
"model": self.model,
"user": self.user,
"seed": self.seed,
"modalities": self.modalities,
**({"reasoning_effort": self.reasoning_effort} if self.is_reasoning_model else {}),
**({"temperature": self.temperature} if not self.is_reasoning_model else {}),
**({"audio": self.audio} if "audio" in self.modalities else {}),
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
**({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
**({"stream_options": {"include_usage": True}} if self.stream else {}),
Expand Down Expand Up @@ -204,44 +217,48 @@

for message in messages:
# If the message only contains textual content we can send it as a single content.
if message.is_text():
if message.has_all_content_type(TextMessageContent):
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
# Action results must be sent as separate messages.
elif message.has_any_content_type(ActionResultMessageContent):
elif action_result_contents := message.get_content_type(ActionResultMessageContent):
openai_messages.extend(
{
"role": self.__to_openai_role(message, action_result),
"content": self.__to_openai_message_content(action_result),
"tool_call_id": action_result.action.tag,
"role": self.__to_openai_role(message, action_result_content),
"content": self.__to_openai_message_content(action_result_content),
"tool_call_id": action_result_content.action.tag,
}
for action_result in message.get_content_type(ActionResultMessageContent)
for action_result_content in action_result_contents
)

if message.has_any_content_type(TextMessageContent):
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
else:
openai_message = {
"role": self.__to_openai_role(message),
"content": [
self.__to_openai_message_content(content)
for content in [
content for content in message.content if not isinstance(content, ActionCallMessageContent)
]
],
"content": [],
}

for content in message.content:
if isinstance(content, ActionCallMessageContent):
if "tool_calls" not in openai_message:
openai_message["tool_calls"] = []
openai_message["tool_calls"].append(self.__to_openai_message_content(content))
elif (
isinstance(content, AudioMessageContent)
and message.is_assistant()
and time.time() < content.artifact.meta["expires_at"]
):
# For assistant audio messages, we reference the audio id instead of sending audio message content.
openai_message["audio"] = {
"id": content.artifact.meta["audio_id"],
}
else:
openai_message["content"].append(self.__to_openai_message_content(content))

# Some OpenAi-compatible services don't accept an empty array for content
if not openai_message["content"]:
openai_message["content"] = ""

# Action calls must be attached to the message, not sent as content.
action_call_content = [
content for content in message.content if isinstance(content, ActionCallMessageContent)
]
if action_call_content:
openai_message["tool_calls"] = [
self.__to_openai_message_content(action_call) for action_call in action_call_content
]

openai_messages.append(openai_message)

return openai_messages
Expand Down Expand Up @@ -282,6 +299,31 @@
"type": "image_url",
"image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"},
}
elif isinstance(content, AudioMessageContent):
artifact = content.artifact
metadata = artifact.meta

# If there's an expiration date, we can assume it's an assistant message.
if "expires_at" in metadata:
# If it's expired, we send the transcript instead.
if time.time() >= metadata["expires_at"]:
return {
"type": "text",
"text": artifact.meta.get("transcript"),
}
else:
# This should never occur, since a non-expired audio content
# should have already been referenced by the audio id.
raise ValueError("Assistant audio messages should be sent as audio ids.")

Check warning on line 317 in griptape/drivers/prompt/openai_chat_prompt_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/prompt/openai_chat_prompt_driver.py#L317

Added line #L317 was not covered by tests
else:
# If there's no expiration date, we can assume it's a user message where we send the audio every time.
return {
"type": "input_audio",
"input_audio": {
"data": base64.b64encode(artifact.value).decode("utf-8"),
"format": artifact.format,
},
}
elif isinstance(content, ActionCallMessageContent):
action = content.artifact.value

Expand All @@ -300,6 +342,20 @@

if response.content is not None:
content.append(TextMessageContent(TextArtifact(response.content)))
if response.audio is not None:
content.append(
AudioMessageContent(
AudioArtifact(
value=base64.b64decode(response.audio.data),
format="wav",
meta={
"audio_id": response.audio.id,
"transcript": response.audio.transcript,
"expires_at": response.audio.expires_at,
},
)
)
)
if response.tool_calls is not None:
content.extend(
[
Expand All @@ -319,7 +375,7 @@

return content

def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> BaseDeltaMessageContent:
def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> Optional[BaseDeltaMessageContent]:
if content_delta.content is not None:
return TextDeltaMessageContent(content_delta.content)
elif content_delta.tool_calls is not None:
Expand All @@ -342,5 +398,13 @@
raise ValueError(f"Unsupported tool call delta: {tool_call}")
else:
raise ValueError(f"Unsupported tool call delta length: {len(tool_calls)}")
else:
return TextDeltaMessageContent("")
# OpenAi doesn't have types for audio deltas so we need to use hasattr and getattr.
elif hasattr(content_delta, "audio") and getattr(content_delta, "audio") is not None:
audio_chunk: dict = getattr(content_delta, "audio")
return AudioDeltaMessageContent(
id=audio_chunk.get("id"),
data=audio_chunk.get("data"),
expires_at=audio_chunk.get("expires_at"),
transcript=audio_chunk.get("transcript"),
)
return None
2 changes: 2 additions & 0 deletions griptape/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .finish_structure_run_event import FinishStructureRunEvent
from .base_chunk_event import BaseChunkEvent
from .text_chunk_event import TextChunkEvent
from .audio_chunk_event import AudioChunkEvent
from .action_chunk_event import ActionChunkEvent
from .event_listener import EventListener
from .start_image_generation_event import StartImageGenerationEvent
Expand Down Expand Up @@ -41,6 +42,7 @@
"FinishStructureRunEvent",
"BaseChunkEvent",
"TextChunkEvent",
"AudioChunkEvent",
"ActionChunkEvent",
"EventListener",
"StartImageGenerationEvent",
Expand Down
Loading