Skip to content

Commit 4bc217e

Browse files
chore: allow image data to be in prompt input
1 parent ed71575 commit 4bc217e

File tree

3 files changed

+82
-5
lines changed

3 files changed

+82
-5
lines changed

pkg/openai/client.go

+30-4
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,7 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
281281
chatMessage.ToolCalls = append(chatMessage.ToolCalls, toToolCall(*content.ToolCall))
282282
}
283283
if content.Text != "" {
284-
chatMessage.MultiContent = append(chatMessage.MultiContent, openai.ChatMessagePart{
285-
Type: openai.ChatMessagePartTypeText,
286-
Text: content.Text,
287-
})
284+
chatMessage.MultiContent = append(chatMessage.MultiContent, textToMultiContent(content.Text)...)
288285
}
289286
}
290287

@@ -306,6 +303,35 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
306303
return
307304
}
308305

306+
const imagePrefix = "data:image/png;base64,"
307+
308+
func textToMultiContent(text string) []openai.ChatMessagePart {
309+
var chatParts []openai.ChatMessagePart
310+
parts := strings.Split(text, "\n")
311+
for i := len(parts) - 1; i >= 0; i-- {
312+
if strings.HasPrefix(parts[i], imagePrefix) {
313+
chatParts = append(chatParts, openai.ChatMessagePart{
314+
Type: openai.ChatMessagePartTypeImageURL,
315+
ImageURL: &openai.ChatMessageImageURL{
316+
URL: parts[i],
317+
},
318+
})
319+
parts = parts[:i]
320+
} else {
321+
break
322+
}
323+
}
324+
if len(parts) > 0 {
325+
chatParts = append(chatParts, openai.ChatMessagePart{
326+
Type: openai.ChatMessagePartTypeText,
327+
Text: strings.Join(parts, "\n"),
328+
})
329+
}
330+
331+
slices.Reverse(chatParts)
332+
return chatParts
333+
}
334+
309335
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
310336
if err := c.ValidAuth(); err != nil {
311337
if err := c.RetrieveAPIKey(ctx, env); err != nil {

pkg/openai/client_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,44 @@ import (
99
"github.com/hexops/valast"
1010
)
1111

12+
func TestTextToMultiContent(t *testing.T) {
13+
autogold.Expect([]openai.ChatMessagePart{{
14+
Type: "text",
15+
Text: "hi\ndata:image/png;base64,xxxxx\n",
16+
}}).Equal(t, textToMultiContent("hi\ndata:image/png;base64,xxxxx\n"))
17+
18+
autogold.Expect([]openai.ChatMessagePart{
19+
{
20+
Type: "text",
21+
Text: "hi",
22+
},
23+
{
24+
Type: "image_url",
25+
ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,xxxxx"},
26+
},
27+
}).Equal(t, textToMultiContent("hi\ndata:image/png;base64,xxxxx"))
28+
29+
autogold.Expect([]openai.ChatMessagePart{{
30+
Type: "image_url",
31+
ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,xxxxx"},
32+
}}).Equal(t, textToMultiContent("data:image/png;base64,xxxxx"))
33+
34+
autogold.Expect([]openai.ChatMessagePart{
35+
{
36+
Type: "text",
37+
Text: "\none\ntwo",
38+
},
39+
{
40+
Type: "image_url",
41+
ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,xxxxx"},
42+
},
43+
{
44+
Type: "image_url",
45+
ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,yyyyy"},
46+
},
47+
}).Equal(t, textToMultiContent("\none\ntwo\ndata:image/png;base64,xxxxx\ndata:image/png;base64,yyyyy"))
48+
}
49+
1250
func Test_appendMessage(t *testing.T) {
1351
autogold.Expect(types.CompletionMessage{Content: []types.ContentPart{
1452
{ToolCall: &types.CompletionToolCall{

pkg/runner/runner.go

+14-1
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,17 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher {
651651
return newParallelDispatcher(ctx)
652652
}
653653

654+
func idForToolCall(id string, state *engine.Return) string {
655+
if state == nil || state.State == nil {
656+
return id
657+
}
658+
tc, ok := state.State.Pending[id]
659+
if !ok || tc.Index == nil {
660+
return id
661+
}
662+
return fmt.Sprintf("%03d", *tc.Index)
663+
}
664+
654665
func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State, toolCategory engine.ToolCategory) (_ *State, callResults []SubCallResult, _ error) {
655666
var resultLock sync.Mutex
656667

@@ -693,7 +704,9 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
693704

694705
// Sort the id so if sequential the results are predictable
695706
ids := maps.Keys(state.Continuation.Calls)
696-
sort.Strings(ids)
707+
sort.Slice(ids, func(i, j int) bool {
708+
return idForToolCall(ids[i], state.Continuation) < idForToolCall(ids[j], state.Continuation)
709+
})
697710

698711
for _, id := range ids {
699712
call := state.Continuation.Calls[id]

0 commit comments

Comments
 (0)