Skip to content

Commit fe3d276

Browse files
authored
fix: parameterize where clauses (#512)
1 parent b81fcd0 commit fe3d276

File tree

7 files changed

+122
-29
lines changed

7 files changed

+122
-29
lines changed

knowledge/pkg/datastore/ingest.go

-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ func (s *Datastore) Ingest(ctx context.Context, datasetID string, filename strin
187187
} else if len(fs) > 0 {
188188
fileLoop:
189189
for _, f := range fs {
190-
191190
// check if the dataset embeddingsconfig matches - if not, we don't have to fetch the documents for this file
192191
ds, err := s.GetDataset(ctx, f.Dataset, nil)
193192
if err != nil || ds == nil {

knowledge/pkg/index/types/types.go

-1
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@ func (db *DB) FindFilesByMetadata(ctx context.Context, dataset string, metadata
236236
return nil, err
237237
}
238238
return files, nil
239-
240239
}
241240

242241
func (db *DB) GetDocument(ctx context.Context, documentID string) (*Document, error) {

knowledge/pkg/vectorstore/chromem/chromem.go

-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,6 @@ func (s *ChromemStore) GetDocument(ctx context.Context, documentID, collection s
250250
Content: doc.Content,
251251
Embedding: doc.Embedding,
252252
}, nil
253-
254253
}
255254

256255
func (s *ChromemStore) GetDocuments(ctx context.Context, collection string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) {

knowledge/pkg/vectorstore/helper/sql.go

+61-10
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,87 @@ import (
77
cg "github.com/philippgille/chromem-go"
88
)
99

10-
func BuildWhereDocumentClause(whereDocs []cg.WhereDocument, joinOperator string) (string, error) {
10+
func BuildWhereDocumentClauseIndexed(whereDocs []cg.WhereDocument, joinOperator string, argIndex int) (string, []any, error) {
1111
if len(whereDocs) == 0 {
12-
return "TRUE", nil
12+
return "TRUE", nil, nil
1313
}
1414
if joinOperator == "" {
1515
joinOperator = "AND"
1616
}
1717
joinOperator = fmt.Sprintf(" %s ", strings.TrimSpace(joinOperator)) // ensure space around operator
1818
var whereClauses []string
19+
var args []any
1920
for _, wd := range whereDocs {
2021
switch wd.Operator {
2122
case cg.WhereDocumentOperatorAnd:
22-
wc, err := BuildWhereDocumentClause(wd.WhereDocuments, "AND")
23+
wc, a, err := BuildWhereDocumentClauseIndexed(wd.WhereDocuments, "AND", argIndex)
2324
if err != nil {
24-
return "", err
25+
return "", nil, err
2526
}
2627
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", wc))
28+
args = append(args, a...)
29+
argIndex += len(a)
2730
case cg.WhereDocumentOperatorOr:
28-
wc, err := BuildWhereDocumentClause(wd.WhereDocuments, "OR")
31+
wc, a, err := BuildWhereDocumentClauseIndexed(wd.WhereDocuments, "OR", argIndex)
2932
if err != nil {
30-
return "", err
33+
return "", nil, err
3134
}
3235
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", wc))
36+
args = append(args, a...)
37+
argIndex += len(a)
3338
case cg.WhereDocumentOperatorEquals:
34-
whereClauses = append(whereClauses, fmt.Sprintf("document = '%s'", wd.Value))
39+
whereClauses = append(whereClauses, fmt.Sprintf("document = $%d", argIndex))
40+
args = append(args, wd.Value)
41+
argIndex += 1
3542
case cg.WhereDocumentOperatorContains:
36-
whereClauses = append(whereClauses, fmt.Sprintf("document LIKE '%%%s%%'", wd.Value))
43+
whereClauses = append(whereClauses, fmt.Sprintf("document LIKE $%d", argIndex))
44+
args = append(args, "%"+wd.Value+"%")
45+
argIndex += 1
3746
case cg.WhereDocumentOperatorNotContains:
38-
whereClauses = append(whereClauses, fmt.Sprintf("document NOT LIKE '%%%s%%'", wd.Value))
47+
whereClauses = append(whereClauses, fmt.Sprintf("document NOT LIKE $%d", argIndex))
48+
args = append(args, "%"+wd.Value+"%")
49+
argIndex += 1
3950
}
4051
}
41-
return strings.Join(whereClauses, joinOperator), nil
52+
return strings.Join(whereClauses, joinOperator), args, nil
53+
}
54+
55+
func BuildWhereDocumentClause(whereDocs []cg.WhereDocument, joinOperator string) (string, []any, error) {
56+
if len(whereDocs) == 0 {
57+
return "TRUE", nil, nil
58+
}
59+
if joinOperator == "" {
60+
joinOperator = "AND"
61+
}
62+
joinOperator = fmt.Sprintf(" %s ", strings.TrimSpace(joinOperator)) // ensure space around operator
63+
var whereClauses []string
64+
var args []any
65+
for _, wd := range whereDocs {
66+
switch wd.Operator {
67+
case cg.WhereDocumentOperatorAnd:
68+
wc, a, err := BuildWhereDocumentClause(wd.WhereDocuments, "AND")
69+
if err != nil {
70+
return "", nil, err
71+
}
72+
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", wc))
73+
args = append(args, a...)
74+
case cg.WhereDocumentOperatorOr:
75+
wc, a, err := BuildWhereDocumentClause(wd.WhereDocuments, "OR")
76+
if err != nil {
77+
return "", nil, err
78+
}
79+
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", wc))
80+
args = append(args, a...)
81+
case cg.WhereDocumentOperatorEquals:
82+
whereClauses = append(whereClauses, fmt.Sprintf("document = ?"))
83+
args = append(args, wd.Value)
84+
case cg.WhereDocumentOperatorContains:
85+
whereClauses = append(whereClauses, fmt.Sprintf("document LIKE ?"))
86+
args = append(args, "%"+wd.Value+"%")
87+
case cg.WhereDocumentOperatorNotContains:
88+
whereClauses = append(whereClauses, fmt.Sprintf("document NOT LIKE ?"))
89+
args = append(args, "%"+wd.Value+"%")
90+
}
91+
}
92+
return strings.Join(whereClauses, joinOperator), args, nil
4293
}

knowledge/pkg/vectorstore/helper/sql_test.go

+56-13
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,40 @@ import (
99

1010
func TestBuildWhereDocumentClause_EmptyInput_TRUEClause(t *testing.T) {
1111
var whereDocs []cg.WhereDocument
12-
whereClause, err := BuildWhereDocumentClause(whereDocs, "AND")
12+
whereClause, a, err := BuildWhereDocumentClause(whereDocs, "AND")
1313
assert.NoError(t, err)
1414
assert.Equal(t, "TRUE", whereClause)
15+
assert.Empty(t, a)
1516
}
1617

1718
func TestBuildWhereDocumentClause_SingleEqualsCondition_ReturnsCorrectClause(t *testing.T) {
1819
whereDocs := []cg.WhereDocument{
1920
{Operator: cg.WhereDocumentOperatorEquals, Value: "test"},
2021
}
21-
whereClause, err := BuildWhereDocumentClause(whereDocs, "AND")
22+
whereClause, a, err := BuildWhereDocumentClause(whereDocs, "AND")
2223
assert.NoError(t, err)
23-
assert.Equal(t, "document = 'test'", whereClause)
24+
assert.Equal(t, "document = ?", whereClause)
25+
assert.Equal(t, []any{"test"}, a)
2426
}
2527

2628
func TestBuildWhereDocumentClause_SingleContainsCondition_ReturnsCorrectClause(t *testing.T) {
2729
whereDocs := []cg.WhereDocument{
2830
{Operator: cg.WhereDocumentOperatorContains, Value: "test"},
2931
}
30-
whereClause, err := BuildWhereDocumentClause(whereDocs, "AND")
32+
whereClause, a, err := BuildWhereDocumentClause(whereDocs, "AND")
3133
assert.NoError(t, err)
32-
assert.Equal(t, "document LIKE '%test%'", whereClause)
34+
assert.Equal(t, "document LIKE ?", whereClause)
35+
assert.Equal(t, []any{"%test%"}, a)
3336
}
3437

3538
func TestBuildWhereDocumentClause_SingleNotContainsCondition_ReturnsCorrectClause(t *testing.T) {
3639
whereDocs := []cg.WhereDocument{
3740
{Operator: cg.WhereDocumentOperatorNotContains, Value: "test"},
3841
}
39-
whereClause, err := BuildWhereDocumentClause(whereDocs, "AND")
42+
whereClause, a, err := BuildWhereDocumentClause(whereDocs, "AND")
4043
assert.NoError(t, err)
41-
assert.Equal(t, "document NOT LIKE '%test%'", whereClause)
44+
assert.Equal(t, "document NOT LIKE ?", whereClause)
45+
assert.Equal(t, []any{"%test%"}, a)
4246
}
4347

4448
func TestBuildWhereDocumentClause_AndCondition_ReturnsCorrectClauses(t *testing.T) {
@@ -51,9 +55,10 @@ func TestBuildWhereDocumentClause_AndCondition_ReturnsCorrectClauses(t *testing.
5155
},
5256
},
5357
}
54-
whereClause, err := BuildWhereDocumentClause(whereDocs, "AND")
58+
whereClause, a, err := BuildWhereDocumentClause(whereDocs, "AND")
5559
assert.NoError(t, err)
56-
assert.Equal(t, "(document = 'test1' AND document = 'test2')", whereClause)
60+
assert.Equal(t, "(document = ? AND document = ?)", whereClause)
61+
assert.Equal(t, []any{"test1", "test2"}, a)
5762
}
5863

5964
func TestBuildWhereDocumentClause_OrCondition_ReturnsCorrectClauses(t *testing.T) {
@@ -66,9 +71,10 @@ func TestBuildWhereDocumentClause_OrCondition_ReturnsCorrectClauses(t *testing.T
6671
},
6772
},
6873
}
69-
whereClause, err := BuildWhereDocumentClause(whereDocs, "OR")
74+
whereClause, a, err := BuildWhereDocumentClause(whereDocs, "OR")
7075
assert.NoError(t, err)
71-
assert.Equal(t, "(document = 'test1' OR document = 'test2')", whereClause)
76+
assert.Equal(t, "(document = ? OR document = ?)", whereClause)
77+
assert.Equal(t, []any{"test1", "test2"}, a)
7278
}
7379

7480
func TestBuildWhereDocumentClause_Nested_ReturnsCorrectClauses(t *testing.T) {
@@ -101,7 +107,44 @@ func TestBuildWhereDocumentClause_Nested_ReturnsCorrectClauses(t *testing.T) {
101107
},
102108
},
103109
}
104-
whereClause, err := BuildWhereDocumentClause(whereDocs, "AND")
110+
whereClause, a, err := BuildWhereDocumentClause(whereDocs, "AND")
105111
assert.NoError(t, err)
106-
assert.Equal(t, "(document = 'test1' OR document = 'test2') AND (document = 'test3' AND document = 'test4') AND ((document = 'test5' AND document = 'test6') AND document = 'test7')", whereClause)
112+
assert.Equal(t, "(document = ? OR document = ?) AND (document = ? AND document = ?) AND ((document = ? AND document = ?) AND document = ?)", whereClause)
113+
assert.Equal(t, []any{"test1", "test2", "test3", "test4", "test5", "test6", "test7"}, a)
114+
}
115+
116+
func TestBuildWhereDocumentClauseIndexed_Nested_ReturnsCorrectClauses(t *testing.T) {
117+
whereDocs := []cg.WhereDocument{
118+
{
119+
Operator: cg.WhereDocumentOperatorOr,
120+
WhereDocuments: []cg.WhereDocument{
121+
{Operator: cg.WhereDocumentOperatorEquals, Value: "test1"},
122+
{Operator: cg.WhereDocumentOperatorEquals, Value: "test2"},
123+
},
124+
},
125+
{
126+
Operator: cg.WhereDocumentOperatorAnd,
127+
WhereDocuments: []cg.WhereDocument{
128+
{Operator: cg.WhereDocumentOperatorEquals, Value: "test3"},
129+
{Operator: cg.WhereDocumentOperatorEquals, Value: "test4"},
130+
},
131+
},
132+
{
133+
Operator: cg.WhereDocumentOperatorAnd,
134+
WhereDocuments: []cg.WhereDocument{
135+
{
136+
Operator: cg.WhereDocumentOperatorAnd,
137+
WhereDocuments: []cg.WhereDocument{
138+
{Operator: cg.WhereDocumentOperatorEquals, Value: "test5"},
139+
{Operator: cg.WhereDocumentOperatorEquals, Value: "test6"},
140+
},
141+
},
142+
{Operator: cg.WhereDocumentOperatorEquals, Value: "test7"},
143+
},
144+
},
145+
}
146+
whereClause, a, err := BuildWhereDocumentClauseIndexed(whereDocs, "AND", 3)
147+
assert.NoError(t, err)
148+
assert.Equal(t, "(document = $3 OR document = $4) AND (document = $5 AND document = $6) AND ((document = $7 AND document = $8) AND document = $9)", whereClause)
149+
assert.Equal(t, []any{"test1", "test2", "test3", "test4", "test5", "test6", "test7"}, a)
107150
}

