Skip to content

Commit 375628e

Browse files
committed
WIP
1 parent 9162c5e commit 375628e

File tree

11 files changed

+170
-33
lines changed

11 files changed

+170
-33
lines changed

griptape/common/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
from .prompt_stack.contents.base_message_content import BaseMessageContent
55
from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent
66
from .prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent
7+
from .prompt_stack.contents.audio_delta_message_content import AudioDeltaMessageContent
8+
from .prompt_stack.contents.audio_transcript_delta_message_content import AudioTranscriptDeltaMessageContent
79
from .prompt_stack.contents.text_message_content import TextMessageContent
810
from .prompt_stack.contents.image_message_content import ImageMessageContent
11+
from .prompt_stack.contents.audio_message_content import AudioMessageContent
912
from .prompt_stack.contents.action_call_delta_message_content import ActionCallDeltaMessageContent
1013
from .prompt_stack.contents.action_call_message_content import ActionCallMessageContent
1114
from .prompt_stack.contents.action_result_message_content import ActionResultMessageContent
@@ -30,8 +33,11 @@
3033
"DeltaMessage",
3134
"Message",
3235
"TextDeltaMessageContent",
36+
"AudioDeltaMessageContent",
37+
"AudioTranscriptDeltaMessageContent",
3338
"TextMessageContent",
3439
"ImageMessageContent",
40+
"AudioMessageContent",
3541
"GenericMessageContent",
3642
"ActionCallDeltaMessageContent",
3743
"ActionCallMessageContent",
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from __future__ import annotations
2+
3+
from typing import Optional
4+
5+
from attrs import define, field
6+
7+
from griptape.common import BaseDeltaMessageContent
8+
9+
10+
@define
11+
class AudioDeltaMessageContent(BaseDeltaMessageContent):
12+
id: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
13+
data: bytes = field(kw_only=True, metadata={"serializable": True})
14+
transcript: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from attrs import define, field
6+
7+
from griptape.artifacts import AudioArtifact
8+
from griptape.common import (
9+
AudioDeltaMessageContent,
10+
AudioTranscriptDeltaMessageContent,
11+
BaseDeltaMessageContent,
12+
BaseMessageContent,
13+
)
14+
15+
if TYPE_CHECKING:
16+
from collections.abc import Sequence
17+
18+
19+
@define
20+
class AudioMessageContent(BaseMessageContent):
21+
artifact: AudioArtifact = field(metadata={"serializable": True})
22+
23+
@classmethod
24+
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> AudioMessageContent:
25+
audio_deltas = [delta for delta in deltas if isinstance(delta, AudioDeltaMessageContent)]
26+
audio_transcript_deltas = [delta for delta in deltas if isinstance(delta, AudioTranscriptDeltaMessageContent)]
27+
audio_id = audio_deltas[0].id
28+
29+
audio_transcript = "".join(delta.text for delta in audio_transcript_deltas)
30+
31+
artifact = AudioArtifact(
32+
value=b"".join(delta.data for delta in audio_deltas),
33+
format="wav",
34+
meta={
35+
"audio_id": audio_id,
36+
"transcript": audio_transcript,
37+
},
38+
)
39+
40+
return cls(artifact=artifact)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from __future__ import annotations
2+
3+
from attrs import define, field
4+
5+
from griptape.common import BaseDeltaMessageContent
6+
7+
8+
@define
9+
class AudioTranscriptDeltaMessageContent(BaseDeltaMessageContent):
10+
text: str = field(metadata={"serializable": True})

griptape/common/prompt_stack/messages/message.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def has_any_content_type(self, content_type: type[T]) -> bool:
3737
def get_content_type(self, content_type: type[T]) -> list[T]:
3838
return [content for content in self.content if isinstance(content, content_type)]
3939

40+
def exclude_content_type(self, content_type: type[T] | tuple[type[T]]) -> list[BaseMessageContent]:
41+
return [content for content in self.content if not isinstance(content, content_type)]
42+
4043
def is_text(self) -> bool:
4144
return all(isinstance(content, TextMessageContent) for content in self.content)
4245

