2
2
3
3
using System ;
4
4
using System . Collections . Generic ;
5
- using System . Data . Common ;
6
5
using System . Diagnostics ;
7
6
using System . Diagnostics . CodeAnalysis ;
8
7
using System . Linq ;
@@ -263,11 +262,9 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord
263
262
translator . Translate ( appendWhere : false ) ;
264
263
265
264
using var connection = await this . GetConnectionAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
266
- DbCommand ? command = null ;
267
265
268
- if ( options . IncludeVectors )
269
- {
270
- command = SqliteCommandBuilder . BuildSelectInnerJoinCommand (
266
+ using var command = options . IncludeVectors
267
+ ? SqliteCommandBuilder . BuildSelectInnerJoinCommand (
271
268
connection ,
272
269
this . _vectorTableName ,
273
270
this . _dataTableName ,
@@ -279,11 +276,8 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord
279
276
translator . Clause . ToString ( ) ,
280
277
translator . Parameters ,
281
278
top : top ,
282
- skip : options . Skip ) ;
283
- }
284
- else
285
- {
286
- command = SqliteCommandBuilder . BuildSelectDataCommand (
279
+ skip : options . Skip )
280
+ : SqliteCommandBuilder . BuildSelectDataCommand (
287
281
connection ,
288
282
this . _dataTableName ,
289
283
this . _model ,
@@ -293,28 +287,21 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord
293
287
translator . Parameters ,
294
288
top : top ,
295
289
skip : options . Skip ) ;
296
- }
297
290
298
- using ( command )
299
- {
300
- const string OperationName = "Get" ;
291
+ const string OperationName = "Get" ;
301
292
302
- using var reader = await connection . ExecuteWithErrorHandlingAsync (
303
- this . _collectionMetadata ,
304
- OperationName ,
305
- ( ) => command . ExecuteReaderAsync ( cancellationToken ) ,
306
- cancellationToken ) . ConfigureAwait ( false ) ;
293
+ using var reader = await connection . ExecuteWithErrorHandlingAsync (
294
+ this . _collectionMetadata ,
295
+ OperationName ,
296
+ ( ) => command . ExecuteReaderAsync ( cancellationToken ) ,
297
+ cancellationToken ) . ConfigureAwait ( false ) ;
307
298
308
- while ( await reader . ReadWithErrorHandlingAsync (
309
- this . _collectionMetadata ,
310
- OperationName ,
311
- cancellationToken ) . ConfigureAwait ( false ) )
312
- {
313
- yield return this . GetAndMapRecord (
314
- reader ,
315
- this . _model . Properties ,
316
- options . IncludeVectors ) ;
317
- }
299
+ while ( await reader . ReadWithErrorHandlingAsync (
300
+ this . _collectionMetadata ,
301
+ OperationName ,
302
+ cancellationToken ) . ConfigureAwait ( false ) )
303
+ {
304
+ yield return this . _mapper . MapFromStorageToDataModel ( reader , options . IncludeVectors ) ;
318
305
}
319
306
}
320
307
@@ -363,7 +350,7 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
363
350
{
364
351
Verify . NotNull ( record ) ;
365
352
366
- IReadOnlyList < Embedding > ? [ ] ? generatedEmbeddings = null ;
353
+ Dictionary < VectorPropertyModel , IReadOnlyList < Embedding < float > > > ? generatedEmbeddings = null ;
367
354
368
355
var vectorPropertyCount = this . _model . VectorProperties . Count ;
369
356
for ( var i = 0 ; i < vectorPropertyCount ; i ++ )
@@ -382,8 +369,8 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
382
369
// and generate embeddings for them in a single batch. That's some more complexity though.
383
370
if ( vectorProperty . TryGenerateEmbedding < TRecord , Embedding < float > > ( record , cancellationToken , out var floatTask ) )
384
371
{
385
- generatedEmbeddings ??= new IReadOnlyList < Embedding > ? [ vectorPropertyCount ] ;
386
- generatedEmbeddings [ i ] = [ await floatTask . ConfigureAwait ( false ) ] ;
372
+ generatedEmbeddings ??= new Dictionary < VectorPropertyModel , IReadOnlyList < Embedding < float > > > ( vectorPropertyCount ) ;
373
+ generatedEmbeddings [ vectorProperty ] = [ await floatTask . ConfigureAwait ( false ) ] ;
387
374
}
388
375
else
389
376
{
@@ -394,16 +381,7 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
394
381
395
382
using var connection = await this . GetConnectionAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
396
383
397
- var storageModel = this . _mapper . MapFromDataToStorageModel ( record , recordIndex : 0 , generatedEmbeddings ) ;
398
-
399
- var key = storageModel [ this . _keyStorageName ] ;
400
-
401
- Verify . NotNull ( key ) ;
402
-
403
- var condition = new SqliteWhereEqualsCondition ( this . _keyStorageName , key ) ;
404
-
405
- await this . InternalUpsertBatchAsync ( connection , [ storageModel ] , condition , cancellationToken )
406
- . ConfigureAwait ( false ) ;
384
+ await this . InternalUpsertBatchAsync ( connection , [ record ] , generatedEmbeddings , cancellationToken ) . ConfigureAwait ( false ) ;
407
385
}
408
386
409
387
/// <inheritdoc />
@@ -414,7 +392,7 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
414
392
IReadOnlyList < TRecord > ? recordsList = null ;
415
393
416
394
// If an embedding generator is defined, invoke it once per property for all records.
417
- IReadOnlyList < Embedding > ? [ ] ? generatedEmbeddings = null ;
395
+ Dictionary < VectorPropertyModel , IReadOnlyList < Embedding < float > > > ? generatedEmbeddings = null ;
418
396
419
397
var vectorPropertyCount = this . _model . VectorProperties . Count ;
420
398
for ( var i = 0 ; i < vectorPropertyCount ; i ++ )
@@ -447,8 +425,8 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
447
425
// and generate embeddings for them in a single batch. That's some more complexity though.
448
426
if ( vectorProperty . TryGenerateEmbeddings < TRecord , Embedding < float > > ( records , cancellationToken , out var floatTask ) )
449
427
{
450
- generatedEmbeddings ??= new IReadOnlyList < Embedding > ? [ vectorPropertyCount ] ;
451
- generatedEmbeddings [ i ] = ( IReadOnlyList < Embedding < float > > ) await floatTask . ConfigureAwait ( false ) ;
428
+ generatedEmbeddings ??= new Dictionary < VectorPropertyModel , IReadOnlyList < Embedding < float > > > ( vectorPropertyCount ) ;
429
+ generatedEmbeddings [ vectorProperty ] = await floatTask . ConfigureAwait ( false ) ;
452
430
}
453
431
else
454
432
{
@@ -457,19 +435,9 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
457
435
}
458
436
}
459
437
460
- var storageModels = records . Select ( ( r , i ) => this . _mapper . MapFromDataToStorageModel ( r , i , generatedEmbeddings ) ) . ToList ( ) ;
461
-
462
- if ( storageModels . Count == 0 )
463
- {
464
- return ;
465
- }
466
-
467
- var keys = storageModels . Select ( model => model [ this . _keyStorageName ] ! ) . ToList ( ) ;
468
-
469
438
using var connection = await this . GetConnectionAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
470
- var condition = new SqliteWhereInCondition ( this . _keyStorageName , keys ) ;
471
439
472
- await this . InternalUpsertBatchAsync ( connection , storageModels , condition , cancellationToken ) . ConfigureAwait ( false ) ;
440
+ await this . InternalUpsertBatchAsync ( connection , records , generatedEmbeddings , cancellationToken ) . ConfigureAwait ( false ) ;
473
441
}
474
442
475
443
/// <inheritdoc />
@@ -557,11 +525,7 @@ private async IAsyncEnumerable<VectorSearchResult<TRecord>> EnumerateAndMapSearc
557
525
if ( recordCounter >= searchOptions . Skip )
558
526
{
559
527
var score = SqlitePropertyMapping . GetPropertyValue < double > ( reader , SqliteCommandBuilder . DistancePropertyName ) ;
560
-
561
- var record = this . GetAndMapRecord (
562
- reader ,
563
- this . _model . Properties ,
564
- searchOptions . IncludeVectors ) ;
528
+ var record = this . _mapper . MapFromStorageToDataModel ( reader , searchOptions . IncludeVectors ) ;
565
529
566
530
yield return new VectorSearchResult < TRecord > ( record , score ) ;
567
531
}
@@ -632,69 +596,67 @@ private async IAsyncEnumerable<TRecord> InternalGetBatchAsync(
632
596
const string OperationName = "Select" ;
633
597
634
598
bool includeVectors = options ? . IncludeVectors is true && this . _vectorPropertiesExist ;
635
-
636
- DbCommand command ;
637
-
638
- if ( includeVectors )
599
+ if ( includeVectors && this . _model . EmbeddingGenerationRequired )
639
600
{
640
- if ( this . _model . EmbeddingGenerationRequired )
641
- {
642
- throw new NotSupportedException ( VectorDataStrings . IncludeVectorsNotSupportedWithEmbeddingGeneration ) ;
643
- }
601
+ throw new NotSupportedException ( VectorDataStrings . IncludeVectorsNotSupportedWithEmbeddingGeneration ) ;
602
+ }
644
603
645
- command = SqliteCommandBuilder . BuildSelectInnerJoinCommand < TRecord > (
604
+ var command = includeVectors
605
+ ? SqliteCommandBuilder . BuildSelectInnerJoinCommand < TRecord > (
646
606
connection ,
647
607
this . _vectorTableName ,
648
608
this . _dataTableName ,
649
609
this . _keyStorageName ,
650
610
this . _model ,
651
611
[ condition ] ,
652
- includeDistance : false ) ;
653
- }
654
- else
655
- {
656
- command = SqliteCommandBuilder . BuildSelectDataCommand < TRecord > (
612
+ includeDistance : false )
613
+ : SqliteCommandBuilder . BuildSelectDataCommand < TRecord > (
657
614
connection ,
658
615
this . _dataTableName ,
659
616
this . _model ,
660
617
[ condition ] ) ;
661
- }
662
618
663
- using ( command )
664
- {
665
- using var reader = await connection . ExecuteWithErrorHandlingAsync (
666
- this . _collectionMetadata ,
667
- OperationName ,
668
- ( ) => command . ExecuteReaderAsync ( cancellationToken ) ,
669
- cancellationToken ) . ConfigureAwait ( false ) ;
619
+ using var reader = await connection . ExecuteWithErrorHandlingAsync (
620
+ this . _collectionMetadata ,
621
+ OperationName ,
622
+ ( ) => command . ExecuteReaderAsync ( cancellationToken ) ,
623
+ cancellationToken ) . ConfigureAwait ( false ) ;
670
624
671
- while ( await reader . ReadAsync ( cancellationToken ) . ConfigureAwait ( false ) )
672
- {
673
- yield return this . GetAndMapRecord (
674
- reader ,
675
- this . _model . Properties ,
676
- includeVectors ) ;
677
- }
625
+ while ( await reader . ReadAsync ( cancellationToken ) . ConfigureAwait ( false ) )
626
+ {
627
+ yield return this . _mapper . MapFromStorageToDataModel ( reader , includeVectors ) ;
678
628
}
679
629
}
680
630
681
- private async Task < IReadOnlyList < TKey > > InternalUpsertBatchAsync (
631
+ private async Task InternalUpsertBatchAsync (
682
632
SqliteConnection connection ,
683
- List < Dictionary < string , object ? > > storageModels ,
684
- SqliteWhereCondition condition ,
633
+ IEnumerable < TRecord > records ,
634
+ Dictionary < VectorPropertyModel , IReadOnlyList < Embedding < float > > > ? generatedEmbeddings ,
685
635
CancellationToken cancellationToken )
686
636
{
687
- Verify . NotNull ( storageModels ) ;
688
- Verify . True ( storageModels . Count > 0 , "Number of provided records should be greater than zero." ) ;
637
+ Verify . NotNull ( records ) ;
689
638
690
639
if ( this . _vectorPropertiesExist )
691
640
{
641
+ // We're going to have to traverse the records multiple times, so materialize the enumerable if needed.
642
+ var recordsList = records is IReadOnlyList < TRecord > r ? r : records . ToList ( ) ;
643
+
644
+ if ( recordsList . Count == 0 )
645
+ {
646
+ return ;
647
+ }
648
+
649
+ records = recordsList ;
650
+
651
+ var keyProperty = this . _model . KeyProperty ;
652
+ var keys = recordsList . Select ( r => keyProperty . GetValueAsObject ( r ) ! ) . ToList ( ) ;
653
+
692
654
// Deleting vector records first since current version of vector search extension
693
655
// doesn't support Upsert operation, only Delete/Insert.
694
656
using var vectorDeleteCommand = SqliteCommandBuilder . BuildDeleteCommand (
695
657
connection ,
696
658
this . _vectorTableName ,
697
- [ condition ] ) ;
659
+ [ new SqliteWhereInCondition ( this . _keyStorageName , keys ) ] ) ;
698
660
699
661
await connection . ExecuteWithErrorHandlingAsync (
700
662
this . _collectionMetadata ,
@@ -706,8 +668,9 @@ await connection.ExecuteWithErrorHandlingAsync(
706
668
connection ,
707
669
this . _vectorTableName ,
708
670
this . _keyStorageName ,
709
- this . _model . Properties ,
710
- storageModels ,
671
+ this . _model ,
672
+ records ,
673
+ generatedEmbeddings ,
711
674
data : false ) ;
712
675
713
676
await connection . ExecuteWithErrorHandlingAsync (
@@ -721,8 +684,9 @@ await connection.ExecuteWithErrorHandlingAsync(
721
684
connection ,
722
685
this . _dataTableName ,
723
686
this . _keyStorageName ,
724
- this . _model . Properties ,
725
- storageModels ,
687
+ this . _model ,
688
+ records ,
689
+ generatedEmbeddings ,
726
690
data : true ,
727
691
replaceIfExists : true ) ;
728
692
@@ -732,18 +696,14 @@ await connection.ExecuteWithErrorHandlingAsync(
732
696
( ) => dataCommand . ExecuteReaderAsync ( cancellationToken ) ,
733
697
cancellationToken ) . ConfigureAwait ( false ) ;
734
698
735
- var keys = new List < TKey > ( ) ;
736
-
737
699
while ( await reader . ReadAsync ( cancellationToken ) . ConfigureAwait ( false ) )
738
700
{
739
701
var key = reader . GetFieldValue < TKey > ( 0 ) ;
740
702
741
- keys . Add ( key ) ;
703
+ // TODO: Inject the generated keys into the record for autogenerated keys.
742
704
743
705
await reader . NextResultAsync ( cancellationToken ) . ConfigureAwait ( false ) ;
744
706
}
745
-
746
- return keys ;
747
707
}
748
708
749
709
private Task InternalDeleteBatchAsync ( SqliteConnection connection , SqliteWhereCondition condition , CancellationToken cancellationToken )
@@ -778,25 +738,6 @@ private Task InternalDeleteBatchAsync(SqliteConnection connection, SqliteWhereCo
778
738
return Task . WhenAll ( tasks ) ;
779
739
}
780
740
781
- private TRecord GetAndMapRecord (
782
- DbDataReader reader ,
783
- IReadOnlyList < PropertyModel > properties ,
784
- bool includeVectors )
785
- {
786
- var storageModel = new Dictionary < string , object ? > ( ) ;
787
-
788
- foreach ( var property in properties )
789
- {
790
- if ( includeVectors || property is not VectorPropertyModel )
791
- {
792
- var propertyValue = SqlitePropertyMapping . GetPropertyValue ( reader , property . StorageName , property . Type ) ;
793
- storageModel . Add ( property . StorageName , propertyValue ) ;
794
- }
795
- }
796
-
797
- return this . _mapper . MapFromStorageToDataModel ( storageModel , includeVectors ) ;
798
- }
799
-
800
741
#pragma warning disable CS0618 // VectorSearchFilter is obsolete
801
742
private List < SqliteWhereCondition > ? GetFilterConditions ( VectorSearchFilter ? filter , string ? tableName = null )
802
743
{
0 commit comments