Skip to content

Commit 654158c

Browse files
committed
Add priorityqueue
1 parent 88a217a commit 654158c

File tree

1 file changed

+90
-18
lines changed

1 file changed

+90
-18
lines changed

vectorstore/in_memory.go

Lines changed: 90 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package vectorstore
22

33
import (
4+
"container/heap"
45
"context"
5-
"sort"
66

77
"github.com/hupe1980/golc/internal/util"
88
"github.com/hupe1980/golc/metric"
@@ -19,13 +19,64 @@ type InMemoryItem struct {
1919
Metadata map[string]any `json:"metadata"`
2020
}
2121

22+
// priorityQueueItem represents an item in the priority queue.
23+
type priorityQueueItem struct {
24+
Data InMemoryItem // Data associated with the item
25+
distance float32 // Distance from the query vector
26+
27+
index int // Index of the item in the priority queue
28+
}
29+
30+
// priorityQueue is a priority queue for InMemoryItem items.
31+
type priorityQueue []*priorityQueueItem
32+
33+
// Len returns the length of the priority queue.
34+
func (pq priorityQueue) Len() int { return len(pq) }
35+
36+
// Less reports whether the element with index i should sort before the element with index j.
37+
func (pq priorityQueue) Less(i, j int) bool { return pq[i].distance < pq[j].distance }
38+
39+
// Swap swaps the elements with indexes i and j.
40+
func (pq priorityQueue) Swap(i, j int) {
41+
pq[i], pq[j] = pq[j], pq[i]
42+
pq[i].index = i
43+
pq[j].index = j
44+
}
45+
46+
// Push adds an item to the priority queue.
47+
func (pq *priorityQueue) Push(x any) {
48+
n := len(*pq)
49+
item, _ := x.(*priorityQueueItem)
50+
item.index = n
51+
*pq = append(*pq, item)
52+
}
53+
54+
// Pop removes and returns the item with the highest priority (distance).
55+
func (pq *priorityQueue) Pop() any {
56+
old := *pq
57+
n := len(old)
58+
item := old[n-1]
59+
item.index = -1
60+
*pq = old[0 : n-1]
61+
62+
return item
63+
}
64+
65+
// Top returns the top element of the priority queue.
66+
func (pq *priorityQueue) Top() any {
67+
return (*pq)[0]
68+
}
69+
70+
// DistanceFunc represents a function for calculating the distance between two vectors
71+
type DistanceFunc func(v1, v2 []float32) (float32, error)
72+
2273
// InMemoryOptions represents options for the in-memory vector store.
2374
type InMemoryOptions struct {
24-
TopK int
75+
TopK int
76+
DistanceFunc DistanceFunc
2577
}
2678

2779
// InMemory represents an in-memory vector store.
28-
// Note: This implementation is intended for testing and demonstration purposes, not for production use.
2980
type InMemory struct {
3081
embedder schema.Embedder
3182
data []InMemoryItem
@@ -35,7 +86,8 @@ type InMemory struct {
3586
// NewInMemory creates a new instance of the in-memory vector store.
3687
func NewInMemory(embedder schema.Embedder, optFns ...func(*InMemoryOptions)) *InMemory {
3788
opts := InMemoryOptions{
38-
TopK: 3,
89+
TopK: 3,
90+
DistanceFunc: metric.SquaredL2,
3991
}
4092

4193
for _, fn := range optFns {
@@ -89,36 +141,56 @@ func (vs *InMemory) SimilaritySearch(ctx context.Context, query string) ([]schem
89141
return nil, err
90142
}
91143

144+
topCandidates := &priorityQueue{}
145+
146+
heap.Init(topCandidates)
147+
92148
type searchResult struct {
93149
Item InMemoryItem
94150
Similarity float32
95151
}
96152

97153
results := make([]searchResult, len(vs.data))
98154

99-
for i, item := range vs.data {
100-
similarity, err := metric.CosineSimilarity(queryVector, item.Vector)
155+
for _, item := range vs.data {
156+
similarity, err := vs.opts.DistanceFunc(queryVector, item.Vector)
101157
if err != nil {
102158
return nil, err
103159
}
104160

105-
results[i] = searchResult{Item: item, Similarity: similarity}
106-
}
161+
if topCandidates.Len() < vs.opts.TopK {
162+
heap.Push(topCandidates, &priorityQueueItem{
163+
Data: item,
164+
distance: similarity,
165+
})
166+
167+
continue
168+
}
169+
170+
largestDist, _ := topCandidates.Top().(*priorityQueueItem)
107171

108-
// Sort results by similarity in descending order
109-
sort.Slice(results, func(i, j int) bool {
110-
return results[i].Similarity > results[j].Similarity
111-
})
172+
if similarity < largestDist.distance {
173+
_ = heap.Pop(topCandidates)
174+
175+
heap.Push(topCandidates, &priorityQueueItem{
176+
Data: item,
177+
distance: similarity,
178+
})
179+
}
180+
}
112181

113182
docLen := util.Min(len(results), vs.opts.TopK)
114183

115184
// Extract documents from sorted results
116-
documents := make([]schema.Document, docLen)
117-
for i := 0; i < docLen; i++ {
118-
documents[i] = schema.Document{
119-
PageContent: results[i].Item.Content,
120-
Metadata: results[i].Item.Metadata,
121-
}
185+
documents := make([]schema.Document, 0, docLen)
186+
187+
for topCandidates.Len() > 0 {
188+
item, _ := heap.Pop(topCandidates).(*priorityQueueItem)
189+
190+
documents = append(documents, schema.Document{
191+
PageContent: item.Data.Content,
192+
Metadata: item.Data.Metadata,
193+
})
122194
}
123195

124196
return documents, nil

0 commit comments

Comments
 (0)