Skip to content

Commit 214f4ab

Browse files
authored
.Net: Optimize and clean up SqliteVec provider (#12333)
Mainly removing the intermediate dictionaries from serialization/deserialization. Part of #11123
1 parent 2a78664 commit 214f4ab

File tree

7 files changed

+216
-473
lines changed

7 files changed

+216
-473
lines changed

dotnet/src/Connectors/Connectors.Memory.SqliteVec/SqliteCollection.cs

Lines changed: 65 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
using System;
44
using System.Collections.Generic;
5-
using System.Data.Common;
65
using System.Diagnostics;
76
using System.Diagnostics.CodeAnalysis;
87
using System.Linq;
@@ -263,11 +262,9 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord
263262
translator.Translate(appendWhere: false);
264263

265264
using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
266-
DbCommand? command = null;
267265

268-
if (options.IncludeVectors)
269-
{
270-
command = SqliteCommandBuilder.BuildSelectInnerJoinCommand(
266+
using var command = options.IncludeVectors
267+
? SqliteCommandBuilder.BuildSelectInnerJoinCommand(
271268
connection,
272269
this._vectorTableName,
273270
this._dataTableName,
@@ -279,11 +276,8 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord
279276
translator.Clause.ToString(),
280277
translator.Parameters,
281278
top: top,
282-
skip: options.Skip);
283-
}
284-
else
285-
{
286-
command = SqliteCommandBuilder.BuildSelectDataCommand(
279+
skip: options.Skip)
280+
: SqliteCommandBuilder.BuildSelectDataCommand(
287281
connection,
288282
this._dataTableName,
289283
this._model,
@@ -293,28 +287,21 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord
293287
translator.Parameters,
294288
top: top,
295289
skip: options.Skip);
296-
}
297290

298-
using (command)
299-
{
300-
const string OperationName = "Get";
291+
const string OperationName = "Get";
301292

302-
using var reader = await connection.ExecuteWithErrorHandlingAsync(
303-
this._collectionMetadata,
304-
OperationName,
305-
() => command.ExecuteReaderAsync(cancellationToken),
306-
cancellationToken).ConfigureAwait(false);
293+
using var reader = await connection.ExecuteWithErrorHandlingAsync(
294+
this._collectionMetadata,
295+
OperationName,
296+
() => command.ExecuteReaderAsync(cancellationToken),
297+
cancellationToken).ConfigureAwait(false);
307298

308-
while (await reader.ReadWithErrorHandlingAsync(
309-
this._collectionMetadata,
310-
OperationName,
311-
cancellationToken).ConfigureAwait(false))
312-
{
313-
yield return this.GetAndMapRecord(
314-
reader,
315-
this._model.Properties,
316-
options.IncludeVectors);
317-
}
299+
while (await reader.ReadWithErrorHandlingAsync(
300+
this._collectionMetadata,
301+
OperationName,
302+
cancellationToken).ConfigureAwait(false))
303+
{
304+
yield return this._mapper.MapFromStorageToDataModel(reader, options.IncludeVectors);
318305
}
319306
}
320307

@@ -363,7 +350,7 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
363350
{
364351
Verify.NotNull(record);
365352

366-
IReadOnlyList<Embedding>?[]? generatedEmbeddings = null;
353+
Dictionary<VectorPropertyModel, IReadOnlyList<Embedding<float>>>? generatedEmbeddings = null;
367354

368355
var vectorPropertyCount = this._model.VectorProperties.Count;
369356
for (var i = 0; i < vectorPropertyCount; i++)
@@ -382,8 +369,8 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
382369
// and generate embeddings for them in a single batch. That's some more complexity though.
383370
if (vectorProperty.TryGenerateEmbedding<TRecord, Embedding<float>>(record, cancellationToken, out var floatTask))
384371
{
385-
generatedEmbeddings ??= new IReadOnlyList<Embedding>?[vectorPropertyCount];
386-
generatedEmbeddings[i] = [await floatTask.ConfigureAwait(false)];
372+
generatedEmbeddings ??= new Dictionary<VectorPropertyModel, IReadOnlyList<Embedding<float>>>(vectorPropertyCount);
373+
generatedEmbeddings[vectorProperty] = [await floatTask.ConfigureAwait(false)];
387374
}
388375
else
389376
{
@@ -394,16 +381,7 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
394381

395382
using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
396383

397-
var storageModel = this._mapper.MapFromDataToStorageModel(record, recordIndex: 0, generatedEmbeddings);
398-
399-
var key = storageModel[this._keyStorageName];
400-
401-
Verify.NotNull(key);
402-
403-
var condition = new SqliteWhereEqualsCondition(this._keyStorageName, key);
404-
405-
await this.InternalUpsertBatchAsync(connection, [storageModel], condition, cancellationToken)
406-
.ConfigureAwait(false);
384+
await this.InternalUpsertBatchAsync(connection, [record], generatedEmbeddings, cancellationToken).ConfigureAwait(false);
407385
}
408386

409387
/// <inheritdoc />
@@ -414,7 +392,7 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
414392
IReadOnlyList<TRecord>? recordsList = null;
415393

416394
// If an embedding generator is defined, invoke it once per property for all records.
417-
IReadOnlyList<Embedding>?[]? generatedEmbeddings = null;
395+
Dictionary<VectorPropertyModel, IReadOnlyList<Embedding<float>>>? generatedEmbeddings = null;
418396

419397
var vectorPropertyCount = this._model.VectorProperties.Count;
420398
for (var i = 0; i < vectorPropertyCount; i++)
@@ -447,8 +425,8 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
447425
// and generate embeddings for them in a single batch. That's some more complexity though.
448426
if (vectorProperty.TryGenerateEmbeddings<TRecord, Embedding<float>>(records, cancellationToken, out var floatTask))
449427
{
450-
generatedEmbeddings ??= new IReadOnlyList<Embedding>?[vectorPropertyCount];
451-
generatedEmbeddings[i] = (IReadOnlyList<Embedding<float>>)await floatTask.ConfigureAwait(false);
428+
generatedEmbeddings ??= new Dictionary<VectorPropertyModel, IReadOnlyList<Embedding<float>>>(vectorPropertyCount);
429+
generatedEmbeddings[vectorProperty] = await floatTask.ConfigureAwait(false);
452430
}
453431
else
454432
{
@@ -457,19 +435,9 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
457435
}
458436
}
459437

460-
var storageModels = records.Select((r, i) => this._mapper.MapFromDataToStorageModel(r, i, generatedEmbeddings)).ToList();
461-
462-
if (storageModels.Count == 0)
463-
{
464-
return;
465-
}
466-
467-
var keys = storageModels.Select(model => model[this._keyStorageName]!).ToList();
468-
469438
using var connection = await this.GetConnectionAsync(cancellationToken).ConfigureAwait(false);
470-
var condition = new SqliteWhereInCondition(this._keyStorageName, keys);
471439

472-
await this.InternalUpsertBatchAsync(connection, storageModels, condition, cancellationToken).ConfigureAwait(false);
440+
await this.InternalUpsertBatchAsync(connection, records, generatedEmbeddings, cancellationToken).ConfigureAwait(false);
473441
}
474442

475443
/// <inheritdoc />
@@ -557,11 +525,7 @@ private async IAsyncEnumerable<VectorSearchResult<TRecord>> EnumerateAndMapSearc
557525
if (recordCounter >= searchOptions.Skip)
558526
{
559527
var score = SqlitePropertyMapping.GetPropertyValue<double>(reader, SqliteCommandBuilder.DistancePropertyName);
560-
561-
var record = this.GetAndMapRecord(
562-
reader,
563-
this._model.Properties,
564-
searchOptions.IncludeVectors);
528+
var record = this._mapper.MapFromStorageToDataModel(reader, searchOptions.IncludeVectors);
565529

566530
yield return new VectorSearchResult<TRecord>(record, score);
567531
}
@@ -632,69 +596,67 @@ private async IAsyncEnumerable<TRecord> InternalGetBatchAsync(
632596
const string OperationName = "Select";
633597

634598
bool includeVectors = options?.IncludeVectors is true && this._vectorPropertiesExist;
635-
636-
DbCommand command;
637-
638-
if (includeVectors)
599+
if (includeVectors && this._model.EmbeddingGenerationRequired)
639600
{
640-
if (this._model.EmbeddingGenerationRequired)
641-
{
642-
throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration);
643-
}
601+
throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration);
602+
}
644603

