Skip to content

Commit 8e4b796

Browse files
authored
Chore Support base64 embedding format (#485)
* chore: support base64 embedding format * fix: add sizeOfFloat32 * chore: refactor base64 decoding * chore: add tests * fix linting * fix test * fix return error * fix: use smaller slice for tests * fix [skip ci] * chore: refactor test to consider CreateEmbeddings response * trigger build * chore: remove named returns * chore: refactor code to simplify the understanding * chore: tests have been refactored to match the encoding format passed by request * chore: fix tests * fix * fix
1 parent 3589837 commit 8e4b796

File tree

2 files changed

+229
-18
lines changed

2 files changed

+229
-18
lines changed

embeddings.go

Lines changed: 105 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package openai
22

33
import (
44
"context"
5+
"encoding/base64"
6+
"encoding/binary"
7+
"math"
58
"net/http"
69
)
710

@@ -129,15 +132,83 @@ type EmbeddingResponse struct {
129132
Usage Usage `json:"usage"`
130133
}
131134

135+
type base64String string
136+
137+
func (b base64String) Decode() ([]float32, error) {
138+
decodedData, err := base64.StdEncoding.DecodeString(string(b))
139+
if err != nil {
140+
return nil, err
141+
}
142+
143+
const sizeOfFloat32 = 4
144+
floats := make([]float32, len(decodedData)/sizeOfFloat32)
145+
for i := 0; i < len(floats); i++ {
146+
floats[i] = math.Float32frombits(binary.LittleEndian.Uint32(decodedData[i*4 : (i+1)*4]))
147+
}
148+
149+
return floats, nil
150+
}
151+
152+
// Base64Embedding is a container for base64 encoded embeddings.
153+
type Base64Embedding struct {
154+
Object string `json:"object"`
155+
Embedding base64String `json:"embedding"`
156+
Index int `json:"index"`
157+
}
158+
159+
// EmbeddingResponseBase64 is the response from a Create embeddings request with base64 encoding format.
160+
type EmbeddingResponseBase64 struct {
161+
Object string `json:"object"`
162+
Data []Base64Embedding `json:"data"`
163+
Model EmbeddingModel `json:"model"`
164+
Usage Usage `json:"usage"`
165+
}
166+
167+
// ToEmbeddingResponse converts an embeddingResponseBase64 to an EmbeddingResponse.
168+
func (r *EmbeddingResponseBase64) ToEmbeddingResponse() (EmbeddingResponse, error) {
169+
data := make([]Embedding, len(r.Data))
170+
171+
for i, base64Embedding := range r.Data {
172+
embedding, err := base64Embedding.Embedding.Decode()
173+
if err != nil {
174+
return EmbeddingResponse{}, err
175+
}
176+
177+
data[i] = Embedding{
178+
Object: base64Embedding.Object,
179+
Embedding: embedding,
180+
Index: base64Embedding.Index,
181+
}
182+
}
183+
184+
return EmbeddingResponse{
185+
Object: r.Object,
186+
Model: r.Model,
187+
Data: data,
188+
Usage: r.Usage,
189+
}, nil
190+
}
191+
132192
type EmbeddingRequestConverter interface {
133193
// Needs to be of type EmbeddingRequestStrings or EmbeddingRequestTokens
134194
Convert() EmbeddingRequest
135195
}
136196

197+
// EmbeddingEncodingFormat is the format of the embeddings data.
198+
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
199+
// If not specified OpenAI will use "float".
200+
type EmbeddingEncodingFormat string
201+
202+
const (
203+
EmbeddingEncodingFormatFloat EmbeddingEncodingFormat = "float"
204+
EmbeddingEncodingFormatBase64 EmbeddingEncodingFormat = "base64"
205+
)
206+
137207
type EmbeddingRequest struct {
138-
Input any `json:"input"`
139-
Model EmbeddingModel `json:"model"`
140-
User string `json:"user"`
208+
Input any `json:"input"`
209+
Model EmbeddingModel `json:"model"`
210+
User string `json:"user"`
211+
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
141212
}
142213

143214
func (r EmbeddingRequest) Convert() EmbeddingRequest {
@@ -158,13 +229,18 @@ type EmbeddingRequestStrings struct {
158229
Model EmbeddingModel `json:"model"`
159230
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
160231
User string `json:"user"`
232+
// EmbeddingEncodingFormat is the format of the embeddings data.
233+
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
234+
// If not specified OpenAI will use "float".
235+
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
161236
}
162237

163238
func (r EmbeddingRequestStrings) Convert() EmbeddingRequest {
164239
return EmbeddingRequest{
165-
Input: r.Input,
166-
Model: r.Model,
167-
User: r.User,
240+
Input: r.Input,
241+
Model: r.Model,
242+
User: r.User,
243+
EncodingFormat: r.EncodingFormat,
168244
}
169245
}
170246

@@ -181,13 +257,18 @@ type EmbeddingRequestTokens struct {
181257
Model EmbeddingModel `json:"model"`
182258
// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
183259
User string `json:"user"`
260+
// EmbeddingEncodingFormat is the format of the embeddings data.
261+
// Currently, only "float" and "base64" are supported, however, "base64" is not officially documented.
262+
// If not specified OpenAI will use "float".
263+
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
184264
}
185265

186266
func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
187267
return EmbeddingRequest{
188-
Input: r.Input,
189-
Model: r.Model,
190-
User: r.User,
268+
Input: r.Input,
269+
Model: r.Model,
270+
User: r.User,
271+
EncodingFormat: r.EncodingFormat,
191272
}
192273
}
193274

@@ -196,14 +277,27 @@ func (r EmbeddingRequestTokens) Convert() EmbeddingRequest {
196277
//
197278
// Body should be of type EmbeddingRequestStrings for embedding strings or EmbeddingRequestTokens
198279
// for embedding groups of text already converted to tokens.
199-
func (c *Client) CreateEmbeddings(ctx context.Context, conv EmbeddingRequestConverter) (res EmbeddingResponse, err error) { //nolint:lll
280+
func (c *Client) CreateEmbeddings(
281+
ctx context.Context,
282+
conv EmbeddingRequestConverter,
283+
) (res EmbeddingResponse, err error) {
200284
baseReq := conv.Convert()
201285
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq))
202286
if err != nil {
203287
return
204288
}
205289

206-
err = c.sendRequest(req, &res)
290+
if baseReq.EncodingFormat != EmbeddingEncodingFormatBase64 {
291+
err = c.sendRequest(req, &res)
292+
return
293+
}
294+
295+
base64Response := &EmbeddingResponseBase64{}
296+
err = c.sendRequest(req, base64Response)
297+
if err != nil {
298+
return
299+
}
207300

301+
res, err = base64Response.ToEmbeddingResponse()
208302
return
209303
}

embeddings_test.go

Lines changed: 124 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
package openai_test
22

33
import (
4-
. "github.com/sashabaranov/go-openai"
5-
"github.com/sashabaranov/go-openai/internal/test/checks"
6-
74
"bytes"
85
"context"
96
"encoding/json"
107
"fmt"
118
"net/http"
9+
"reflect"
1210
"testing"
11+
12+
. "github.com/sashabaranov/go-openai"
13+
"github.com/sashabaranov/go-openai/internal/test/checks"
1314
)
1415

1516
func TestEmbedding(t *testing.T) {
@@ -97,22 +98,138 @@ func TestEmbeddingModel(t *testing.T) {
9798
func TestEmbeddingEndpoint(t *testing.T) {
9899
client, server, teardown := setupOpenAITestServer()
99100
defer teardown()
101+
102+
sampleEmbeddings := []Embedding{
103+
{Embedding: []float32{1.23, 4.56, 7.89}},
104+
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
105+
}
106+
107+
sampleBase64Embeddings := []Base64Embedding{
108+
{Embedding: "pHCdP4XrkUDhevxA"},
109+
{Embedding: "/1jku0G/rLvA/EI8"},
110+
}
111+
100112
server.RegisterHandler(
101113
"/v1/embeddings",
102114
func(w http.ResponseWriter, r *http.Request) {
103-
resBytes, _ := json.Marshal(EmbeddingResponse{})
115+
var req struct {
116+
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"`
117+
User string `json:"user"`
118+
}
119+
_ = json.NewDecoder(r.Body).Decode(&req)
120+
121+
var resBytes []byte
122+
switch {
123+
case req.User == "invalid":
124+
w.WriteHeader(http.StatusBadRequest)
125+
return
126+
case req.EncodingFormat == EmbeddingEncodingFormatBase64:
127+
resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings})
128+
default:
129+
resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings})
130+
}
104131
fmt.Fprintln(w, string(resBytes))
105132
},
106133
)
107134
// test create embeddings with strings (simple embedding request)
108-
_, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
135+
res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
136+
checks.NoError(t, err, "CreateEmbeddings error")
137+
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
138+
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
139+
}
140+
141+
// test create embeddings with strings (simple embedding request)
142+
res, err = client.CreateEmbeddings(
143+
context.Background(),
144+
EmbeddingRequest{
145+
EncodingFormat: EmbeddingEncodingFormatBase64,
146+
},
147+
)
109148
checks.NoError(t, err, "CreateEmbeddings error")
149+
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
150+
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
151+
}
110152

