-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathfeeder.go
142 lines (114 loc) · 2.83 KB
/
feeder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
package gptbot
import (
"context"
)
type XPreprocessor interface {
Preprocess(docs ...*Document) (map[string][]*Chunk, error)
}
type Updater interface {
Insert(ctx context.Context, chunks map[string][]*Chunk) error
Delete(ctx context.Context, documentIDs ...string) error
}
type FeederConfig struct {
// Encoder is the embedding encoder.
// This field is required.
Encoder Encoder
// Updater is the vector store for inserting/deleting chunks.
// This field is required.
Updater Updater
// Defaults to NewPreprocessor(...).
Preprocessor XPreprocessor
// BatchSize is the number of chunks to encode/upsert at a time.
// Defaults to 100.
BatchSize int
}
func (cfg *FeederConfig) init() *FeederConfig {
if cfg.Preprocessor == nil {
cfg.Preprocessor = NewPreprocessor(&PreprocessorConfig{})
}
if cfg.BatchSize == 0 {
cfg.BatchSize = 100
}
return cfg
}
type Feeder struct {
cfg *FeederConfig
}
func NewFeeder(cfg *FeederConfig) *Feeder {
return &Feeder{
cfg: cfg.init(),
}
}
func (f *Feeder) Preprocessor() XPreprocessor {
return f.cfg.Preprocessor
}
func (f *Feeder) Feed(ctx context.Context, docs ...*Document) error {
chunks, err := f.cfg.Preprocessor.Preprocess(docs...)
if err != nil {
return err
}
// Delete old chunks belonging to the given document IDs.
var docIDs []string
for docID := range chunks {
docIDs = append(docIDs, docID)
}
if err := f.cfg.Updater.Delete(ctx, docIDs...); err != nil {
return err
}
// Insert new chunks.
for batch := range genBatches(chunks, f.cfg.BatchSize) {
if err := f.encode(ctx, batch); err != nil {
return err
}
if err := f.insert(ctx, batch); err != nil {
return err
}
}
return nil
}
func (f *Feeder) encode(ctx context.Context, batch []*Chunk) error {
var texts []string
for _, chunk := range batch {
texts = append(texts, chunk.Text)
}
embeddings, err := f.cfg.Encoder.EncodeBatch(ctx, texts)
if err != nil {
return err
}
for i, chunk := range batch {
chunk.Embedding = embeddings[i]
}
return nil
}
func (f *Feeder) insert(ctx context.Context, batch []*Chunk) error {
chunkMap := make(map[string][]*Chunk)
for _, chunk := range batch {
chunkMap[chunk.DocumentID] = append(chunkMap[chunk.DocumentID], chunk)
}
return f.cfg.Updater.Insert(ctx, chunkMap)
}
func genBatches(chunks map[string][]*Chunk, size int) <-chan []*Chunk {
ch := make(chan []*Chunk)
go func() {
var batch []*Chunk
for _, chunkList := range chunks {
for _, chunk := range chunkList {
batch = append(batch, chunk)
if len(batch) == size {
// Reach the batch size, copy and send all the buffered chunks.
temp := make([]*Chunk, size)
copy(temp, batch)
ch <- temp
// Clear the buffer.
batch = batch[:0]
}
}
}
// Send all the remaining chunks, if any.
if len(batch) > 0 {
ch <- batch
}
close(ch)
}()
return ch
}