1
1
from __future__ import annotations
2
2
3
+ import base64
3
4
import json
4
5
import logging
6
+ import time
5
7
from typing import TYPE_CHECKING , Optional
6
8
7
9
import openai
8
10
from attrs import Factory , define , field
9
11
10
- from griptape .artifacts import ActionArtifact , TextArtifact
12
+ from griptape .artifacts import ActionArtifact , AudioArtifact , TextArtifact
11
13
from griptape .common import (
12
14
ActionCallDeltaMessageContent ,
13
15
ActionCallMessageContent ,
14
16
ActionResultMessageContent ,
17
+ AudioDeltaMessageContent ,
18
+ AudioMessageContent ,
15
19
BaseDeltaMessageContent ,
16
20
BaseMessageContent ,
17
21
DeltaMessage ,
@@ -93,6 +97,10 @@ class OpenAiChatPromptDriver(BasePromptDriver):
93
97
),
94
98
kw_only = True ,
95
99
)
100
+ modalities : list [str ] = field (default = Factory (lambda : ["text" ]), kw_only = True , metadata = {"serializable" : True })
101
+ audio : dict = field (
102
+ default = Factory (lambda : {"voice" : "alloy" , "format" : "pcm16" }), kw_only = True , metadata = {"serializable" : True }
103
+ )
96
104
_client : openai .OpenAI = field (default = None , kw_only = True , alias = "client" , metadata = {"serializable" : False })
97
105
98
106
@lazy_property ()
@@ -143,14 +151,19 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
143
151
choice = chunk .choices [0 ]
144
152
delta = choice .delta
145
153
146
- yield DeltaMessage (content = self .__to_prompt_stack_delta_message_content (delta ))
154
+ content = self .__to_prompt_stack_delta_message_content (delta )
155
+
156
+ if content is not None :
157
+ yield DeltaMessage (content = content )
147
158
148
159
def _base_params (self , prompt_stack : PromptStack ) -> dict :
149
160
params = {
150
161
"model" : self .model ,
151
162
"temperature" : self .temperature ,
152
163
"user" : self .user ,
153
164
"seed" : self .seed ,
165
+ "modalities" : self .modalities ,
166
+ "audio" : self .audio ,
154
167
** ({"stop" : self .tokenizer .stop_sequences } if self .tokenizer .stop_sequences else {}),
155
168
** ({"max_tokens" : self .max_tokens } if self .max_tokens is not None else {}),
156
169
** ({"stream_options" : {"include_usage" : True }} if self .stream else {}),
@@ -196,44 +209,47 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
196
209
197
210
for message in messages :
198
211
# If the message only contains textual content we can send it as a single content.
199
- if message .is_text ( ):
212
+ if message .has_all_content_type ( TextMessageContent ):
200
213
openai_messages .append ({"role" : self .__to_openai_role (message ), "content" : message .to_text ()})
201
214
# Action results must be sent as separate messages.
202
- elif message .has_any_content_type (ActionResultMessageContent ):
215
+ elif action_result_contents := message .get_content_type (ActionResultMessageContent ):
203
216
openai_messages .extend (
204
217
{
205
- "role" : self .__to_openai_role (message , action_result ),
206
- "content" : self .__to_openai_message_content (action_result ),
207
- "tool_call_id" : action_result .action .tag ,
218
+ "role" : self .__to_openai_role (message , action_result_content ),
219
+ "content" : self .__to_openai_message_content (action_result_content ),
220
+ "tool_call_id" : action_result_content .action .tag ,
208
221
}
209
- for action_result in message . get_content_type ( ActionResultMessageContent )
222
+ for action_result_content in action_result_contents
210
223
)
211
224
212
225
if message .has_any_content_type (TextMessageContent ):
213
226
openai_messages .append ({"role" : self .__to_openai_role (message ), "content" : message .to_text ()})
214
227
else :
215
228
openai_message = {
216
229
"role" : self .__to_openai_role (message ),
217
- "content" : [
218
- self .__to_openai_message_content (content )
219
- for content in [
220
- content for content in message .content if not isinstance (content , ActionCallMessageContent )
221
- ]
222
- ],
230
+ "content" : [],
223
231
}
232
+
233
+ for content in message .content :
234
+ if isinstance (content , ActionCallMessageContent ):
235
+ if "tool_calls" not in openai_message :
236
+ openai_message ["tool_calls" ] = []
237
+ openai_message ["tool_calls" ].append (self .__to_openai_message_content (content ))
238
+ elif (
239
+ isinstance (content , AudioMessageContent )
240
+ and message .is_assistant ()
241
+ and time .time () < content .artifact .meta .get ("expires_at" , float ("inf" ))
242
+ ):
243
+ openai_message ["audio" ] = {
244
+ "id" : content .artifact .meta ["audio_id" ],
245
+ }
246
+ else :
247
+ openai_message ["content" ].append (self .__to_openai_message_content (content ))
248
+
224
249
# Some OpenAi-compatible services don't accept an empty array for content
225
250
if not openai_message ["content" ]:
226
251
openai_message ["content" ] = ""
227
252
228
- # Action calls must be attached to the message, not sent as content.
229
- action_call_content = [
230
- content for content in message .content if isinstance (content , ActionCallMessageContent )
231
- ]
232
- if action_call_content :
233
- openai_message ["tool_calls" ] = [
234
- self .__to_openai_message_content (action_call ) for action_call in action_call_content
235
- ]
236
-
237
253
openai_messages .append (openai_message )
238
254
239
255
return openai_messages
@@ -271,6 +287,23 @@ def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict
271
287
"type" : "image_url" ,
272
288
"image_url" : {"url" : f"data:{ content .artifact .mime_type } ;base64,{ content .artifact .base64 } " },
273
289
}
290
+ elif isinstance (content , AudioMessageContent ):
291
+ artifact = content .artifact
292
+
293
+ # We can't send the audio if it's expired.
294
+ if int (time .time ()) > artifact .meta .get ("expires_at" , float ("inf" )):
295
+ return {
296
+ "type" : "text" ,
297
+ "text" : artifact .meta .get ("transcript" ),
298
+ }
299
+ else :
300
+ return {
301
+ "type" : "input_audio" ,
302
+ "input_audio" : {
303
+ "data" : base64 .b64encode (artifact .value ).decode ("utf-8" ),
304
+ "format" : artifact .format ,
305
+ },
306
+ }
274
307
elif isinstance (content , ActionCallMessageContent ):
275
308
action = content .artifact .value
276
309
@@ -289,6 +322,20 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
289
322
290
323
if response .content is not None :
291
324
content .append (TextMessageContent (TextArtifact (response .content )))
325
+ if response .audio is not None :
326
+ content .append (
327
+ AudioMessageContent (
328
+ AudioArtifact (
329
+ value = base64 .b64decode (response .audio .data ),
330
+ format = "wav" ,
331
+ meta = {
332
+ "audio_id" : response .audio .id ,
333
+ "transcript" : response .audio .transcript ,
334
+ "expires_at" : response .audio .expires_at ,
335
+ },
336
+ )
337
+ )
338
+ )
292
339
if response .tool_calls is not None :
293
340
content .extend (
294
341
[
@@ -308,7 +355,7 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
308
355
309
356
return content
310
357
311
- def __to_prompt_stack_delta_message_content (self , content_delta : ChoiceDelta ) -> BaseDeltaMessageContent :
358
+ def __to_prompt_stack_delta_message_content (self , content_delta : ChoiceDelta ) -> Optional [ BaseDeltaMessageContent ] :
312
359
if content_delta .content is not None :
313
360
return TextDeltaMessageContent (content_delta .content )
314
361
elif content_delta .tool_calls is not None :
@@ -333,5 +380,13 @@ def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) ->
333
380
raise ValueError (f"Unsupported tool call delta: { tool_call } " )
334
381
else :
335
382
raise ValueError (f"Unsupported tool call delta length: { len (tool_calls )} " )
336
- else :
337
- return TextDeltaMessageContent ("" )
383
+ # OpenAi doesn't have types for audio deltas so we need to use hasattr and getattr.
384
+ elif hasattr (content_delta , "audio" ) and getattr (content_delta , "audio" ) is not None :
385
+ audio_chunk : dict = getattr (content_delta , "audio" )
386
+ return AudioDeltaMessageContent (
387
+ id = audio_chunk .get ("id" ),
388
+ data = audio_chunk .get ("data" ),
389
+ expires_at = audio_chunk .get ("expires_at" ),
390
+ transcript = audio_chunk .get ("transcript" ),
391
+ )
392
+ return None
0 commit comments