griptape/common/prompt_stack/prompt_stack.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from griptape.artifacts import (
88
ActionArtifact,
9+
AudioArtifact,
910
BaseArtifact,
1011
GenericArtifact,
1112
ImageArtifact,
@@ -15,6 +16,7 @@
1516
from griptape.common import (
1617
ActionCallMessageContent,
1718
ActionResultMessageContent,
19+
AudioMessageContent,
1820
BaseMessageContent,
1921
GenericMessageContent,
2022
ImageMessageContent,
@@ -77,6 +79,8 @@ def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessage
7779
return [TextMessageContent(artifact)]
7880
elif isinstance(artifact, ImageArtifact):
7981
return [ImageMessageContent(artifact)]
82+
elif isinstance(artifact, AudioArtifact):
83+
return [AudioMessageContent(artifact)]
8084
elif isinstance(artifact, GenericArtifact):
8185
return [GenericMessageContent(artifact)]
8286
elif isinstance(artifact, ActionArtifact):

griptape/drivers/prompt/base_prompt_driver.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from griptape.common import (
1010
ActionCallDeltaMessageContent,
1111
ActionCallMessageContent,
12+
AudioDeltaMessageContent,
13+
AudioMessageContent,
1214
BaseDeltaMessageContent,
1315
DeltaMessage,
1416
Message,
@@ -19,6 +21,7 @@
1921
)
2022
from griptape.events import (
2123
ActionChunkEvent,
24+
AudioChunkEvent,
2225
EventBus,
2326
FinishPromptEvent,
2427
StartPromptEvent,
@@ -177,6 +180,8 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message:
177180
delta_contents[content.index] = [content]
178181
if isinstance(content, TextDeltaMessageContent):
179182
EventBus.publish_event(TextChunkEvent(token=content.text, index=content.index))
183+
elif isinstance(content, AudioDeltaMessageContent):
184+
EventBus.publish_event(AudioChunkEvent(token=content.data))
180185
elif isinstance(content, ActionCallDeltaMessageContent):
181186
EventBus.publish_event(
182187
ActionChunkEvent(
@@ -197,10 +202,13 @@ def __build_message(
197202
content = []
198203
for delta_content in delta_contents:
199204
text_deltas = [delta for delta in delta_content if isinstance(delta, TextDeltaMessageContent)]
205+
audio_deltas = [delta for delta in delta_content if isinstance(delta, AudioDeltaMessageContent)]
200206
action_deltas = [delta for delta in delta_content if isinstance(delta, ActionCallDeltaMessageContent)]
201207

202208
if text_deltas:
203209
content.append(TextMessageContent.from_deltas(text_deltas))
210+
if audio_deltas:
211+
content.append(AudioMessageContent.from_deltas(audio_deltas))
204212
if action_deltas:
205213
content.append(ActionCallMessageContent.from_deltas(action_deltas))
206214

griptape/drivers/prompt/openai_chat_prompt_driver.py

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
from __future__ import annotations
22

3+
import base64
34
import json
45
import logging
5-
from typing import TYPE_CHECKING, Optional
6+
from typing import TYPE_CHECKING, Any, Optional
67

78
import openai
89
from attrs import Factory, define, field
910
from schema import Schema
1011

11-
from griptape.artifacts import ActionArtifact, TextArtifact
12+
from griptape.artifacts import ActionArtifact, AudioArtifact, TextArtifact
1213
from griptape.common import (
1314
ActionCallDeltaMessageContent,
1415
ActionCallMessageContent,
1516
ActionResultMessageContent,
17+
AudioDeltaMessageContent,
18+
AudioMessageContent,
1619
BaseDeltaMessageContent,
1720
BaseMessageContent,
1821
DeltaMessage,
@@ -24,6 +27,9 @@
2427
ToolAction,
2528
observable,
2629
)
30+
from griptape.common.prompt_stack.contents.audio_transcript_delta_message_content import (
31+
AudioTranscriptDeltaMessageContent,
32+
)
2733
from griptape.configs.defaults_config import Defaults
2834
from griptape.drivers.prompt import BasePromptDriver
2935
from griptape.tokenizers import BaseTokenizer, OpenAiTokenizer
@@ -32,7 +38,6 @@
3238
if TYPE_CHECKING:
3339
from collections.abc import Iterator
3440

35-
from openai.types.chat.chat_completion_chunk import ChoiceDelta
3641
from openai.types.chat.chat_completion_message import ChatCompletionMessage
3742

3843
from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
@@ -132,6 +137,8 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
132137
result = self.client.chat.completions.create(**params, stream=True)
133138

134139
for chunk in result:
140+
if chunk.choices is None:
141+
continue
135142
logger.debug(chunk.model_dump())
136143
if chunk.usage is not None:
137144
yield DeltaMessage(
@@ -144,14 +151,18 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
144151
choice = chunk.choices[0]
145152
delta = choice.delta
146153

147-
yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(delta))
154+
content = self.__to_prompt_stack_delta_message_content(delta)
155+
if content is not None:
156+
yield DeltaMessage(content=content)
148157

149158
def _base_params(self, prompt_stack: PromptStack) -> dict:
150159
params = {
151160
"model": self.model,
152161
"temperature": self.temperature,
153162
"user": self.user,
154163
"seed": self.seed,
164+
"modalities": ["text", "audio"],
165+
"audio": {"voice": "alloy", "format": "pcm16"},
155166
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
156167
**({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
157168
**({"stream_options": {"include_usage": True}} if self.stream else {}),
@@ -196,45 +207,44 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
196207
openai_messages = []
197208

198209
for message in messages:
199-
# If the message only contains textual content we can send it as a single content.
200-
if message.is_text():
201-
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
202210
# Action results must be sent as separate messages.
203-
elif message.has_any_content_type(ActionResultMessageContent):
211+
212+
action_result_contents = message.get_content_type(ActionResultMessageContent)
213+
# Action results must be sent as separate messages.
214+
if action_result_contents:
204215
openai_messages.extend(
205216
{
206-
"role": self.__to_openai_role(message, action_result),
207-
"content": self.__to_openai_message_content(action_result),
208-
"tool_call_id": action_result.action.tag,
217+
"role": self.__to_openai_role(message, action_result_content),
218+
"content": self.__to_openai_message_content(action_result_content),
219+
"tool_call_id": action_result_content.action.tag,
209220
}
210-
for action_result in message.get_content_type(ActionResultMessageContent)
221+
for action_result_content in action_result_contents
211222
)
212223

213224
if message.has_any_content_type(TextMessageContent):
214225
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
215226
else:
216227
openai_message = {
217228
"role": self.__to_openai_role(message),
218-
"content": [
219-
self.__to_openai_message_content(content)
220-
for content in [
221-
content for content in message.content if not isinstance(content, ActionCallMessageContent)
222-
]
223-
],
229+
"content": [],
224230
}
231+
232+
for content in message.content:
233+
if isinstance(content, ActionCallMessageContent):
234+
if "tool_calls" not in openai_message:
235+
openai_message["tool_calls"] = []
236+
openai_message["tool_calls"].append(self.__to_openai_message_content(content))
237+
elif isinstance(content, AudioMessageContent) and message.is_assistant():
238+
openai_message["audio"] = {
239+
"id": content.artifact.meta["audio_id"],
240+
}
241+
else:
242+
openai_message["content"].append(self.__to_openai_message_content(content))
243+
225244
# Some OpenAi-compatible services don't accept an empty array for content
226245
if not openai_message["content"]:
227246
openai_message["content"] = ""
228247

229-
# Action calls must be attached to the message, not sent as content.
230-
action_call_content = [
231-
content for content in message.content if isinstance(content, ActionCallMessageContent)
232-
]
233-
if action_call_content:
234-
openai_message["tool_calls"] = [
235-
self.__to_openai_message_content(action_call) for action_call in action_call_content
236-
]
237-
238248
openai_messages.append(openai_message)
239249

240250
return openai_messages
@@ -272,6 +282,14 @@ def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict
272282
"type": "image_url",
273283
"image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"},
274284
}
285+
elif isinstance(content, AudioMessageContent):
286+
return {
287+
"type": "input_audio",
288+
"input_audio": {
289+
"data": base64.b64encode(content.artifact.value).decode("utf-8"),
290+
"format": content.artifact.format,
291+
},
292+
}
275293
elif isinstance(content, ActionCallMessageContent):
276294
action = content.artifact.value
277295

@@ -290,6 +308,19 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
290308

291309
if response.content is not None:
292310
content.append(TextMessageContent(TextArtifact(response.content)))
311+
if response.audio is not None:
312+
content.append(
313+
AudioMessageContent(
314+
AudioArtifact(
315+
value=base64.b64decode(response.audio.data),
316+
format="wav",
317+
meta={
318+
"audio_id": response.audio.id,
319+
"transcript": response.audio.transcript,
320+
},
321+
)
322+
)
323+
)
293324
if response.tool_calls is not None:
294325
content.extend(
295326
[
@@ -309,7 +340,7 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
309340

310341
return content
311342

312-
def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> BaseDeltaMessageContent:
343+
def __to_prompt_stack_delta_message_content(self, content_delta: Any) -> Optional[BaseDeltaMessageContent]:
313344
if content_delta.content is not None:
314345
return TextDeltaMessageContent(content_delta.content)
315346
elif content_delta.tool_calls is not None:
@@ -334,5 +365,12 @@ def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) ->
334365
raise ValueError(f"Unsupported tool call delta: {tool_call}")
335366
else:
336367
raise ValueError(f"Unsupported tool call delta length: {len(tool_calls)}")
337-
else:
338-
return TextDeltaMessageContent("")
368+
elif hasattr(content_delta, "audio") and content_delta.audio is not None:
369+
if "data" in content_delta.audio:
370+
return AudioDeltaMessageContent(
371+
id=content_delta.audio.get("id"),
372+
data=base64.b64decode(content_delta.audio["data"]),
373+
)
374+
elif "transcript" in content_delta.audio:
375+
return AudioTranscriptDeltaMessageContent(text=content_delta.audio["transcript"])
376+
return None

griptape/events/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .finish_structure_run_event import FinishStructureRunEvent
1313
from .base_chunk_event import BaseChunkEvent
1414
from .text_chunk_event import TextChunkEvent
15+
from .audio_chunk_event import AudioChunkEvent
1516
from .action_chunk_event import ActionChunkEvent
1617
from .event_listener import EventListener
1718
from .start_image_generation_event import StartImageGenerationEvent
@@ -41,6 +42,7 @@
4142
"FinishStructureRunEvent",
4243
"BaseChunkEvent",
4344
"TextChunkEvent",
45+
"AudioChunkEvent",
4446
"ActionChunkEvent",
4547
"EventListener",
4648
"StartImageGenerationEvent",

griptape/events/audio_chunk_event.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from attrs import define, field
2+
3+
from griptape.events.base_chunk_event import BaseChunkEvent
4+
5+
6+
@define
7+
class AudioChunkEvent(BaseChunkEvent):
8+
token: bytes = field(kw_only=True, metadata={"serializable": True})
9+
10+
def __str__(self) -> str:
11+
return self.token.decode()

0 commit comments

Comments
 (0)