645-
command = SqliteCommandBuilder.BuildSelectInnerJoinCommand<TRecord>(
604+
var command = includeVectors
605+
? SqliteCommandBuilder.BuildSelectInnerJoinCommand<TRecord>(
646606
connection,
647607
this._vectorTableName,
648608
this._dataTableName,
649609
this._keyStorageName,
650610
this._model,
651611
[condition],
652-
includeDistance: false);
653-
}
654-
else
655-
{
656-
command = SqliteCommandBuilder.BuildSelectDataCommand<TRecord>(
612+
includeDistance: false)
613+
: SqliteCommandBuilder.BuildSelectDataCommand<TRecord>(
657614
connection,
658615
this._dataTableName,
659616
this._model,
660617
[condition]);
661-
}
662618

663-
using (command)
664-
{
665-
using var reader = await connection.ExecuteWithErrorHandlingAsync(
666-
this._collectionMetadata,
667-
OperationName,
668-
() => command.ExecuteReaderAsync(cancellationToken),
669-
cancellationToken).ConfigureAwait(false);
619+
using var reader = await connection.ExecuteWithErrorHandlingAsync(
620+
this._collectionMetadata,
621+
OperationName,
622+
() => command.ExecuteReaderAsync(cancellationToken),
623+
cancellationToken).ConfigureAwait(false);
670624

671-
while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false))
672-
{
673-
yield return this.GetAndMapRecord(
674-
reader,
675-
this._model.Properties,
676-
includeVectors);
677-
}
625+
while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false))
626+
{
627+
yield return this._mapper.MapFromStorageToDataModel(reader, includeVectors);
678628
}
679629
}
680630

