Skip to content

Commit efa0532

Browse files
committed
Add audio input/output support to OpenAiChatPromptDriver
1 parent c8a17b8 commit efa0532

13 files changed

+217
-39
lines changed

griptape/common/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
from .prompt_stack.contents.base_message_content import BaseMessageContent
55
from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent
66
from .prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent
7+
from .prompt_stack.contents.audio_delta_message_content import AudioDeltaMessageContent
8+
from .prompt_stack.contents.audio_transcript_delta_message_content import AudioTranscriptDeltaMessageContent
79
from .prompt_stack.contents.text_message_content import TextMessageContent
810
from .prompt_stack.contents.image_message_content import ImageMessageContent
11+
from .prompt_stack.contents.audio_message_content import AudioMessageContent
912
from .prompt_stack.contents.action_call_delta_message_content import ActionCallDeltaMessageContent
1013
from .prompt_stack.contents.action_call_message_content import ActionCallMessageContent
1114
from .prompt_stack.contents.action_result_message_content import ActionResultMessageContent
@@ -30,8 +33,11 @@
3033
"DeltaMessage",
3134
"Message",
3235
"TextDeltaMessageContent",
36+
"AudioDeltaMessageContent",
37+
"AudioTranscriptDeltaMessageContent",
3338
"TextMessageContent",
3439
"ImageMessageContent",
40+
"AudioMessageContent",
3541
"GenericMessageContent",
3642
"ActionCallDeltaMessageContent",
3743
"ActionCallMessageContent",
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from typing import Optional
4+
5+
from attrs import define, field
6+
7+
from griptape.common import BaseDeltaMessageContent
8+
9+
10+
@define
11+
class AudioDeltaMessageContent(BaseDeltaMessageContent):
12+
"""A delta message content for audio data.
13+
14+
Attributes:
15+
id: The ID of the audio data.
16+
data: Base64 encoded audio data.
17+
transcript: The transcript of the audio data.
18+
expires_at: The Unix timestamp (in seconds) for when this audio data will no longer be accessible.
19+
"""
20+
21+
id: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
22+
data: Optional[str] = field(kw_only=True, metadata={"serializable": True})
23+
transcript: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
24+
expires_at: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
from typing import TYPE_CHECKING
5+
6+
from attrs import define, field
7+
8+
from griptape.artifacts import AudioArtifact
9+
from griptape.common import (
10+
AudioDeltaMessageContent,
11+
BaseDeltaMessageContent,
12+
BaseMessageContent,
13+
)
14+
15+
if TYPE_CHECKING:
16+
from collections.abc import Sequence
17+
18+
19+
@define
20+
class AudioMessageContent(BaseMessageContent):
21+
artifact: AudioArtifact = field(metadata={"serializable": True})
22+
23+
@classmethod
24+
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> AudioMessageContent:
25+
audio_deltas = [delta for delta in deltas if isinstance(delta, AudioDeltaMessageContent)]
26+
audio_data = [delta.data for delta in audio_deltas if delta.data is not None]
27+
transcript_data = [delta.transcript for delta in audio_deltas if delta.transcript is not None]
28+
expires_at = next(delta.expires_at for delta in audio_deltas if delta.expires_at is not None)
29+
audio_id = next(delta.id for delta in audio_deltas if delta.id is not None)
30+
31+
audio_transcript = "".join(data for data in transcript_data)
32+
33+
artifact = AudioArtifact(
34+
value=b"".join(base64.b64decode(data) for data in audio_data),
35+
format="wav",
36+
meta={
37+
"audio_id": audio_id,
38+
"expires_at": expires_at,
39+
"transcript": audio_transcript,
40+
},
41+
)
42+
43+
return cls(artifact=artifact)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from __future__ import annotations
2+
3+
from attrs import define, field
4+
5+
from griptape.common import BaseDeltaMessageContent
6+
7+
8+
@define
9+
class AudioTranscriptDeltaMessageContent(BaseDeltaMessageContent):
10+
text: str = field(metadata={"serializable": True})

griptape/common/prompt_stack/messages/message.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def has_any_content_type(self, content_type: type[T]) -> bool:
3737
def get_content_type(self, content_type: type[T]) -> list[T]:
3838
return [content for content in self.content if isinstance(content, content_type)]
3939

40+
def exclude_content_type(self, content_type: type[T] | tuple[type[T]]) -> list[BaseMessageContent]:
41+
return [content for content in self.content if not isinstance(content, content_type)]
42+
4043
def is_text(self) -> bool:
4144
return all(isinstance(content, TextMessageContent) for content in self.content)
4245

