Skip to content

Commit 84f77a0

Browse files
authored
Add DotProduct Method and README Example for Embedding Similarity Search (#492)
* Add DotProduct Method and README Example for Embedding Similarity Search - Implement a DotProduct() method for the Embedding struct to calculate the dot product between two embeddings. - Add a custom error type for vector length mismatch. - Update README.md with a complete example demonstrating how to perform an embedding similarity search for user queries. - Add unit tests to validate the new DotProduct() method and error handling. * Update README to focus on Embedding Semantic Similarity
1 parent 0d5256f commit 84f77a0

File tree

3 files changed

+114
-0
lines changed

3 files changed

+114
-0
lines changed

README.md

+56
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,62 @@ func main() {
483483
```
484484
</details>
485485

486+
<detail>
487+
<summary>Embedding Semantic Similarity</summary>
488+
489+
```go
490+
package main
491+
492+
import (
493+
"context"
494+
"log"
495+
openai "github.com/sashabaranov/go-openai"
496+
497+
)
498+
499+
func main() {
500+
client := openai.NewClient("your-token")
501+
502+
// Create an EmbeddingRequest for the user query
503+
queryReq := openai.EmbeddingRequest{
504+
Input: []string{"How many chucks would a woodchuck chuck"},
505+
Model: openai.AdaEmbeddingv2,
506+
}
507+
508+
// Create an embedding for the user query
509+
queryResponse, err := client.CreateEmbeddings(context.Background(), queryReq)
510+
if err != nil {
511+
log.Fatal("Error creating query embedding:", err)
512+
}
513+
514+
// Create an EmbeddingRequest for the target text
515+
targetReq := openai.EmbeddingRequest{
516+
Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"},
517+
Model: openai.AdaEmbeddingv2,
518+
}
519+
520+
// Create an embedding for the target text
521+
targetResponse, err := client.CreateEmbeddings(context.Background(), targetReq)
522+
if err != nil {
523+
log.Fatal("Error creating target embedding:", err)
524+
}
525+
526+
// Now that we have the embeddings for the user query and the target text, we
527+
// can calculate their similarity.
528+
queryEmbedding := queryResponse.Data[0]
529+
targetEmbedding := targetResponse.Data[0]
530+
531+
similarity, err := queryEmbedding.DotProduct(&targetEmbedding)
532+
if err != nil {
533+
log.Fatal("Error calculating dot product:", err)
534+
}
535+
536+
log.Printf("The similarity score between the query and the target is %f", similarity)
537+
}
538+
539+
```
540+
</detail>
541+
486542
<details>
487543
<summary>Azure OpenAI Embeddings</summary>
488544

embeddings.go

+20
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ import (
44
"context"
55
"encoding/base64"
66
"encoding/binary"
7+
"errors"
78
"math"
89
"net/http"
910
)
1011

12+
var ErrVectorLengthMismatch = errors.New("vector length mismatch")
13+
1114
// EmbeddingModel enumerates the models which can be used
1215
// to generate Embedding vectors.
1316
type EmbeddingModel int
@@ -124,6 +127,23 @@ type Embedding struct {
124127
Index int `json:"index"`
125128
}
126129

130+
// DotProduct calculates the dot product of the embedding vector with another
131+
// embedding vector. Both vectors must have the same length; otherwise, an
132+
// ErrVectorLengthMismatch is returned. The method returns the calculated dot
133+
// product as a float32 value.
134+
func (e *Embedding) DotProduct(other *Embedding) (float32, error) {
135+
if len(e.Embedding) != len(other.Embedding) {
136+
return 0, ErrVectorLengthMismatch
137+
}
138+
139+
var dotProduct float32
140+
for i := range e.Embedding {
141+
dotProduct += e.Embedding[i] * other.Embedding[i]
142+
}
143+
144+
return dotProduct, nil
145+
}
146+
127147
// EmbeddingResponse is the response from a Create embeddings request.
128148
type EmbeddingResponse struct {
129149
Object string `json:"object"`

embeddings_test.go

+38
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7+
"errors"
78
"fmt"
9+
"math"
810
"net/http"
911
"reflect"
1012
"testing"
@@ -233,3 +235,39 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
233235
})
234236
}
235237
}
238+
239+
func TestDotProduct(t *testing.T) {
240+
v1 := &Embedding{Embedding: []float32{1, 2, 3}}
241+
v2 := &Embedding{Embedding: []float32{2, 4, 6}}
242+
expected := float32(28.0)
243+
244+
result, err := v1.DotProduct(v2)
245+
if err != nil {
246+
t.Errorf("Unexpected error: %v", err)
247+
}
248+
249+
if math.Abs(float64(result-expected)) > 1e-12 {
250+
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
251+
}
252+
253+
v1 = &Embedding{Embedding: []float32{1, 0, 0}}
254+
v2 = &Embedding{Embedding: []float32{0, 1, 0}}
255+
expected = float32(0.0)
256+
257+
result, err = v1.DotProduct(v2)
258+
if err != nil {
259+
t.Errorf("Unexpected error: %v", err)
260+
}
261+
262+
if math.Abs(float64(result-expected)) > 1e-12 {
263+
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
264+
}
265+
266+
// Test for VectorLengthMismatchError
267+
v1 = &Embedding{Embedding: []float32{1, 0, 0}}
268+
v2 = &Embedding{Embedding: []float32{0, 1}}
269+
_, err = v1.DotProduct(v2)
270+
if !errors.Is(err, ErrVectorLengthMismatch) {
271+
t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err)
272+
}
273+
}

0 commit comments

Comments
 (0)