1
1
from __future__ import annotations
2
2
3
+ import base64
3
4
import json
4
5
import logging
5
- from typing import TYPE_CHECKING , Optional
6
+ from typing import TYPE_CHECKING , Any , Optional
6
7
7
8
import openai
8
9
from attrs import Factory , define , field
9
10
from schema import Schema
10
11
11
- from griptape .artifacts import ActionArtifact , TextArtifact
12
+ from griptape .artifacts import ActionArtifact , AudioArtifact , TextArtifact
12
13
from griptape .common import (
13
14
ActionCallDeltaMessageContent ,
14
15
ActionCallMessageContent ,
15
16
ActionResultMessageContent ,
17
+ AudioDeltaMessageContent ,
18
+ AudioMessageContent ,
16
19
BaseDeltaMessageContent ,
17
20
BaseMessageContent ,
18
21
DeltaMessage ,
24
27
ToolAction ,
25
28
observable ,
26
29
)
30
+ from griptape .common .prompt_stack .contents .audio_transcript_delta_message_content import (
31
+ AudioTranscriptDeltaMessageContent ,
32
+ )
27
33
from griptape .configs .defaults_config import Defaults
28
34
from griptape .drivers .prompt import BasePromptDriver
29
35
from griptape .tokenizers import BaseTokenizer , OpenAiTokenizer
32
38
if TYPE_CHECKING :
33
39
from collections .abc import Iterator
34
40
35
- from openai .types .chat .chat_completion_chunk import ChoiceDelta
36
41
from openai .types .chat .chat_completion_message import ChatCompletionMessage
37
42
38
43
from griptape .drivers .prompt .base_prompt_driver import StructuredOutputStrategy
@@ -132,6 +137,8 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
132
137
result = self .client .chat .completions .create (** params , stream = True )
133
138
134
139
for chunk in result :
140
+ if chunk .choices is None :
141
+ continue
135
142
logger .debug (chunk .model_dump ())
136
143
if chunk .usage is not None :
137
144
yield DeltaMessage (
@@ -144,14 +151,18 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
144
151
choice = chunk .choices [0 ]
145
152
delta = choice .delta
146
153
147
- yield DeltaMessage (content = self .__to_prompt_stack_delta_message_content (delta ))
154
+ content = self .__to_prompt_stack_delta_message_content (delta )
155
+ if content is not None :
156
+ yield DeltaMessage (content = content )
148
157
149
158
def _base_params (self , prompt_stack : PromptStack ) -> dict :
150
159
params = {
151
160
"model" : self .model ,
152
161
"temperature" : self .temperature ,
153
162
"user" : self .user ,
154
163
"seed" : self .seed ,
164
+ "modalities" : ["text" , "audio" ],
165
+ "audio" : {"voice" : "alloy" , "format" : "pcm16" },
155
166
** ({"stop" : self .tokenizer .stop_sequences } if self .tokenizer .stop_sequences else {}),
156
167
** ({"max_tokens" : self .max_tokens } if self .max_tokens is not None else {}),
157
168
** ({"stream_options" : {"include_usage" : True }} if self .stream else {}),
@@ -196,45 +207,44 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
196
207
openai_messages = []
197
208
198
209
for message in messages :
199
- # If the message only contains textual content we can send it as a single content.
200
- if message .is_text ():
201
- openai_messages .append ({"role" : self .__to_openai_role (message ), "content" : message .to_text ()})
202
210
# Action results must be sent as separate messages.
203
- elif message .has_any_content_type (ActionResultMessageContent ):
211
+
212
+ action_result_contents = message .get_content_type (ActionResultMessageContent )
213
+ # Action results must be sent as separate messages.
214
+ if action_result_contents :
204
215
openai_messages .extend (
205
216
{
206
- "role" : self .__to_openai_role (message , action_result ),
207
- "content" : self .__to_openai_message_content (action_result ),
208
- "tool_call_id" : action_result .action .tag ,
217
+ "role" : self .__to_openai_role (message , action_result_content ),
218
+ "content" : self .__to_openai_message_content (action_result_content ),
219
+ "tool_call_id" : action_result_content .action .tag ,
209
220
}
210
- for action_result in message . get_content_type ( ActionResultMessageContent )
221
+ for action_result_content in action_result_contents
211
222
)
212
223
213
224
if message .has_any_content_type (TextMessageContent ):
214
225
openai_messages .append ({"role" : self .__to_openai_role (message ), "content" : message .to_text ()})
215
226
else :
216
227
openai_message = {
217
228
"role" : self .__to_openai_role (message ),
218
- "content" : [
219
- self .__to_openai_message_content (content )
220
- for content in [
221
- content for content in message .content if not isinstance (content , ActionCallMessageContent )
222
- ]
223
- ],
229
+ "content" : [],
224
230
}
231
+
232
+ for content in message .content :
233
+ if isinstance (content , ActionCallMessageContent ):
234
+ if "tool_calls" not in openai_message :
235
+ openai_message ["tool_calls" ] = []
236
+ openai_message ["tool_calls" ].append (self .__to_openai_message_content (content ))
237
+ elif isinstance (content , AudioMessageContent ) and message .is_assistant ():
238
+ openai_message ["audio" ] = {
239
+ "id" : content .artifact .meta ["audio_id" ],
240
+ }
241
+ else :
242
+ openai_message ["content" ].append (self .__to_openai_message_content (content ))
243
+
225
244
# Some OpenAi-compatible services don't accept an empty array for content
226
245
if not openai_message ["content" ]:
227
246
openai_message ["content" ] = ""
228
247
229
- # Action calls must be attached to the message, not sent as content.
230
- action_call_content = [
231
- content for content in message .content if isinstance (content , ActionCallMessageContent )
232
- ]
233
- if action_call_content :
234
- openai_message ["tool_calls" ] = [
235
- self .__to_openai_message_content (action_call ) for action_call in action_call_content
236
- ]
237
-
238
248
openai_messages .append (openai_message )
239
249
240
250
return openai_messages
@@ -272,6 +282,14 @@ def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict
272
282
"type" : "image_url" ,
273
283
"image_url" : {"url" : f"data:{ content .artifact .mime_type } ;base64,{ content .artifact .base64 } " },
274
284
}
285
+ elif isinstance (content , AudioMessageContent ):
286
+ return {
287
+ "type" : "input_audio" ,
288
+ "input_audio" : {
289
+ "data" : base64 .b64encode (content .artifact .value ).decode ("utf-8" ),
290
+ "format" : content .artifact .format ,
291
+ },
292
+ }
275
293
elif isinstance (content , ActionCallMessageContent ):
276
294
action = content .artifact .value
277
295
@@ -290,6 +308,19 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
290
308
291
309
if response .content is not None :
292
310
content .append (TextMessageContent (TextArtifact (response .content )))
311
+ if response .audio is not None :
312
+ content .append (
313
+ AudioMessageContent (
314
+ AudioArtifact (
315
+ value = base64 .b64decode (response .audio .data ),
316
+ format = "wav" ,
317
+ meta = {
318
+ "audio_id" : response .audio .id ,
319
+ "transcript" : response .audio .transcript ,
320
+ },
321
+ )
322
+ )
323
+ )
293
324
if response .tool_calls is not None :
294
325
content .extend (
295
326
[
@@ -309,7 +340,7 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
309
340
310
341
return content
311
342
312
- def __to_prompt_stack_delta_message_content (self , content_delta : ChoiceDelta ) -> BaseDeltaMessageContent :
343
+ def __to_prompt_stack_delta_message_content (self , content_delta : Any ) -> Optional [ BaseDeltaMessageContent ] :
313
344
if content_delta .content is not None :
314
345
return TextDeltaMessageContent (content_delta .content )
315
346
elif content_delta .tool_calls is not None :
@@ -334,5 +365,12 @@ def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) ->
334
365
raise ValueError (f"Unsupported tool call delta: { tool_call } " )
335
366
else :
336
367
raise ValueError (f"Unsupported tool call delta length: { len (tool_calls )} " )
337
- else :
338
- return TextDeltaMessageContent ("" )
368
+ elif hasattr (content_delta , "audio" ) and content_delta .audio is not None :
369
+ if "data" in content_delta .audio :
370
+ return AudioDeltaMessageContent (
371
+ id = content_delta .audio .get ("id" ),
372
+ data = base64 .b64decode (content_delta .audio ["data" ]),
373
+ )
374
+ elif "transcript" in content_delta .audio :
375
+ return AudioTranscriptDeltaMessageContent (text = content_delta .audio ["transcript" ])
376
+ return None
0 commit comments