681-
private async Task<IReadOnlyList<TKey>> InternalUpsertBatchAsync(
631+
private async Task InternalUpsertBatchAsync(
682632
SqliteConnection connection,
683-
List<Dictionary<string, object?>> storageModels,
684-
SqliteWhereCondition condition,
633+
IEnumerable<TRecord> records,
634+
Dictionary<VectorPropertyModel, IReadOnlyList<Embedding<float>>>? generatedEmbeddings,
685635
CancellationToken cancellationToken)
686636
{
687-
Verify.NotNull(storageModels);
688-
Verify.True(storageModels.Count > 0, "Number of provided records should be greater than zero.");
637+
Verify.NotNull(records);
689638

690639
if (this._vectorPropertiesExist)
691640
{
641+
// We're going to have to traverse the records multiple times, so materialize the enumerable if needed.
642+
var recordsList = records is IReadOnlyList<TRecord> r ? r : records.ToList();
643+
644+
if (recordsList.Count == 0)
645+
{
646+
return;
647+
}
648+
649+
records = recordsList;
650+
651+
var keyProperty = this._model.KeyProperty;
652+
var keys = recordsList.Select(r => keyProperty.GetValueAsObject(r)!).ToList();
653+
692654
// Deleting vector records first since current version of vector search extension
693655
// doesn't support Upsert operation, only Delete/Insert.
694656
using var vectorDeleteCommand = SqliteCommandBuilder.BuildDeleteCommand(
695657
connection,
696658
this._vectorTableName,
697-
[condition]);
659+
[new SqliteWhereInCondition(this._keyStorageName, keys)]);
698660

699661
await connection.ExecuteWithErrorHandlingAsync(
700662
this._collectionMetadata,
@@ -706,8 +668,9 @@ await connection.ExecuteWithErrorHandlingAsync(
706668
connection,
707669
this._vectorTableName,
708670
this._keyStorageName,
709-
this._model.Properties,
710-
storageModels,
671+
this._model,
672+
records,
673+
generatedEmbeddings,
711674
data: false);
712675

713676
await connection.ExecuteWithErrorHandlingAsync(
@@ -721,8 +684,9 @@ await connection.ExecuteWithErrorHandlingAsync(
721684
connection,
722685
this._dataTableName,
723686
this._keyStorageName,
724-
this._model.Properties,
725-
storageModels,
687+
this._model,
688+
records,
689+
generatedEmbeddings,
726690
data: true,
727691
replaceIfExists: true);
728692

@@ -732,18 +696,14 @@ await connection.ExecuteWithErrorHandlingAsync(
732696
() => dataCommand.ExecuteReaderAsync(cancellationToken),
733697
cancellationToken).ConfigureAwait(false);
734698

735-
var keys = new List<TKey>();
736-
737699
while (await reader.ReadAsync(cancellationToken).ConfigureAwait(false))
738700
{
739701
var key = reader.GetFieldValue<TKey>(0);
740702

741-
keys.Add(key);
703+
// TODO: Inject the generated keys into the record for autogenerated keys.
742704

743705
await reader.NextResultAsync(cancellationToken).ConfigureAwait(false);
744706
}
745-
746-
return keys;
747707
}
748708

749709
private Task InternalDeleteBatchAsync(SqliteConnection connection, SqliteWhereCondition condition, CancellationToken cancellationToken)
@@ -778,25 +738,6 @@ private Task InternalDeleteBatchAsync(SqliteConnection connection, SqliteWhereCo
778738
return Task.WhenAll(tasks);
779739
}
780740

781-
private TRecord GetAndMapRecord(
782-
DbDataReader reader,
783-
IReadOnlyList<PropertyModel> properties,
784-
bool includeVectors)
785-
{
786-
var storageModel = new Dictionary<string, object?>();
787-
788-
foreach (var property in properties)
789-
{
790-
if (includeVectors || property is not VectorPropertyModel)
791-
{
792-
var propertyValue = SqlitePropertyMapping.GetPropertyValue(reader, property.StorageName, property.Type);
793-
storageModel.Add(property.StorageName, propertyValue);
794-
}
795-
}
796-
797-
return this._mapper.MapFromStorageToDataModel(storageModel, includeVectors);
798-
}
799-
800741
#pragma warning disable CS0618 // VectorSearchFilter is obsolete
801742
private List<SqliteWhereCondition>? GetFilterConditions(VectorSearchFilter? filter, string? tableName = null)
802743
{

0 commit comments

Comments
 (0)