knowledge/pkg/vectorstore/pgvector/pgvector.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -568,19 +568,20 @@ func buildWhereClause(args []any, where map[string]string, whereDocument []cg.Wh
568568
args = make([]any, 0)
569569
}
570570

571-
argIndex := len(args) + 1 // Usually w start with index 2 because $1 is for cid
571+
argIndex := len(args) + 1 // Usually we start with index 2 because $1 is for cid
572572
for k, v := range where {
573573
whereClauses = append(whereClauses, fmt.Sprintf("(cmetadata ->> $%d) = $%d", argIndex, argIndex+1))
574574
args = append(args, k, v)
575575
argIndex += 2
576576
}
577577

578578
if len(whereDocument) > 0 {
579-
wc, err := helper.BuildWhereDocumentClause(whereDocument, "AND")
579+
wc, a, err := helper.BuildWhereDocumentClauseIndexed(whereDocument, "AND", argIndex)
580580
if err != nil {
581581
return "", nil, err
582582
}
583583
whereClauses = append(whereClauses, wc)
584+
args = append(args, a...)
584585
}
585586

586587
whereClause := strings.Join(whereClauses, " AND ")

knowledge/pkg/vectorstore/sqlite-vec/sqlite-vec.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -372,11 +372,12 @@ func (v *VectorStore) GetDocuments(_ context.Context, collection string, where m
372372
}
373373

374374
if len(whereDocument) > 0 {
375-
wc, err := helper.BuildWhereDocumentClause(whereDocument, "AND")
375+
wc, a, err := helper.BuildWhereDocumentClause(whereDocument, "AND")
376376
if err != nil {
377377
return nil, fmt.Errorf("failed to build whereDocument clause: %w", err)
378378
}
379379
whereQueries = append(whereQueries, wc)
380+
args = append(args, a...)
380381
}
381382

382383
whereQuery := strings.Join(whereQueries, " AND ")

0 commit comments

Comments
 (0)