1
1
from __future__ import annotations
2
2
3
+ import base64
3
4
import json
4
5
import logging
6
+ import time
5
7
from typing import TYPE_CHECKING , Optional
6
8
7
9
import openai
8
10
from attrs import Factory , define , field
9
11
from schema import Schema
10
12
11
- from griptape .artifacts import ActionArtifact , TextArtifact
13
+ from griptape .artifacts import ActionArtifact , AudioArtifact , TextArtifact
12
14
from griptape .common import (
13
15
ActionCallDeltaMessageContent ,
14
16
ActionCallMessageContent ,
15
17
ActionResultMessageContent ,
18
+ AudioDeltaMessageContent ,
19
+ AudioMessageContent ,
16
20
BaseDeltaMessageContent ,
17
21
BaseMessageContent ,
18
22
DeltaMessage ,
@@ -94,6 +98,10 @@ class OpenAiChatPromptDriver(BasePromptDriver):
94
98
),
95
99
kw_only = True ,
96
100
)
101
+ modalities : list [str ] = field (default = Factory (lambda : ["text" ]), kw_only = True , metadata = {"serializable" : True })
102
+ audio : dict = field (
103
+ default = Factory (lambda : {"voice" : "alloy" , "format" : "pcm16" }), kw_only = True , metadata = {"serializable" : True }
104
+ )
97
105
_client : openai .OpenAI = field (default = None , kw_only = True , alias = "client" , metadata = {"serializable" : False })
98
106
99
107
@lazy_property ()
@@ -144,14 +152,18 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
144
152
choice = chunk .choices [0 ]
145
153
delta = choice .delta
146
154
147
- yield DeltaMessage (content = self .__to_prompt_stack_delta_message_content (delta ))
155
+ content = self .__to_prompt_stack_delta_message_content (delta )
156
+ if content is not None :
157
+ yield DeltaMessage (content = content )
148
158
149
159
def _base_params (self , prompt_stack : PromptStack ) -> dict :
150
160
params = {
151
161
"model" : self .model ,
152
162
"temperature" : self .temperature ,
153
163
"user" : self .user ,
154
164
"seed" : self .seed ,
165
+ "modalities" : self .modalities ,
166
+ "audio" : self .audio ,
155
167
** ({"stop" : self .tokenizer .stop_sequences } if self .tokenizer .stop_sequences else {}),
156
168
** ({"max_tokens" : self .max_tokens } if self .max_tokens is not None else {}),
157
169
** ({"stream_options" : {"include_usage" : True }} if self .stream else {}),
@@ -196,45 +208,44 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
196
208
openai_messages = []
197
209
198
210
for message in messages :
199
- # If the message only contains textual content we can send it as a single content.
200
- if message .is_text ():
201
- openai_messages .append ({"role" : self .__to_openai_role (message ), "content" : message .to_text ()})
202
211
# Action results must be sent as separate messages.
203
- elif message .has_any_content_type (ActionResultMessageContent ):
212
+
213
+ action_result_contents = message .get_content_type (ActionResultMessageContent )
214
+ # Action results must be sent as separate messages.
215
+ if action_result_contents :
204
216
openai_messages .extend (
205
217
{
206
- "role" : self .__to_openai_role (message , action_result ),
207
- "content" : self .__to_openai_message_content (action_result ),
208
- "tool_call_id" : action_result .action .tag ,
218
+ "role" : self .__to_openai_role (message , action_result_content ),
219
+ "content" : self .__to_openai_message_content (action_result_content ),
220
+ "tool_call_id" : action_result_content .action .tag ,
209
221
}
210
- for action_result in message . get_content_type ( ActionResultMessageContent )
222
+ for action_result_content in action_result_contents
211
223
)
212
224
213
225
if message .has_any_content_type (TextMessageContent ):
214
226
openai_messages .append ({"role" : self .__to_openai_role (message ), "content" : message .to_text ()})
215
227
else :
216
228
openai_message = {
217
229
"role" : self .__to_openai_role (message ),
218
- "content" : [
219
- self .__to_openai_message_content (content )
220
- for content in [
221
- content for content in message .content if not isinstance (content , ActionCallMessageContent )
222
- ]
223
- ],
230
+ "content" : [],
224
231
}
232
+
233
+ for content in message .content :
234
+ if isinstance (content , ActionCallMessageContent ):
235
+ if "tool_calls" not in openai_message :
236
+ openai_message ["tool_calls" ] = []
237
+ openai_message ["tool_calls" ].append (self .__to_openai_message_content (content ))
238
+ elif isinstance (content , AudioMessageContent ) and message .is_assistant ():
239
+ openai_message ["audio" ] = {
240
+ "id" : content .artifact .meta ["audio_id" ],
241
+ }
242
+ else :
243
+ openai_message ["content" ].append (self .__to_openai_message_content (content ))
244
+
225
245
# Some OpenAi-compatible services don't accept an empty array for content
226
246
if not openai_message ["content" ]:
227
247
openai_message ["content" ] = ""
228
248
229
- # Action calls must be attached to the message, not sent as content.
230
- action_call_content = [
231
- content for content in message .content if isinstance (content , ActionCallMessageContent )
232
- ]
233
- if action_call_content :
234
- openai_message ["tool_calls" ] = [
235
- self .__to_openai_message_content (action_call ) for action_call in action_call_content
236
- ]
237
-
238
249
openai_messages .append (openai_message )
239
250
240
251
return openai_messages
@@ -272,6 +283,23 @@ def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict
272
283
"type" : "image_url" ,
273
284
"image_url" : {"url" : f"data:{ content .artifact .mime_type } ;base64,{ content .artifact .base64 } " },
274
285
}
286
+ elif isinstance (content , AudioMessageContent ):
287
+ artifact = content .artifact
288
+
289
+ # We can't send the audio if it's expired.
290
+ if int (time .time ()) > artifact .meta .get ("expires_at" , float ("inf" )):
291
+ return {
292
+ "type" : "text" ,
293
+ "text" : artifact .meta .get ("transcript" ),
294
+ }
295
+ else :
296
+ return {
297
+ "type" : "input_audio" ,
298
+ "input_audio" : {
299
+ "data" : base64 .b64encode (artifact .value ).decode ("utf-8" ),
300
+ "format" : artifact .format ,
301
+ },
302
+ }
275
303
elif isinstance (content , ActionCallMessageContent ):
276
304
action = content .artifact .value
277
305
@@ -290,6 +318,20 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
290
318
291
319
if response .content is not None :
292
320
content .append (TextMessageContent (TextArtifact (response .content )))
321
+ if response .audio is not None :
322
+ content .append (
323
+ AudioMessageContent (
324
+ AudioArtifact (
325
+ value = base64 .b64decode (response .audio .data ),
326
+ format = "wav" ,
327
+ meta = {
328
+ "audio_id" : response .audio .id ,
329
+ "transcript" : response .audio .transcript ,
330
+ "expires_at" : response .audio .expires_at ,
331
+ },
332
+ )
333
+ )
334
+ )
293
335
if response .tool_calls is not None :
294
336
content .extend (
295
337
[
@@ -309,7 +351,7 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->
309
351
310
352
return content
311
353
312
- def __to_prompt_stack_delta_message_content (self , content_delta : ChoiceDelta ) -> BaseDeltaMessageContent :
354
+ def __to_prompt_stack_delta_message_content (self , content_delta : ChoiceDelta ) -> Optional [ BaseDeltaMessageContent ] :
313
355
if content_delta .content is not None :
314
356
return TextDeltaMessageContent (content_delta .content )
315
357
elif content_delta .tool_calls is not None :
@@ -334,5 +376,13 @@ def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) ->
334
376
raise ValueError (f"Unsupported tool call delta: { tool_call } " )
335
377
else :
336
378
raise ValueError (f"Unsupported tool call delta length: { len (tool_calls )} " )
337
- else :
338
- return TextDeltaMessageContent ("" )
379
+ # OpenAi doesn't have types for audio deltas so we need to use hasattr and getattr.
380
+ elif hasattr (content_delta , "audio" ) and getattr (content_delta , "audio" ) is not None :
381
+ audio_chunk : dict = getattr (content_delta , "audio" )
382
+ return AudioDeltaMessageContent (
383
+ id = audio_chunk .get ("id" ),
384
+ data = audio_chunk .get ("data" ),
385
+ expires_at = audio_chunk .get ("expires_at" ),
386
+ transcript = audio_chunk .get ("transcript" ),
387
+ )
388
+ return None
0 commit comments