111153
// test create embeddings with strings
112-
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{})
154+
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{})
113155
checks.NoError(t, err, "CreateEmbeddings strings error")
156+
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
157+
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
158+
}
114159

115160
// test create embeddings with tokens
116-
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{})
161+
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{})
117162
checks.NoError(t, err, "CreateEmbeddings tokens error")
163+
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
164+
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
165+
}
166+
167+
// test failed sendRequest
168+
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{
169+
User: "invalid",
170+
EncodingFormat: EmbeddingEncodingFormatBase64,
171+
})
172+
checks.HasError(t, err, "CreateEmbeddings error")
173+
}
174+
175+
func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
176+
type fields struct {
177+
Object string
178+
Data []Base64Embedding
179+
Model EmbeddingModel
180+
Usage Usage
181+
}
182+
tests := []struct {
183+
name string
184+
fields fields
185+
want EmbeddingResponse
186+
wantErr bool
187+
}{
188+
{
189+
name: "test embedding response base64 to embedding response",
190+
fields: fields{
191+
Data: []Base64Embedding{
192+
{Embedding: "pHCdP4XrkUDhevxA"},
193+
{Embedding: "/1jku0G/rLvA/EI8"},
194+
},
195+
},
196+
want: EmbeddingResponse{
197+
Data: []Embedding{
198+
{Embedding: []float32{1.23, 4.56, 7.89}},
199+
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
200+
},
201+
},
202+
wantErr: false,
203+
},
204+
{
205+
name: "Invalid embedding",
206+
fields: fields{
207+
Data: []Base64Embedding{
208+
{
209+
Embedding: "----",
210+
},
211+
},
212+
},
213+
want: EmbeddingResponse{},
214+
wantErr: true,
215+
},
216+
}
217+
for _, tt := range tests {
218+
t.Run(tt.name, func(t *testing.T) {
219+
r := &EmbeddingResponseBase64{
220+
Object: tt.fields.Object,
221+
Data: tt.fields.Data,
222+
Model: tt.fields.Model,
223+
Usage: tt.fields.Usage,
224+
}
225+
got, err := r.ToEmbeddingResponse()
226+
if (err != nil) != tt.wantErr {
227+
t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() error = %v, wantErr %v", err, tt.wantErr)
228+
return
229+
}
230+
if !reflect.DeepEqual(got, tt.want) {
231+
t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() = %v, want %v", got, tt.want)
232+
}
233+
})
234+
}
118235
}

0 commit comments

Comments
 (0)