diff --git a/extensions/SQLServer/SQLServer/SqlServerMemory.cs b/extensions/SQLServer/SQLServer/SqlServerMemory.cs
index eb5388e89..3ea8b82dc 100644
--- a/extensions/SQLServer/SQLServer/SqlServerMemory.cs
+++ b/extensions/SQLServer/SQLServer/SqlServerMemory.cs
@@ -17,7 +17,7 @@ namespace Microsoft.KernelMemory.MemoryDb.SQLServer;
/// Represents a memory store implementation that uses a SQL Server database as its backing store.
///
#pragma warning disable CA2100 // SQL reviewed for user input validation
-public class SqlServerMemory : IMemoryDb, IMemoryDbBatchUpsert
+public sealed class SqlServerMemory : IMemoryDb, IMemoryDbBatchUpsert, IDisposable
{
///
/// The SQL Server configuration.
@@ -34,6 +34,21 @@ public class SqlServerMemory : IMemoryDb, IMemoryDbBatchUpsert
///
private readonly ILogger _log;
+ ///
+ /// Flag used to initialize the client on the first call
+ ///
+ private bool _isReady = false;
+
+ ///
+ /// Lock used to initialize the class instance
+ ///
+ private readonly SemaphoreSlim _initSemaphore = new(1, 1);
+
+ ///
+ /// SQL Server version, retrieved on the first connection
+ ///
+ private int _cachedServerVersion = int.MinValue;
+
///
/// Initializes a new instance of the class.
///
@@ -45,22 +60,16 @@ public SqlServerMemory(
ITextEmbeddingGenerator embeddingGenerator,
ILogger? log = null)
{
- this._embeddingGenerator = embeddingGenerator;
- this._log = log ?? DefaultLogger.Instance;
-
this._config = config;
-
- if (this._embeddingGenerator == null)
- {
- throw new SqlServerMemoryException("Embedding generator not configured");
- }
-
- this.CreateTablesIfNotExists();
+ this._embeddingGenerator = embeddingGenerator ?? throw new ConfigurationException("Embedding generator not configured");
+ this._log = log ?? DefaultLogger.Instance;
}
///
public async Task CreateIndexAsync(string index, int vectorSize, CancellationToken cancellationToken = default)
{
+ if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); }
+
index = NormalizeIndexName(index);
if (await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false))
@@ -69,52 +78,54 @@ public async Task CreateIndexAsync(string index, int vectorSize, CancellationTok
return;
}
- using var connection = new SqlConnection(this._config.ConnectionString);
+ var connection = new SqlConnection(this._config.ConnectionString);
+ await using var disposableConnection = connection.ConfigureAwait(false);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
- using (SqlCommand command = connection.CreateCommand())
- {
- command.CommandText = $@"
- BEGIN TRANSACTION;
-
- INSERT INTO {this.GetFullTableName(this._config.MemoryCollectionTableName)}([id])
- VALUES (@index);
-
- IF OBJECT_ID(N'{this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}', N'U') IS NULL
- CREATE TABLE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}
- (
- [memory_id] UNIQUEIDENTIFIER NOT NULL,
- [vector_value_id] [int] NOT NULL,
- [vector_value] [float] NOT NULL
- FOREIGN KEY ([memory_id]) REFERENCES {this.GetFullTableName(this._config.MemoryTableName)}([id])
- );
+ SqlCommand command = connection.CreateCommand();
+ await using var disposableCommand = command.ConfigureAwait(false);
- IF OBJECT_ID(N'[{this._config.Schema}.IXC_{$"{this._config.EmbeddingsTableName}_{index}"}]', N'U') IS NULL
- CREATE CLUSTERED COLUMNSTORE INDEX [IXC_{$"{this._config.EmbeddingsTableName}_{index}"}]
- ON {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}
- {(this.GetSqlServerMajorVersionNumber() >= 16 ? "ORDER ([memory_id])" : "")};
-
- IF OBJECT_ID(N'{this.GetFullTableName($"{this._config.TagsTableName}_{index}")}', N'U') IS NULL
- CREATE TABLE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")}
- (
- [memory_id] UNIQUEIDENTIFIER NOT NULL,
- [name] NVARCHAR(256) NOT NULL,
- [value] NVARCHAR(256) NOT NULL,
- FOREIGN KEY ([memory_id]) REFERENCES {this.GetFullTableName(this._config.MemoryTableName)}([id])
- );
+ command.CommandText = $@"
+ BEGIN TRANSACTION;
- COMMIT;
- ";
+ INSERT INTO {this.GetFullTableName(this._config.MemoryCollectionTableName)}([id])
+ VALUES (@index);
- command.Parameters.AddWithValue("@index", index);
+ IF OBJECT_ID(N'{this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}', N'U') IS NULL
+ CREATE TABLE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}
+ (
+ [memory_id] UNIQUEIDENTIFIER NOT NULL,
+ [vector_value_id] [int] NOT NULL,
+ [vector_value] [float] NOT NULL
+ FOREIGN KEY ([memory_id]) REFERENCES {this.GetFullTableName(this._config.MemoryTableName)}([id])
+ );
+
+ IF OBJECT_ID(N'[{this._config.Schema}.IXC_{$"{this._config.EmbeddingsTableName}_{index}"}]', N'U') IS NULL
+ CREATE CLUSTERED COLUMNSTORE INDEX [IXC_{$"{this._config.EmbeddingsTableName}_{index}"}]
+ ON {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}
+ {(this._cachedServerVersion >= 16 ? "ORDER ([memory_id])" : "")};
+
+ IF OBJECT_ID(N'{this.GetFullTableName($"{this._config.TagsTableName}_{index}")}', N'U') IS NULL
+ CREATE TABLE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")}
+ (
+ [memory_id] UNIQUEIDENTIFIER NOT NULL,
+ [name] NVARCHAR(256) NOT NULL,
+ [value] NVARCHAR(256) NOT NULL,
+ FOREIGN KEY ([memory_id]) REFERENCES {this.GetFullTableName(this._config.MemoryTableName)}([id])
+ );
- await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
- }
+ COMMIT;";
+
+ command.Parameters.AddWithValue("@index", index);
+
+ await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
///
public async Task DeleteAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default)
{
+ if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); }
+
index = NormalizeIndexName(index);
if (!(await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false)))
@@ -123,12 +134,14 @@ public async Task DeleteAsync(string index, MemoryRecord record, CancellationTok
return;
}
- using var connection = new SqlConnection(this._config.ConnectionString);
+ var connection = new SqlConnection(this._config.ConnectionString);
+ await using var disposableConnection = connection.ConfigureAwait(false);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
- using (SqlCommand command = connection.CreateCommand())
- {
- command.CommandText = $@"
+ SqlCommand command = connection.CreateCommand();
+ await using var disposableCommand = command.ConfigureAwait(false);
+
+ command.CommandText = $@"
BEGIN TRANSACTION;
DELETE [embeddings]
@@ -149,16 +162,17 @@ DELETE [tags]
COMMIT;";
- command.Parameters.AddWithValue("@index", index);
- command.Parameters.AddWithValue("@key", record.Id);
+ command.Parameters.AddWithValue("@index", index);
+ command.Parameters.AddWithValue("@key", record.Id);
- await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
- }
+ await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
///
public async Task DeleteIndexAsync(string index, CancellationToken cancellationToken = default)
{
+ if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); }
+
index = NormalizeIndexName(index);
if (!(await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false)))
@@ -167,46 +181,51 @@ public async Task DeleteIndexAsync(string index, CancellationToken cancellationT
return;
}
- using var connection = new SqlConnection(this._config.ConnectionString);
+ var connection = new SqlConnection(this._config.ConnectionString);
+ await using var disposableConnection = connection.ConfigureAwait(false);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
- using (SqlCommand command = connection.CreateCommand())
- {
- command.CommandText = $@"
- BEGIN TRANSACTION;
+ SqlCommand command = connection.CreateCommand();
+ await using var disposableCommand = command.ConfigureAwait(false);
- DROP TABLE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")};
- DROP TABLE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")};
+ command.CommandText = $@"
+ BEGIN TRANSACTION;
- DELETE FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)}
- WHERE [id] = @index;
+ DROP TABLE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")};
+ DROP TABLE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")};
- COMMIT;";
+ DELETE FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)}
+ WHERE [id] = @index;
- command.Parameters.AddWithValue("@index", index);
+ COMMIT;";
- await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
- }
+ command.Parameters.AddWithValue("@index", index);
+
+ await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
///
public async Task> GetIndexesAsync(CancellationToken cancellationToken = default)
{
+ if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); }
+
List indexes = new();
- using var connection = new SqlConnection(this._config.ConnectionString);
+ var connection = new SqlConnection(this._config.ConnectionString);
+ await using var disposableConnection = connection.ConfigureAwait(false);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
- using (SqlCommand command = connection.CreateCommand())
- {
- command.CommandText = $"SELECT [id] FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)}";
+ SqlCommand command = connection.CreateCommand();
+ await using var disposableCommand = command.ConfigureAwait(false);
- using var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
+ command.CommandText = $"SELECT [id] FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)}";
- while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
- {
- indexes.Add(dataReader.GetString(dataReader.GetOrdinal("id")));
- }
+ var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
+ await using var disposableReader = dataReader.ConfigureAwait(false);
+
+ while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
+ {
+ indexes.Add(dataReader.GetString(dataReader.GetOrdinal("id")));
}
return indexes;
@@ -215,6 +234,8 @@ public async Task> GetIndexesAsync(CancellationToken cancell
///
public async IAsyncEnumerable GetListAsync(string index, ICollection? filters = null, int limit = 1, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
+ if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); }
+
index = NormalizeIndexName(index);
if (!(await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false)))
@@ -235,14 +256,16 @@ public async IAsyncEnumerable GetListAsync(string index, ICollecti
limit = int.MaxValue;
}
- using var connection = new SqlConnection(this._config.ConnectionString);
+ var connection = new SqlConnection(this._config.ConnectionString);
+ await using var disposableConnection = connection.ConfigureAwait(false);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
- using (SqlCommand command = connection.CreateCommand())
- {
- var tagFilters = new TagCollection();
+ SqlCommand command = connection.CreateCommand();
+ await using var disposableCommand = command.ConfigureAwait(false);
- command.CommandText = $@"
+ var tagFilters = new TagCollection();
+
+ command.CommandText = $@"
WITH [filters] AS
(
SELECT
@@ -258,22 +281,24 @@ SELECT TOP (@limit)
AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
{this.GenerateFilters(index, command.Parameters, filters)};";
- command.Parameters.AddWithValue("@index", index);
- command.Parameters.AddWithValue("@limit", limit);
- command.Parameters.AddWithValue("@filters", JsonSerializer.Serialize(tagFilters));
+ command.Parameters.AddWithValue("@index", index);
+ command.Parameters.AddWithValue("@limit", limit);
+ command.Parameters.AddWithValue("@filters", JsonSerializer.Serialize(tagFilters));
- using var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
+ var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
+ await using var disposableReader = dataReader.ConfigureAwait(false);
- while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
- {
- yield return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false);
- }
+ while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
+ {
+ yield return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false);
}
}
///
public async IAsyncEnumerable<(MemoryRecord, double)> GetSimilarListAsync(string index, string text, ICollection? filters = null, double minRelevance = 0, int limit = 1, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
+ if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); }
+
index = NormalizeIndexName(index);
if (!(await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false)))
@@ -295,14 +320,16 @@ SELECT TOP (@limit)
$"{this.GetFullTableName(this._config.MemoryTableName)}.[embedding]";
}
- using var connection = new SqlConnection(this._config.ConnectionString);
+ var connection = new SqlConnection(this._config.ConnectionString);
+ await using var disposableConnection = connection.ConfigureAwait(false);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
- using (SqlCommand command = connection.CreateCommand())
- {
- var generatedFilters = this.GenerateFilters(index, command.Parameters, filters);
+ SqlCommand command = connection.CreateCommand();
+ await using var disposableCommand = command.ConfigureAwait(false);
- command.CommandText = $@"
+ var generatedFilters = this.GenerateFilters(index, command.Parameters, filters);
+
+ command.CommandText = $@"
WITH
[embedding] as
(
@@ -348,24 +375,26 @@ INNER JOIN
{generatedFilters}
ORDER BY [cosine_similarity] desc";
- command.Parameters.AddWithValue("@vector", JsonSerializer.Serialize(embedding.Data.ToArray()));
- command.Parameters.AddWithValue("@index", index);
- command.Parameters.AddWithValue("@min_relevance_score", minRelevance);
- command.Parameters.AddWithValue("@limit", limit);
+ command.Parameters.AddWithValue("@vector", JsonSerializer.Serialize(embedding.Data.ToArray()));
+ command.Parameters.AddWithValue("@index", index);
+ command.Parameters.AddWithValue("@min_relevance_score", minRelevance);
+ command.Parameters.AddWithValue("@limit", limit);
- using var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
+ var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
+ await using var disposableReader = dataReader.ConfigureAwait(false);
- while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
- {
- double cosineSimilarity = dataReader.GetDouble(dataReader.GetOrdinal("cosine_similarity"));
- yield return (await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false), cosineSimilarity);
- }
+ while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false))
+ {
+ double cosineSimilarity = dataReader.GetDouble(dataReader.GetOrdinal("cosine_similarity"));
+ yield return (await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false), cosineSimilarity);
}
}
///
public async Task UpsertAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default)
{
+ if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); }
+
await foreach (var item in this.BatchUpsertAsync(index, new[] { record }, cancellationToken).ConfigureAwait(false))
{
return item;
@@ -377,6 +406,8 @@ public async Task UpsertAsync(string index, MemoryRecord record, Cancell
///
public async IAsyncEnumerable BatchUpsertAsync(string index, IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
+ if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); }
+
index = NormalizeIndexName(index);
if (!(await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false)))
@@ -384,95 +415,147 @@ public async IAsyncEnumerable BatchUpsertAsync(string index, IEnumerable
throw new IndexNotFoundException($"The index '{index}' does not exist.");
}
- using var connection = new SqlConnection(this._config.ConnectionString);
+ var connection = new SqlConnection(this._config.ConnectionString);
+ await using var disposableConnection = connection.ConfigureAwait(false);
await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
foreach (var record in records)
{
- using (SqlCommand command = connection.CreateCommand())
- {
- command.CommandText = $@"
- BEGIN TRANSACTION;
-
- MERGE INTO {this.GetFullTableName(this._config.MemoryTableName)}
- USING (SELECT @key) as [src]([key])
- ON {this.GetFullTableName(this._config.MemoryTableName)}.[key] = [src].[key]
- WHEN MATCHED THEN
- UPDATE SET payload=@payload, embedding=@embedding, tags=@tags
- WHEN NOT MATCHED THEN
- INSERT ([id], [key], [collection], [payload], [tags], [embedding])
- VALUES (NEWID(), @key, @index, @payload, @tags, @embedding);
-
- MERGE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} AS [tgt]
- USING (
- SELECT
- {this.GetFullTableName(this._config.MemoryTableName)}.[id],
- cast([vector].[key] AS INT) AS [vector_value_id],
- cast([vector].[value] AS FLOAT) AS [vector_value]
- FROM {this.GetFullTableName(this._config.MemoryTableName)}
- CROSS APPLY
- openjson(@embedding) [vector]
- WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
- AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
- ) AS [src]
- ON [tgt].[memory_id] = [src].[id] AND [tgt].[vector_value_id] = [src].[vector_value_id]
- WHEN MATCHED THEN
- UPDATE SET [tgt].[vector_value] = [src].[vector_value]
- WHEN NOT MATCHED THEN
- INSERT ([memory_id], [vector_value_id], [vector_value])
- VALUES ([src].[id],
- [src].[vector_value_id],
- [src].[vector_value] );
-
- DELETE FROM [tgt]
- FROM {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt]
- INNER JOIN {this.GetFullTableName(this._config.MemoryTableName)} ON [tgt].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id]
+ SqlCommand command = connection.CreateCommand();
+ await using var disposableCommand = command.ConfigureAwait(false);
+
+ command.CommandText = $@"
+ BEGIN TRANSACTION;
+
+ MERGE INTO {this.GetFullTableName(this._config.MemoryTableName)}
+ USING (SELECT @key) as [src]([key])
+ ON {this.GetFullTableName(this._config.MemoryTableName)}.[key] = [src].[key]
+ WHEN MATCHED THEN
+ UPDATE SET payload=@payload, embedding=@embedding, tags=@tags
+ WHEN NOT MATCHED THEN
+ INSERT ([id], [key], [collection], [payload], [tags], [embedding])
+ VALUES (NEWID(), @key, @index, @payload, @tags, @embedding);
+
+ MERGE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} AS [tgt]
+ USING (
+ SELECT
+ {this.GetFullTableName(this._config.MemoryTableName)}.[id],
+ cast([vector].[key] AS INT) AS [vector_value_id],
+ cast([vector].[value] AS FLOAT) AS [vector_value]
+ FROM {this.GetFullTableName(this._config.MemoryTableName)}
+ CROSS APPLY
+ openjson(@embedding) [vector]
WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
- AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index;
-
- MERGE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt]
- USING (
- SELECT
- {this.GetFullTableName(this._config.MemoryTableName)}.[id],
- cast([tags].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [tag_name],
- [tag_value].[value] AS [value]
- FROM {this.GetFullTableName(this._config.MemoryTableName)}
- CROSS APPLY openjson(@tags) [tags]
- CROSS APPLY openjson(cast([tags].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS) [tag_value]
- WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
- AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
- ) AS [src]
- ON [tgt].[memory_id] = [src].[id] AND [tgt].[name] = [src].[tag_name]
- WHEN MATCHED THEN
- UPDATE SET [tgt].[value] = [src].[value]
- WHEN NOT MATCHED THEN
- INSERT ([memory_id], [name], [value])
- VALUES ([src].[id],
- [src].[tag_name],
- [src].[value]);
-
- COMMIT;";
-
- command.Parameters.AddWithValue("@index", index);
- command.Parameters.AddWithValue("@key", record.Id);
- command.Parameters.AddWithValue("@payload", JsonSerializer.Serialize(record.Payload) ?? (object)DBNull.Value);
- command.Parameters.AddWithValue("@tags", JsonSerializer.Serialize(record.Tags) ?? (object)DBNull.Value);
- command.Parameters.AddWithValue("@embedding", JsonSerializer.Serialize(record.Vector.Data.ToArray()));
-
- await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
-
- yield return record.Id;
- }
+ AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
+ ) AS [src]
+ ON [tgt].[memory_id] = [src].[id] AND [tgt].[vector_value_id] = [src].[vector_value_id]
+ WHEN MATCHED THEN
+ UPDATE SET [tgt].[vector_value] = [src].[vector_value]
+ WHEN NOT MATCHED THEN
+ INSERT ([memory_id], [vector_value_id], [vector_value])
+ VALUES ([src].[id],
+ [src].[vector_value_id],
+ [src].[vector_value] );
+
+ DELETE FROM [tgt]
+ FROM {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt]
+ INNER JOIN {this.GetFullTableName(this._config.MemoryTableName)} ON [tgt].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id]
+ WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
+ AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index;
+
+ MERGE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt]
+ USING (
+ SELECT
+ {this.GetFullTableName(this._config.MemoryTableName)}.[id],
+ cast([tags].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [tag_name],
+ [tag_value].[value] AS [value]
+ FROM {this.GetFullTableName(this._config.MemoryTableName)}
+ CROSS APPLY openjson(@tags) [tags]
+ CROSS APPLY openjson(cast([tags].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS) [tag_value]
+ WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key
+ AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index
+ ) AS [src]
+ ON [tgt].[memory_id] = [src].[id] AND [tgt].[name] = [src].[tag_name]
+ WHEN MATCHED THEN
+ UPDATE SET [tgt].[value] = [src].[value]
+ WHEN NOT MATCHED THEN
+ INSERT ([memory_id], [name], [value])
+ VALUES ([src].[id],
+ [src].[tag_name],
+ [src].[value]);
+
+ COMMIT;";
+
+ command.Parameters.AddWithValue("@index", index);
+ command.Parameters.AddWithValue("@key", record.Id);
+ command.Parameters.AddWithValue("@payload", JsonSerializer.Serialize(record.Payload) ?? (object)DBNull.Value);
+ command.Parameters.AddWithValue("@tags", JsonSerializer.Serialize(record.Tags) ?? (object)DBNull.Value);
+ command.Parameters.AddWithValue("@embedding", JsonSerializer.Serialize(record.Vector.Data.ToArray()));
+
+ await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
+
+ yield return record.Id;
}
}
+ ///
+ public void Dispose()
+ {
+ this._initSemaphore.Dispose();
+ }
+
#region private ================================================================================
+ // Note: "_" is allowed in Postgres, but we normalize it to "-" for consistency with other DBs
+ private static readonly Regex s_replaceIndexNameCharsRegex = new(@"[\s|\\|/|.|_|:]");
+ private const string ValidSeparator = "-";
+
+ ///
+ /// Prepare instance, ensuring tables exist and reusable info is cached.
+ ///
+ private async Task InitAsync(CancellationToken cancellationToken)
+ {
+ if (this._isReady) { return; }
+
+ await this._initSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
+ if (this._isReady) { return; }
+
+ try
+ {
+ await this.CacheSqlServerMajorVersionNumberAsync(cancellationToken).ConfigureAwait(false);
+ await this.CreateTablesIfNotExistsAsync(cancellationToken).ConfigureAwait(false);
+ this._isReady = true;
+ }
+ finally
+ {
+ this._initSemaphore.Release();
+ }
+ }
+
+ ///
+ /// Cache SQL Server version
+ ///
+ private async Task CacheSqlServerMajorVersionNumberAsync(CancellationToken cancellationToken)
+ {
+ var connection = new SqlConnection(this._config.ConnectionString);
+ await using var disposableConnection = connection.ConfigureAwait(false);
+ await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
+
+ SqlCommand command = connection.CreateCommand();
+ await using var disposableCommand = command.ConfigureAwait(false);
+
+ command.CommandText = "SELECT SERVERPROPERTY('ProductMajorVersion')";
+
+ var result = await command.ExecuteScalarAsync(cancellationToken).ConfigureAwait(false);
+
+ this._cachedServerVersion = Convert.ToInt32(result, CultureInfo.InvariantCulture);
+ }
+
///
/// Creates the SQL Server tables if they do not exist.
///
///
- private void CreateTablesIfNotExists()
+ private async Task CreateTablesIfNotExistsAsync(CancellationToken cancellationToken)
{
var sql = $@"IF NOT EXISTS (SELECT *
FROM sys.schemas
@@ -495,17 +578,17 @@ [embedding] NVARCHAR(MAX),
PRIMARY KEY ([id]),
FOREIGN KEY ([collection]) REFERENCES {this.GetFullTableName(this._config.MemoryCollectionTableName)}([id]) ON DELETE CASCADE,
CONSTRAINT UK_{this._config.MemoryTableName} UNIQUE([collection], [key])
- );
- ";
+ );";
- using var connection = new SqlConnection(this._config.ConnectionString);
- connection.Open();
+ var connection = new SqlConnection(this._config.ConnectionString);
+ await using var disposableConnection = connection.ConfigureAwait(false);
+ await connection.OpenAsync(cancellationToken).ConfigureAwait(false);
- using (SqlCommand command = connection.CreateCommand())
- {
- command.CommandText = sql;
- command.ExecuteNonQuery();
- }
+ SqlCommand command = connection.CreateCommand();
+ await using var disposableCommand = command.ConfigureAwait(false);
+
+ command.CommandText = sql;
+ await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false);
}
///
@@ -513,22 +596,12 @@ FOREIGN KEY ([collection]) REFERENCES {this.GetFullTableName(this._config.Memory
///
/// The index name.
/// The cancellation token.
- ///
+ /// True is the index exists
private async Task DoesIndexExistsAsync(string indexName,
CancellationToken cancellationToken = default)
{
- var collections = await this.GetIndexesAsync(cancellationToken)
- .ConfigureAwait(false);
-
- foreach (var item in collections)
- {
- if (item.Equals(indexName, StringComparison.OrdinalIgnoreCase))
- {
- return true;
- }
- }
-
- return false;
+ var collections = await this.GetIndexesAsync(cancellationToken).ConfigureAwait(false);
+ return collections.Any(x => x.Equals(indexName, StringComparison.OrdinalIgnoreCase));
}
///
@@ -604,9 +677,10 @@ private string GenerateFilters(string index, SqlParameterCollection parameters,
private async Task ReadEntryAsync(SqlDataReader dataReader, bool withEmbedding, CancellationToken cancellationToken = default)
{
- var entry = new MemoryRecord();
-
- entry.Id = dataReader.GetString(dataReader.GetOrdinal("key"));
+ var entry = new MemoryRecord
+ {
+ Id = dataReader.GetString(dataReader.GetOrdinal("key"))
+ };
if (!(await dataReader.IsDBNullAsync(dataReader.GetOrdinal("payload"), cancellationToken).ConfigureAwait(false)))
{
@@ -626,10 +700,6 @@ private async Task ReadEntryAsync(SqlDataReader dataReader, bool w
return entry;
}
- // Note: "_" is allowed in Postgres, but we normalize it to "-" for consistency with other DBs
- private static readonly Regex s_replaceIndexNameCharsRegex = new(@"[\s|\\|/|.|_|:]");
- private const string ValidSeparator = "-";
-
private static string NormalizeIndexName(string index)
{
ArgumentNullExceptionEx.ThrowIfNullOrWhiteSpace(index, nameof(index), "The index name is empty");
@@ -639,20 +709,5 @@ private static string NormalizeIndexName(string index)
return index;
}
- private int GetSqlServerMajorVersionNumber()
- {
- using var connection = new SqlConnection(this._config.ConnectionString);
- connection.Open();
-
- using (SqlCommand command = connection.CreateCommand())
- {
- command.CommandText = "SELECT SERVERPROPERTY('ProductMajorVersion')";
-
- var result = command.ExecuteScalar();
-
- return Convert.ToInt32(result, CultureInfo.InvariantCulture);
- }
- }
-
#endregion
}