Skip to content

Commit f023372

Browse files
committed
feat(drivers-prompt-openai):add audio input/output support to OpenAiChatPromptDriver
1 parent ce01b8b commit f023372

20 files changed

+535
-98
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2828
- Tool streaming support to `OllamaPromptDriver`.
2929
- `DateTimeTool.add_timedelta` and `DateTimeTool.get_datetime_diff` for basic datetime arithmetic.
3030
- Support for `pydantic.BaseModel`s anywhere `schema.Schema` is supported.
31+
- Support for `AudioArtifact` inputs/outputs in `OpenAiChatPromptDriver`.
3132

3233
### Changed
3334

griptape/common/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from .prompt_stack.contents.base_message_content import BaseMessageContent
55
from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent
66
from .prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent
7+
from .prompt_stack.contents.audio_delta_message_content import AudioDeltaMessageContent
78
from .prompt_stack.contents.text_message_content import TextMessageContent
89
from .prompt_stack.contents.image_message_content import ImageMessageContent
10+
from .prompt_stack.contents.audio_message_content import AudioMessageContent
911
from .prompt_stack.contents.action_call_delta_message_content import ActionCallDeltaMessageContent
1012
from .prompt_stack.contents.action_call_message_content import ActionCallMessageContent
1113
from .prompt_stack.contents.action_result_message_content import ActionResultMessageContent
@@ -30,8 +32,10 @@
3032
"DeltaMessage",
3133
"Message",
3234
"TextDeltaMessageContent",
35+
"AudioDeltaMessageContent",
3336
"TextMessageContent",
3437
"ImageMessageContent",
38+
"AudioMessageContent",
3539
"GenericMessageContent",
3640
"ActionCallDeltaMessageContent",
3741
"ActionCallMessageContent",
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from __future__ import annotations
2+
3+
from typing import Optional
4+
5+
from attrs import define, field
6+
7+
from griptape.common import BaseDeltaMessageContent
8+
9+
10+
@define(kw_only=True)
11+
class AudioDeltaMessageContent(BaseDeltaMessageContent):
12+
"""A delta message content for audio data.
13+
14+
Attributes:
15+
id: The ID of the audio data.
16+
data: Base64 encoded audio data.
17+
transcript: The transcript of the audio data.
18+
expires_at: The Unix timestamp (in seconds) for when this audio data will no longer be accessible.
19+
"""
20+
21+
id: Optional[str] = field(default=None, metadata={"serializable": True})
22+
data: Optional[str] = field(default=None, metadata={"serializable": True})
23+
transcript: Optional[str] = field(default=None, metadata={"serializable": True})
24+
expires_at: Optional[int] = field(default=None, metadata={"serializable": True})
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
from typing import TYPE_CHECKING
5+
6+
from attrs import define, field
7+
8+
from griptape.artifacts import AudioArtifact
9+
from griptape.common import (
10+
AudioDeltaMessageContent,
11+
BaseDeltaMessageContent,
12+
BaseMessageContent,
13+
)
14+
15+
if TYPE_CHECKING:
16+
from collections.abc import Sequence
17+
18+
19+
@define
20+
class AudioMessageContent(BaseMessageContent):
21+
artifact: AudioArtifact = field(metadata={"serializable": True})
22+
23+
@classmethod
24+
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> AudioMessageContent:
25+
audio_deltas = [delta for delta in deltas if isinstance(delta, AudioDeltaMessageContent)]
26+
audio_data = [delta.data for delta in audio_deltas if delta.data is not None]
27+
transcript_data = [delta.transcript for delta in audio_deltas if delta.transcript is not None]
28+
expires_at = next(delta.expires_at for delta in audio_deltas if delta.expires_at is not None)
29+
audio_id = next(delta.id for delta in audio_deltas if delta.id is not None)
30+
31+
audio_transcript = "".join(data for data in transcript_data)
32+
33+
artifact = AudioArtifact(
34+
value=b"".join(base64.b64decode(data) for data in audio_data),
35+
format="wav",
36+
meta={
37+
"audio_id": audio_id,
38+
"expires_at": expires_at,
39+
"transcript": audio_transcript,
40+
},
41+
)
42+
43+
return cls(artifact=artifact)