griptape/common/prompt_stack/prompt_stack.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from griptape.artifacts import (
88
ActionArtifact,
9+
AudioArtifact,
910
BaseArtifact,
1011
GenericArtifact,
1112
ImageArtifact,
@@ -15,6 +16,7 @@
1516
from griptape.common import (
1617
ActionCallMessageContent,
1718
ActionResultMessageContent,
19+
AudioMessageContent,
1820
BaseMessageContent,
1921
GenericMessageContent,
2022
ImageMessageContent,
@@ -77,6 +79,8 @@ def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessage
7779
return [TextMessageContent(artifact)]
7880
elif isinstance(artifact, ImageArtifact):
7981
return [ImageMessageContent(artifact)]
82+
elif isinstance(artifact, AudioArtifact):
83+
return [AudioMessageContent(artifact)]
8084
elif isinstance(artifact, GenericArtifact):
8185
return [GenericMessageContent(artifact)]
8286
elif isinstance(artifact, ActionArtifact):

griptape/drivers/prompt/base_prompt_driver.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from griptape.common import (
1010
ActionCallDeltaMessageContent,
1111
ActionCallMessageContent,
12+
AudioDeltaMessageContent,
13+
AudioMessageContent,
1214
BaseDeltaMessageContent,
1315
DeltaMessage,
1416
Message,
@@ -19,6 +21,7 @@
1921
)
2022
from griptape.events import (
2123
ActionChunkEvent,
24+
AudioChunkEvent,
2225
EventBus,
2326
FinishPromptEvent,
2427
StartPromptEvent,
@@ -177,6 +180,8 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message:
177180
delta_contents[content.index] = [content]
178181
if isinstance(content, TextDeltaMessageContent):
179182
EventBus.publish_event(TextChunkEvent(token=content.text, index=content.index))
183+
elif isinstance(content, AudioDeltaMessageContent) and content.data is not None:
184+
EventBus.publish_event(AudioChunkEvent(data=content.data))
180185
elif isinstance(content, ActionCallDeltaMessageContent):
181186
EventBus.publish_event(
182187
ActionChunkEvent(
@@ -197,10 +202,13 @@ def __build_message(
197202
content = []
198203
for delta_content in delta_contents:
199204
text_deltas = [delta for delta in delta_content if isinstance(delta, TextDeltaMessageContent)]
205+
audio_deltas = [delta for delta in delta_content if isinstance(delta, AudioDeltaMessageContent)]
200206
action_deltas = [delta for delta in delta_content if isinstance(delta, ActionCallDeltaMessageContent)]
201207

202208
if text_deltas:
203209
content.append(TextMessageContent.from_deltas(text_deltas))
210+
if audio_deltas:
211+
content.append(AudioMessageContent.from_deltas(audio_deltas))
204212
if action_deltas:
205213
content.append(ActionCallMessageContent.from_deltas(action_deltas))
206214

griptape/drivers/prompt/openai_chat_prompt_driver.py

Lines changed: 78 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
from __future__ import annotations
22

3+
import base64
34
import json
45
import logging
6+
import time
57
from typing import TYPE_CHECKING, Optional
68

79
import openai
810
from attrs import Factory, define, field
911
from schema import Schema
1012

11-
from griptape.artifacts import ActionArtifact, TextArtifact
13+
from griptape.artifacts import ActionArtifact, AudioArtifact, TextArtifact
1214
from griptape.common import (
1315
ActionCallDeltaMessageContent,
1416
ActionCallMessageContent,
1517
ActionResultMessageContent,
18+
AudioDeltaMessageContent,
19+
AudioMessageContent,
1620
BaseDeltaMessageContent,
1721
BaseMessageContent,
1822
DeltaMessage,
@@ -94,6 +98,10 @@ class OpenAiChatPromptDriver(BasePromptDriver):
9498
),
9599
kw_only=True,
96100
)
101+
modalities: list[str] = field(default=Factory(lambda: ["text"]), kw_only=True, metadata={"serializable": True})
102+
audio: dict = field(
103+
default=Factory(lambda: {"voice": "alloy", "format": "pcm16"}), kw_only=True, metadata={"serializable": True}
104+
)
97105
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
98106

99107
@lazy_property()
@@ -144,14 +152,18 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
144152
choice = chunk.choices[0]
145153
delta = choice.delta
146154

147-
yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(delta))
155+
content = self.__to_prompt_stack_delta_message_content(delta)
156+
if content is not None:
157+
yield DeltaMessage(content=content)
148158

149159
def _base_params(self, prompt_stack: PromptStack) -> dict:
150160
params = {
151161
"model": self.model,
152162
"temperature": self.temperature,
153163
"user": self.user,
154164
"seed": self.seed,
165+
"modalities": self.modalities,
166+
"audio": self.audio,
155167
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
156168
**({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
157169
**({"stream_options": {"include_usage": True}} if self.stream else {}),
@@ -196,45 +208,44 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
196208
openai_messages = []
197209

198210
for message in messages:
199-
# If the message only contains textual content we can send it as a single content.
200-
if message.is_text():
201-
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
202211
# Action results must be sent as separate messages.
203-
elif message.has_any_content_type(ActionResultMessageContent):
212+
213+
action_result_contents = message.get_content_type(ActionResultMessageContent)
214+
# Action results must be sent as separate messages.
215+
if action_result_contents:
204216
openai_messages.extend(
205217
{
206-
"role": self.__to_openai_role(message, action_result),
207-
"content": self.__to_openai_message_content(action_result),
208-
"tool_call_id": action_result.action.tag,
218+
"role": self.__to_openai_role(message, action_result_content),
219+
"content": self.__to_openai_message_content(action_result_content),
220+
"tool_call_id": action_result_content.action.tag,
209221
}
210-
for action_result in message.get_content_type(ActionResultMessageContent)
222+
for action_result_content in action_result_contents
211223
)
212224

213225
if message.has_any_content_type(TextMessageContent):
214226
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
215227
else:
216228
openai_message = {
217229
"role": self.__to_openai_role(message),
218-
"content": [
219-
self.__to_openai_message_content(content)
220-
for content in [
221-
content for content in message.content if not isinstance(content, ActionCallMessageContent)
222-
]
223-
],
230+
"content": [],
224231
}
232+
233+
for content in message.content:
234+
if isinstance(content, ActionCallMessageContent):
235+
if "tool_calls" not in openai_message:
236+
openai_message["tool_calls"] = []
237+
openai_message["tool_calls"].append(self.__to_openai_message_content(content))
238+
elif isinstance(content, AudioMessageContent) and message.is_assistant():
239+
openai_message["audio"] = {
240+
"id": content.artifact.meta["audio_id"],
241+
}
242+
else:
243+
openai_message["content"].append(self.__to_openai_message_content(content))
244+
225245
# Some OpenAi-compatible services don't accept an empty array for content
226246
if not openai_message["content"]:
227247
openai_message["content"] = ""
228248

229-
# Action calls must be attached to the message, not sent as content.
230-
action_call_content = [
231-
content for content in message.content if isinstance(content, ActionCallMessageContent)
232-
]
233-
if action_call_content:
234-
openai_message["tool_calls"] = [
235-
self.__to_openai_message_content(action_call) for action_call in action_call_content
236-
]
237-
238249
openai_messages.append(openai_message)
239250

240251
return openai_messages
@@ -272,6 +283,23 @@ def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict
272283
"type": "image_url",
273284
"image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"},
274285
}
286+
elif isinstance(content, AudioMessageContent):
287+
artifact = content.artifact
288+
289+
# We can't send the audio if it's expired.
290+
if int(time.time()) > artifact.meta.get("expires_at", float("inf")):
291+
return {
292+
"type": "text",
293+
"text": artifact.meta.get("transcript"),
294+
}
295+
else:
296+
return {
297+
"type": "input_audio",
298+
"input_audio": {
299+
"data": base64.b64encode(artifact.value).decode("utf-8"),
300+
"format": artifact.format,
301+
},
302+
}
275303
elif isinstance(content, ActionCallMessageContent):
276304
action = content.artifact.value
277305

@@ -290,6 +318,20 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
290318

291319
if response.content is not None:
292320
content.append(TextMessageContent(TextArtifact(response.content)))
321+
if response.audio is not None:
322+
content.append(
323+
AudioMessageContent(
324+
AudioArtifact(
325+
value=base64.b64decode(response.audio.data),
326+
format="wav",
327+
meta={
328+
"audio_id": response.audio.id,
329+
"transcript": response.audio.transcript,
330+
"expires_at": response.audio.expires_at,
331+
},
332+
)
333+
)
334+
)
293335
if response.tool_calls is not None:
294336
content.extend(
295337
[
@@ -309,7 +351,7 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
309351

310352
return content
311353

312-
def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> BaseDeltaMessageContent:
354+
def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> Optional[BaseDeltaMessageContent]:
313355
if content_delta.content is not None:
314356
return TextDeltaMessageContent(content_delta.content)
315357
elif content_delta.tool_calls is not None:
@@ -334,5 +376,13 @@ def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) ->
334376
raise ValueError(f"Unsupported tool call delta: {tool_call}")
335377
else:
336378
raise ValueError(f"Unsupported tool call delta length: {len(tool_calls)}")
337-
else:
338-
return TextDeltaMessageContent("")
379+
# OpenAi doesn't have types for audio deltas so we need to use hasattr and getattr.
380+
elif hasattr(content_delta, "audio") and getattr(content_delta, "audio") is not None:
381+
audio_chunk: dict = getattr(content_delta, "audio")
382+
return AudioDeltaMessageContent(
383+
id=audio_chunk.get("id"),
384+
data=audio_chunk.get("data"),
385+
expires_at=audio_chunk.get("expires_at"),
386+
transcript=audio_chunk.get("transcript"),
387+
)
388+
return None

griptape/events/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .finish_structure_run_event import FinishStructureRunEvent
1313
from .base_chunk_event import BaseChunkEvent
1414
from .text_chunk_event import TextChunkEvent
15+
from .audio_chunk_event import AudioChunkEvent
1516
from .action_chunk_event import ActionChunkEvent
1617
from .event_listener import EventListener
1718
from .start_image_generation_event import StartImageGenerationEvent
@@ -41,6 +42,7 @@
4142
"FinishStructureRunEvent",
4243
"BaseChunkEvent",
4344
"TextChunkEvent",
45+
"AudioChunkEvent",
4446
"ActionChunkEvent",
4547
"EventListener",
4648
"StartImageGenerationEvent",

0 commit comments

Comments
 (0)