griptape/common/prompt_stack/prompt_stack.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from griptape.artifacts import (
1010
ActionArtifact,
11+
AudioArtifact,
1112
BaseArtifact,
1213
GenericArtifact,
1314
ImageArtifact,
@@ -17,6 +18,7 @@
1718
from griptape.common import (
1819
ActionCallMessageContent,
1920
ActionResultMessageContent,
21+
AudioMessageContent,
2022
BaseMessageContent,
2123
GenericMessageContent,
2224
ImageMessageContent,
@@ -91,6 +93,8 @@ def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessage
9193
return [TextMessageContent(artifact)]
9294
elif isinstance(artifact, ImageArtifact):
9395
return [ImageMessageContent(artifact)]
96+
elif isinstance(artifact, AudioArtifact):
97+
return [AudioMessageContent(artifact)]
9498
elif isinstance(artifact, GenericArtifact):
9599
return [GenericMessageContent(artifact)]
96100
elif isinstance(artifact, ActionArtifact):

griptape/drivers/prompt/base_prompt_driver.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from griptape.common import (
1010
ActionCallDeltaMessageContent,
1111
ActionCallMessageContent,
12+
AudioDeltaMessageContent,
13+
AudioMessageContent,
1214
BaseDeltaMessageContent,
1315
DeltaMessage,
1416
Message,
@@ -19,6 +21,7 @@
1921
)
2022
from griptape.events import (
2123
ActionChunkEvent,
24+
AudioChunkEvent,
2225
EventBus,
2326
FinishPromptEvent,
2427
StartPromptEvent,
@@ -177,6 +180,8 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message:
177180
delta_contents[content.index] = [content]
178181
if isinstance(content, TextDeltaMessageContent):
179182
EventBus.publish_event(TextChunkEvent(token=content.text, index=content.index))
183+
elif isinstance(content, AudioDeltaMessageContent) and content.data is not None:
184+
EventBus.publish_event(AudioChunkEvent(data=content.data))
180185
elif isinstance(content, ActionCallDeltaMessageContent):
181186
EventBus.publish_event(
182187
ActionChunkEvent(
@@ -197,10 +202,13 @@ def __build_message(
197202
content = []
198203
for delta_content in delta_contents:
199204
text_deltas = [delta for delta in delta_content if isinstance(delta, TextDeltaMessageContent)]
205+
audio_deltas = [delta for delta in delta_content if isinstance(delta, AudioDeltaMessageContent)]
200206
action_deltas = [delta for delta in delta_content if isinstance(delta, ActionCallDeltaMessageContent)]
201207

202208
if text_deltas:
203209
content.append(TextMessageContent.from_deltas(text_deltas))
210+
if audio_deltas:
211+
content.append(AudioMessageContent.from_deltas(audio_deltas))
204212
if action_deltas:
205213
content.append(ActionCallMessageContent.from_deltas(action_deltas))
206214

griptape/drivers/prompt/openai_chat_prompt_driver.py

Lines changed: 81 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
from __future__ import annotations
22

3+
import base64
34
import json
45
import logging
6+
import time
57
from typing import TYPE_CHECKING, Optional
68

79
import openai
810
from attrs import Factory, define, field
911

10-
from griptape.artifacts import ActionArtifact, TextArtifact
12+
from griptape.artifacts import ActionArtifact, AudioArtifact, TextArtifact
1113
from griptape.common import (
1214
ActionCallDeltaMessageContent,
1315
ActionCallMessageContent,
1416
ActionResultMessageContent,
17+
AudioDeltaMessageContent,
18+
AudioMessageContent,
1519
BaseDeltaMessageContent,
1620
BaseMessageContent,
1721
DeltaMessage,
@@ -93,6 +97,10 @@ class OpenAiChatPromptDriver(BasePromptDriver):
9397
),
9498
kw_only=True,
9599
)
100+
modalities: list[str] = field(default=Factory(lambda: ["text"]), kw_only=True, metadata={"serializable": True})
101+
audio: dict = field(
102+
default=Factory(lambda: {"voice": "alloy", "format": "pcm16"}), kw_only=True, metadata={"serializable": True}
103+
)
96104
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
97105

98106
@lazy_property()
@@ -143,14 +151,19 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
143151
choice = chunk.choices[0]
144152
delta = choice.delta
145153

146-
yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(delta))
154+
content = self.__to_prompt_stack_delta_message_content(delta)
155+
156+
if content is not None:
157+
yield DeltaMessage(content=content)
147158

148159
def _base_params(self, prompt_stack: PromptStack) -> dict:
149160
params = {
150161
"model": self.model,
151162
"temperature": self.temperature,
152163
"user": self.user,
153164
"seed": self.seed,
165+
"modalities": self.modalities,
166+
"audio": self.audio,
154167
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
155168
**({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
156169
**({"stream_options": {"include_usage": True}} if self.stream else {}),
@@ -196,44 +209,47 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
196209

197210
for message in messages:
198211
# If the message only contains textual content we can send it as a single content.
199-
if message.is_text():
212+
if message.has_all_content_type(TextMessageContent):
200213
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
201214
# Action results must be sent as separate messages.
202-
elif message.has_any_content_type(ActionResultMessageContent):
215+
elif action_result_contents := message.get_content_type(ActionResultMessageContent):
203216
openai_messages.extend(
204217
{
205-
"role": self.__to_openai_role(message, action_result),
206-
"content": self.__to_openai_message_content(action_result),
207-
"tool_call_id": action_result.action.tag,
218+
"role": self.__to_openai_role(message, action_result_content),
219+
"content": self.__to_openai_message_content(action_result_content),
220+
"tool_call_id": action_result_content.action.tag,
208221
}
209-
for action_result in message.get_content_type(ActionResultMessageContent)
222+
for action_result_content in action_result_contents
210223
)
211224

212225
if message.has_any_content_type(TextMessageContent):
213226
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
214227
else:
215228
openai_message = {
216229
"role": self.__to_openai_role(message),
217-
"content": [
218-
self.__to_openai_message_content(content)
219-
for content in [
220-
content for content in message.content if not isinstance(content, ActionCallMessageContent)
221-
]
222-
],
230+
"content": [],
223231
}
232+
233+
for content in message.content:
234+
if isinstance(content, ActionCallMessageContent):
235+
if "tool_calls" not in openai_message:
236+
openai_message["tool_calls"] = []
237+
openai_message["tool_calls"].append(self.__to_openai_message_content(content))
238+
elif (
239+
isinstance(content, AudioMessageContent)
240+
and message.is_assistant()
241+
and time.time() < content.artifact.meta.get("expires_at", float("inf"))
242+
):
243+
openai_message["audio"] = {
244+
"id": content.artifact.meta["audio_id"],
245+
}
246+
else:
247+
openai_message["content"].append(self.__to_openai_message_content(content))
248+
224249
# Some OpenAi-compatible services don't accept an empty array for content
225250
if not openai_message["content"]:
226251
openai_message["content"] = ""
227252

228-
# Action calls must be attached to the message, not sent as content.
229-
action_call_content = [
230-
content for content in message.content if isinstance(content, ActionCallMessageContent)
231-
]
232-
if action_call_content:
233-
openai_message["tool_calls"] = [
234-
self.__to_openai_message_content(action_call) for action_call in action_call_content
235-
]
236-
237253
openai_messages.append(openai_message)
238254

239255
return openai_messages
@@ -271,6 +287,23 @@ def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict
271287
"type": "image_url",
272288
"image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"},
273289
}
290+
elif isinstance(content, AudioMessageContent):
291+
artifact = content.artifact
292+
293+
# We can't send the audio if it's expired.
294+
if int(time.time()) > artifact.meta.get("expires_at", float("inf")):
295+
return {
296+
"type": "text",
297+
"text": artifact.meta.get("transcript"),
298+
}
299+
else:
300+
return {
301+
"type": "input_audio",
302+
"input_audio": {
303+
"data": base64.b64encode(artifact.value).decode("utf-8"),
304+
"format": artifact.format,
305+
},
306+
}
274307
elif isinstance(content, ActionCallMessageContent):
275308
action = content.artifact.value
276309

@@ -289,6 +322,20 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
289322

290323
if response.content is not None:
291324
content.append(TextMessageContent(TextArtifact(response.content)))
325+
if response.audio is not None:
326+
content.append(
327+
AudioMessageContent(
328+
AudioArtifact(
329+
value=base64.b64decode(response.audio.data),
330+
format="wav",
331+
meta={
332+
"audio_id": response.audio.id,
333+
"transcript": response.audio.transcript,
334+
"expires_at": response.audio.expires_at,
335+
},
336+
)
337+
)
338+
)
292339
if response.tool_calls is not None:
293340
content.extend(
294341
[
@@ -308,7 +355,7 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
308355

309356
return content
310357

311-
def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> BaseDeltaMessageContent:
358+
def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> Optional[BaseDeltaMessageContent]:
312359
if content_delta.content is not None:
313360
return TextDeltaMessageContent(content_delta.content)
314361
elif content_delta.tool_calls is not None:
@@ -333,5 +380,13 @@ def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) ->
333380
raise ValueError(f"Unsupported tool call delta: {tool_call}")
334381
else:
335382
raise ValueError(f"Unsupported tool call delta length: {len(tool_calls)}")
336-
else:
337-
return TextDeltaMessageContent("")
383+
# OpenAi doesn't have types for audio deltas so we need to use hasattr and getattr.
384+
elif hasattr(content_delta, "audio") and getattr(content_delta, "audio") is not None:
385+
audio_chunk: dict = getattr(content_delta, "audio")
386+
return AudioDeltaMessageContent(
387+
id=audio_chunk.get("id"),
388+
data=audio_chunk.get("data"),
389+
expires_at=audio_chunk.get("expires_at"),
390+
transcript=audio_chunk.get("transcript"),
391+
)
392+
return None

griptape/events/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .finish_structure_run_event import FinishStructureRunEvent
1313
from .base_chunk_event import BaseChunkEvent
1414
from .text_chunk_event import TextChunkEvent
15+
from .audio_chunk_event import AudioChunkEvent
1516
from .action_chunk_event import ActionChunkEvent
1617
from .event_listener import EventListener
1718
from .start_image_generation_event import StartImageGenerationEvent
@@ -41,6 +42,7 @@
4142
"FinishStructureRunEvent",
4243
"BaseChunkEvent",
4344
"TextChunkEvent",
45+
"AudioChunkEvent",
4446
"ActionChunkEvent",
4547
"EventListener",
4648
"StartImageGenerationEvent",

griptape/events/audio_chunk_event.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from attrs import define, field
2+
3+
from griptape.events.base_chunk_event import BaseChunkEvent
4+
5+
6+
@define
7+
class AudioChunkEvent(BaseChunkEvent):
8+
"""Stores a chunk of audio data.
9+
10+
Attributes:
11+
data: Base64 encoded audio data.
12+
"""
13+
14+
data: str = field(kw_only=True, metadata={"serializable": True})
15+
16+
def __str__(self) -> str:
17+
return self.data

0 commit comments

Comments
 (0)