diff --git a/extensions/SQLServer/SQLServer/SqlServerMemory.cs b/extensions/SQLServer/SQLServer/SqlServerMemory.cs index eb5388e89..3ea8b82dc 100644 --- a/extensions/SQLServer/SQLServer/SqlServerMemory.cs +++ b/extensions/SQLServer/SQLServer/SqlServerMemory.cs @@ -17,7 +17,7 @@ namespace Microsoft.KernelMemory.MemoryDb.SQLServer; /// Represents a memory store implementation that uses a SQL Server database as its backing store. /// #pragma warning disable CA2100 // SQL reviewed for user input validation -public class SqlServerMemory : IMemoryDb, IMemoryDbBatchUpsert +public sealed class SqlServerMemory : IMemoryDb, IMemoryDbBatchUpsert, IDisposable { /// /// The SQL Server configuration. @@ -34,6 +34,21 @@ public class SqlServerMemory : IMemoryDb, IMemoryDbBatchUpsert /// private readonly ILogger _log; + /// + /// Flag used to initialize the client on the first call + /// + private bool _isReady = false; + + /// + /// Lock used to initialize the class instance + /// + private readonly SemaphoreSlim _initSemaphore = new(1, 1); + + /// + /// SQL Server version, retrieved on the first connection + /// + private int _cachedServerVersion = int.MinValue; + /// /// Initializes a new instance of the class. /// @@ -45,22 +60,16 @@ public SqlServerMemory( ITextEmbeddingGenerator embeddingGenerator, ILogger? log = null) { - this._embeddingGenerator = embeddingGenerator; - this._log = log ?? DefaultLogger.Instance; - this._config = config; - - if (this._embeddingGenerator == null) - { - throw new SqlServerMemoryException("Embedding generator not configured"); - } - - this.CreateTablesIfNotExists(); + this._embeddingGenerator = embeddingGenerator ?? throw new ConfigurationException("Embedding generator not configured"); + this._log = log ?? DefaultLogger.Instance; } /// public async Task CreateIndexAsync(string index, int vectorSize, CancellationToken cancellationToken = default) { + if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); } + index = NormalizeIndexName(index); if (await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false)) @@ -69,52 +78,54 @@ public async Task CreateIndexAsync(string index, int vectorSize, CancellationTok return; } - using var connection = new SqlConnection(this._config.ConnectionString); + var connection = new SqlConnection(this._config.ConnectionString); + await using var disposableConnection = connection.ConfigureAwait(false); await connection.OpenAsync(cancellationToken).ConfigureAwait(false); - using (SqlCommand command = connection.CreateCommand()) - { - command.CommandText = $@" - BEGIN TRANSACTION; - - INSERT INTO {this.GetFullTableName(this._config.MemoryCollectionTableName)}([id]) - VALUES (@index); - - IF OBJECT_ID(N'{this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}', N'U') IS NULL - CREATE TABLE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} - ( - [memory_id] UNIQUEIDENTIFIER NOT NULL, - [vector_value_id] [int] NOT NULL, - [vector_value] [float] NOT NULL - FOREIGN KEY ([memory_id]) REFERENCES {this.GetFullTableName(this._config.MemoryTableName)}([id]) - ); + SqlCommand command = connection.CreateCommand(); + await using var disposableCommand = command.ConfigureAwait(false); - IF OBJECT_ID(N'[{this._config.Schema}.IXC_{$"{this._config.EmbeddingsTableName}_{index}"}]', N'U') IS NULL - CREATE CLUSTERED COLUMNSTORE INDEX [IXC_{$"{this._config.EmbeddingsTableName}_{index}"}] - ON {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} - {(this.GetSqlServerMajorVersionNumber() >= 16 ? "ORDER ([memory_id])" : "")}; - - IF OBJECT_ID(N'{this.GetFullTableName($"{this._config.TagsTableName}_{index}")}', N'U') IS NULL - CREATE TABLE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} - ( - [memory_id] UNIQUEIDENTIFIER NOT NULL, - [name] NVARCHAR(256) NOT NULL, - [value] NVARCHAR(256) NOT NULL, - FOREIGN KEY ([memory_id]) REFERENCES {this.GetFullTableName(this._config.MemoryTableName)}([id]) - ); + command.CommandText = $@" + BEGIN TRANSACTION; - COMMIT; - "; + INSERT INTO {this.GetFullTableName(this._config.MemoryCollectionTableName)}([id]) + VALUES (@index); - command.Parameters.AddWithValue("@index", index); + IF OBJECT_ID(N'{this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}', N'U') IS NULL + CREATE TABLE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} + ( + [memory_id] UNIQUEIDENTIFIER NOT NULL, + [vector_value_id] [int] NOT NULL, + [vector_value] [float] NOT NULL + FOREIGN KEY ([memory_id]) REFERENCES {this.GetFullTableName(this._config.MemoryTableName)}([id]) + ); + + IF OBJECT_ID(N'[{this._config.Schema}.IXC_{$"{this._config.EmbeddingsTableName}_{index}"}]', N'U') IS NULL + CREATE CLUSTERED COLUMNSTORE INDEX [IXC_{$"{this._config.EmbeddingsTableName}_{index}"}] + ON {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} + {(this._cachedServerVersion >= 16 ? "ORDER ([memory_id])" : "")}; + + IF OBJECT_ID(N'{this.GetFullTableName($"{this._config.TagsTableName}_{index}")}', N'U') IS NULL + CREATE TABLE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} + ( + [memory_id] UNIQUEIDENTIFIER NOT NULL, + [name] NVARCHAR(256) NOT NULL, + [value] NVARCHAR(256) NOT NULL, + FOREIGN KEY ([memory_id]) REFERENCES {this.GetFullTableName(this._config.MemoryTableName)}([id]) + ); - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } + COMMIT;"; + + command.Parameters.AddWithValue("@index", index); + + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } /// public async Task DeleteAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default) { + if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); } + index = NormalizeIndexName(index); if (!(await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false))) @@ -123,12 +134,14 @@ public async Task DeleteAsync(string index, MemoryRecord record, CancellationTok return; } - using var connection = new SqlConnection(this._config.ConnectionString); + var connection = new SqlConnection(this._config.ConnectionString); + await using var disposableConnection = connection.ConfigureAwait(false); await connection.OpenAsync(cancellationToken).ConfigureAwait(false); - using (SqlCommand command = connection.CreateCommand()) - { - command.CommandText = $@" + SqlCommand command = connection.CreateCommand(); + await using var disposableCommand = command.ConfigureAwait(false); + + command.CommandText = $@" BEGIN TRANSACTION; DELETE [embeddings] @@ -149,16 +162,17 @@ DELETE [tags] COMMIT;"; - command.Parameters.AddWithValue("@index", index); - command.Parameters.AddWithValue("@key", record.Id); + command.Parameters.AddWithValue("@index", index); + command.Parameters.AddWithValue("@key", record.Id); - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } /// public async Task DeleteIndexAsync(string index, CancellationToken cancellationToken = default) { + if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); } + index = NormalizeIndexName(index); if (!(await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false))) @@ -167,46 +181,51 @@ public async Task DeleteIndexAsync(string index, CancellationToken cancellationT return; } - using var connection = new SqlConnection(this._config.ConnectionString); + var connection = new SqlConnection(this._config.ConnectionString); + await using var disposableConnection = connection.ConfigureAwait(false); await connection.OpenAsync(cancellationToken).ConfigureAwait(false); - using (SqlCommand command = connection.CreateCommand()) - { - command.CommandText = $@" - BEGIN TRANSACTION; + SqlCommand command = connection.CreateCommand(); + await using var disposableCommand = command.ConfigureAwait(false); - DROP TABLE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}; - DROP TABLE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")}; + command.CommandText = $@" + BEGIN TRANSACTION; - DELETE FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)} - WHERE [id] = @index; + DROP TABLE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")}; + DROP TABLE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")}; - COMMIT;"; + DELETE FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)} + WHERE [id] = @index; - command.Parameters.AddWithValue("@index", index); + COMMIT;"; - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - } + command.Parameters.AddWithValue("@index", index); + + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } /// public async Task> GetIndexesAsync(CancellationToken cancellationToken = default) { + if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); } + List indexes = new(); - using var connection = new SqlConnection(this._config.ConnectionString); + var connection = new SqlConnection(this._config.ConnectionString); + await using var disposableConnection = connection.ConfigureAwait(false); await connection.OpenAsync(cancellationToken).ConfigureAwait(false); - using (SqlCommand command = connection.CreateCommand()) - { - command.CommandText = $"SELECT [id] FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)}"; + SqlCommand command = connection.CreateCommand(); + await using var disposableCommand = command.ConfigureAwait(false); - using var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + command.CommandText = $"SELECT [id] FROM {this.GetFullTableName(this._config.MemoryCollectionTableName)}"; - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) - { - indexes.Add(dataReader.GetString(dataReader.GetOrdinal("id"))); - } + var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using var disposableReader = dataReader.ConfigureAwait(false); + + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + indexes.Add(dataReader.GetString(dataReader.GetOrdinal("id"))); } return indexes; @@ -215,6 +234,8 @@ public async Task> GetIndexesAsync(CancellationToken cancell /// public async IAsyncEnumerable GetListAsync(string index, ICollection? filters = null, int limit = 1, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); } + index = NormalizeIndexName(index); if (!(await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false))) @@ -235,14 +256,16 @@ public async IAsyncEnumerable GetListAsync(string index, ICollecti limit = int.MaxValue; } - using var connection = new SqlConnection(this._config.ConnectionString); + var connection = new SqlConnection(this._config.ConnectionString); + await using var disposableConnection = connection.ConfigureAwait(false); await connection.OpenAsync(cancellationToken).ConfigureAwait(false); - using (SqlCommand command = connection.CreateCommand()) - { - var tagFilters = new TagCollection(); + SqlCommand command = connection.CreateCommand(); + await using var disposableCommand = command.ConfigureAwait(false); - command.CommandText = $@" + var tagFilters = new TagCollection(); + + command.CommandText = $@" WITH [filters] AS ( SELECT @@ -258,22 +281,24 @@ SELECT TOP (@limit) AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index {this.GenerateFilters(index, command.Parameters, filters)};"; - command.Parameters.AddWithValue("@index", index); - command.Parameters.AddWithValue("@limit", limit); - command.Parameters.AddWithValue("@filters", JsonSerializer.Serialize(tagFilters)); + command.Parameters.AddWithValue("@index", index); + command.Parameters.AddWithValue("@limit", limit); + command.Parameters.AddWithValue("@filters", JsonSerializer.Serialize(tagFilters)); - using var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using var disposableReader = dataReader.ConfigureAwait(false); - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) - { - yield return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false); - } + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + yield return await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false); } } /// public async IAsyncEnumerable<(MemoryRecord, double)> GetSimilarListAsync(string index, string text, ICollection? filters = null, double minRelevance = 0, int limit = 1, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); } + index = NormalizeIndexName(index); if (!(await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false))) @@ -295,14 +320,16 @@ SELECT TOP (@limit) $"{this.GetFullTableName(this._config.MemoryTableName)}.[embedding]"; } - using var connection = new SqlConnection(this._config.ConnectionString); + var connection = new SqlConnection(this._config.ConnectionString); + await using var disposableConnection = connection.ConfigureAwait(false); await connection.OpenAsync(cancellationToken).ConfigureAwait(false); - using (SqlCommand command = connection.CreateCommand()) - { - var generatedFilters = this.GenerateFilters(index, command.Parameters, filters); + SqlCommand command = connection.CreateCommand(); + await using var disposableCommand = command.ConfigureAwait(false); - command.CommandText = $@" + var generatedFilters = this.GenerateFilters(index, command.Parameters, filters); + + command.CommandText = $@" WITH [embedding] as ( @@ -348,24 +375,26 @@ INNER JOIN {generatedFilters} ORDER BY [cosine_similarity] desc"; - command.Parameters.AddWithValue("@vector", JsonSerializer.Serialize(embedding.Data.ToArray())); - command.Parameters.AddWithValue("@index", index); - command.Parameters.AddWithValue("@min_relevance_score", minRelevance); - command.Parameters.AddWithValue("@limit", limit); + command.Parameters.AddWithValue("@vector", JsonSerializer.Serialize(embedding.Data.ToArray())); + command.Parameters.AddWithValue("@index", index); + command.Parameters.AddWithValue("@min_relevance_score", minRelevance); + command.Parameters.AddWithValue("@limit", limit); - using var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + var dataReader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await using var disposableReader = dataReader.ConfigureAwait(false); - while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) - { - double cosineSimilarity = dataReader.GetDouble(dataReader.GetOrdinal("cosine_similarity")); - yield return (await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false), cosineSimilarity); - } + while (await dataReader.ReadAsync(cancellationToken).ConfigureAwait(false)) + { + double cosineSimilarity = dataReader.GetDouble(dataReader.GetOrdinal("cosine_similarity")); + yield return (await this.ReadEntryAsync(dataReader, withEmbeddings, cancellationToken).ConfigureAwait(false), cosineSimilarity); } } /// public async Task UpsertAsync(string index, MemoryRecord record, CancellationToken cancellationToken = default) { + if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); } + await foreach (var item in this.BatchUpsertAsync(index, new[] { record }, cancellationToken).ConfigureAwait(false)) { return item; @@ -377,6 +406,8 @@ public async Task UpsertAsync(string index, MemoryRecord record, Cancell /// public async IAsyncEnumerable BatchUpsertAsync(string index, IEnumerable records, [EnumeratorCancellation] CancellationToken cancellationToken = default) { + if (!this._isReady) { await this.InitAsync(cancellationToken).ConfigureAwait(false); } + index = NormalizeIndexName(index); if (!(await this.DoesIndexExistsAsync(index, cancellationToken).ConfigureAwait(false))) @@ -384,95 +415,147 @@ public async IAsyncEnumerable BatchUpsertAsync(string index, IEnumerable throw new IndexNotFoundException($"The index '{index}' does not exist."); } - using var connection = new SqlConnection(this._config.ConnectionString); + var connection = new SqlConnection(this._config.ConnectionString); + await using var disposableConnection = connection.ConfigureAwait(false); await connection.OpenAsync(cancellationToken).ConfigureAwait(false); foreach (var record in records) { - using (SqlCommand command = connection.CreateCommand()) - { - command.CommandText = $@" - BEGIN TRANSACTION; - - MERGE INTO {this.GetFullTableName(this._config.MemoryTableName)} - USING (SELECT @key) as [src]([key]) - ON {this.GetFullTableName(this._config.MemoryTableName)}.[key] = [src].[key] - WHEN MATCHED THEN - UPDATE SET payload=@payload, embedding=@embedding, tags=@tags - WHEN NOT MATCHED THEN - INSERT ([id], [key], [collection], [payload], [tags], [embedding]) - VALUES (NEWID(), @key, @index, @payload, @tags, @embedding); - - MERGE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} AS [tgt] - USING ( - SELECT - {this.GetFullTableName(this._config.MemoryTableName)}.[id], - cast([vector].[key] AS INT) AS [vector_value_id], - cast([vector].[value] AS FLOAT) AS [vector_value] - FROM {this.GetFullTableName(this._config.MemoryTableName)} - CROSS APPLY - openjson(@embedding) [vector] - WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key - AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index - ) AS [src] - ON [tgt].[memory_id] = [src].[id] AND [tgt].[vector_value_id] = [src].[vector_value_id] - WHEN MATCHED THEN - UPDATE SET [tgt].[vector_value] = [src].[vector_value] - WHEN NOT MATCHED THEN - INSERT ([memory_id], [vector_value_id], [vector_value]) - VALUES ([src].[id], - [src].[vector_value_id], - [src].[vector_value] ); - - DELETE FROM [tgt] - FROM {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt] - INNER JOIN {this.GetFullTableName(this._config.MemoryTableName)} ON [tgt].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id] + SqlCommand command = connection.CreateCommand(); + await using var disposableCommand = command.ConfigureAwait(false); + + command.CommandText = $@" + BEGIN TRANSACTION; + + MERGE INTO {this.GetFullTableName(this._config.MemoryTableName)} + USING (SELECT @key) as [src]([key]) + ON {this.GetFullTableName(this._config.MemoryTableName)}.[key] = [src].[key] + WHEN MATCHED THEN + UPDATE SET payload=@payload, embedding=@embedding, tags=@tags + WHEN NOT MATCHED THEN + INSERT ([id], [key], [collection], [payload], [tags], [embedding]) + VALUES (NEWID(), @key, @index, @payload, @tags, @embedding); + + MERGE {this.GetFullTableName($"{this._config.EmbeddingsTableName}_{index}")} AS [tgt] + USING ( + SELECT + {this.GetFullTableName(this._config.MemoryTableName)}.[id], + cast([vector].[key] AS INT) AS [vector_value_id], + cast([vector].[value] AS FLOAT) AS [vector_value] + FROM {this.GetFullTableName(this._config.MemoryTableName)} + CROSS APPLY + openjson(@embedding) [vector] WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key - AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index; - - MERGE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt] - USING ( - SELECT - {this.GetFullTableName(this._config.MemoryTableName)}.[id], - cast([tags].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [tag_name], - [tag_value].[value] AS [value] - FROM {this.GetFullTableName(this._config.MemoryTableName)} - CROSS APPLY openjson(@tags) [tags] - CROSS APPLY openjson(cast([tags].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS) [tag_value] - WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key - AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index - ) AS [src] - ON [tgt].[memory_id] = [src].[id] AND [tgt].[name] = [src].[tag_name] - WHEN MATCHED THEN - UPDATE SET [tgt].[value] = [src].[value] - WHEN NOT MATCHED THEN - INSERT ([memory_id], [name], [value]) - VALUES ([src].[id], - [src].[tag_name], - [src].[value]); - - COMMIT;"; - - command.Parameters.AddWithValue("@index", index); - command.Parameters.AddWithValue("@key", record.Id); - command.Parameters.AddWithValue("@payload", JsonSerializer.Serialize(record.Payload) ?? (object)DBNull.Value); - command.Parameters.AddWithValue("@tags", JsonSerializer.Serialize(record.Tags) ?? (object)DBNull.Value); - command.Parameters.AddWithValue("@embedding", JsonSerializer.Serialize(record.Vector.Data.ToArray())); - - await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); - - yield return record.Id; - } + AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index + ) AS [src] + ON [tgt].[memory_id] = [src].[id] AND [tgt].[vector_value_id] = [src].[vector_value_id] + WHEN MATCHED THEN + UPDATE SET [tgt].[vector_value] = [src].[vector_value] + WHEN NOT MATCHED THEN + INSERT ([memory_id], [vector_value_id], [vector_value]) + VALUES ([src].[id], + [src].[vector_value_id], + [src].[vector_value] ); + + DELETE FROM [tgt] + FROM {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt] + INNER JOIN {this.GetFullTableName(this._config.MemoryTableName)} ON [tgt].[memory_id] = {this.GetFullTableName(this._config.MemoryTableName)}.[id] + WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key + AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index; + + MERGE {this.GetFullTableName($"{this._config.TagsTableName}_{index}")} AS [tgt] + USING ( + SELECT + {this.GetFullTableName(this._config.MemoryTableName)}.[id], + cast([tags].[key] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS AS [tag_name], + [tag_value].[value] AS [value] + FROM {this.GetFullTableName(this._config.MemoryTableName)} + CROSS APPLY openjson(@tags) [tags] + CROSS APPLY openjson(cast([tags].[value] AS NVARCHAR(256)) COLLATE SQL_Latin1_General_CP1_CI_AS) [tag_value] + WHERE {this.GetFullTableName(this._config.MemoryTableName)}.[key] = @key + AND {this.GetFullTableName(this._config.MemoryTableName)}.[collection] = @index + ) AS [src] + ON [tgt].[memory_id] = [src].[id] AND [tgt].[name] = [src].[tag_name] + WHEN MATCHED THEN + UPDATE SET [tgt].[value] = [src].[value] + WHEN NOT MATCHED THEN + INSERT ([memory_id], [name], [value]) + VALUES ([src].[id], + [src].[tag_name], + [src].[value]); + + COMMIT;"; + + command.Parameters.AddWithValue("@index", index); + command.Parameters.AddWithValue("@key", record.Id); + command.Parameters.AddWithValue("@payload", JsonSerializer.Serialize(record.Payload) ?? (object)DBNull.Value); + command.Parameters.AddWithValue("@tags", JsonSerializer.Serialize(record.Tags) ?? (object)DBNull.Value); + command.Parameters.AddWithValue("@embedding", JsonSerializer.Serialize(record.Vector.Data.ToArray())); + + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + + yield return record.Id; } } + /// + public void Dispose() + { + this._initSemaphore.Dispose(); + } + #region private ================================================================================ + // Note: "_" is allowed in Postgres, but we normalize it to "-" for consistency with other DBs + private static readonly Regex s_replaceIndexNameCharsRegex = new(@"[\s|\\|/|.|_|:]"); + private const string ValidSeparator = "-"; + + /// + /// Prepare instance, ensuring tables exist and reusable info is cached. + /// + private async Task InitAsync(CancellationToken cancellationToken) + { + if (this._isReady) { return; } + + await this._initSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false); + if (this._isReady) { return; } + + try + { + await this.CacheSqlServerMajorVersionNumberAsync(cancellationToken).ConfigureAwait(false); + await this.CreateTablesIfNotExistsAsync(cancellationToken).ConfigureAwait(false); + this._isReady = true; + } + finally + { + this._initSemaphore.Release(); + } + } + + /// + /// Cache SQL Server version + /// + private async Task CacheSqlServerMajorVersionNumberAsync(CancellationToken cancellationToken) + { + var connection = new SqlConnection(this._config.ConnectionString); + await using var disposableConnection = connection.ConfigureAwait(false); + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + + SqlCommand command = connection.CreateCommand(); + await using var disposableCommand = command.ConfigureAwait(false); + + command.CommandText = "SELECT SERVERPROPERTY('ProductMajorVersion')"; + + var result = await command.ExecuteScalarAsync(cancellationToken).ConfigureAwait(false); + + this._cachedServerVersion = Convert.ToInt32(result, CultureInfo.InvariantCulture); + } + /// /// Creates the SQL Server tables if they do not exist. /// /// - private void CreateTablesIfNotExists() + private async Task CreateTablesIfNotExistsAsync(CancellationToken cancellationToken) { var sql = $@"IF NOT EXISTS (SELECT * FROM sys.schemas @@ -495,17 +578,17 @@ [embedding] NVARCHAR(MAX), PRIMARY KEY ([id]), FOREIGN KEY ([collection]) REFERENCES {this.GetFullTableName(this._config.MemoryCollectionTableName)}([id]) ON DELETE CASCADE, CONSTRAINT UK_{this._config.MemoryTableName} UNIQUE([collection], [key]) - ); - "; + );"; - using var connection = new SqlConnection(this._config.ConnectionString); - connection.Open(); + var connection = new SqlConnection(this._config.ConnectionString); + await using var disposableConnection = connection.ConfigureAwait(false); + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); - using (SqlCommand command = connection.CreateCommand()) - { - command.CommandText = sql; - command.ExecuteNonQuery(); - } + SqlCommand command = connection.CreateCommand(); + await using var disposableCommand = command.ConfigureAwait(false); + + command.CommandText = sql; + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); } /// @@ -513,22 +596,12 @@ FOREIGN KEY ([collection]) REFERENCES {this.GetFullTableName(this._config.Memory /// /// The index name. /// The cancellation token. - /// + /// True is the index exists private async Task DoesIndexExistsAsync(string indexName, CancellationToken cancellationToken = default) { - var collections = await this.GetIndexesAsync(cancellationToken) - .ConfigureAwait(false); - - foreach (var item in collections) - { - if (item.Equals(indexName, StringComparison.OrdinalIgnoreCase)) - { - return true; - } - } - - return false; + var collections = await this.GetIndexesAsync(cancellationToken).ConfigureAwait(false); + return collections.Any(x => x.Equals(indexName, StringComparison.OrdinalIgnoreCase)); } /// @@ -604,9 +677,10 @@ private string GenerateFilters(string index, SqlParameterCollection parameters, private async Task ReadEntryAsync(SqlDataReader dataReader, bool withEmbedding, CancellationToken cancellationToken = default) { - var entry = new MemoryRecord(); - - entry.Id = dataReader.GetString(dataReader.GetOrdinal("key")); + var entry = new MemoryRecord + { + Id = dataReader.GetString(dataReader.GetOrdinal("key")) + }; if (!(await dataReader.IsDBNullAsync(dataReader.GetOrdinal("payload"), cancellationToken).ConfigureAwait(false))) { @@ -626,10 +700,6 @@ private async Task ReadEntryAsync(SqlDataReader dataReader, bool w return entry; } - // Note: "_" is allowed in Postgres, but we normalize it to "-" for consistency with other DBs - private static readonly Regex s_replaceIndexNameCharsRegex = new(@"[\s|\\|/|.|_|:]"); - private const string ValidSeparator = "-"; - private static string NormalizeIndexName(string index) { ArgumentNullExceptionEx.ThrowIfNullOrWhiteSpace(index, nameof(index), "The index name is empty"); @@ -639,20 +709,5 @@ private static string NormalizeIndexName(string index) return index; } - private int GetSqlServerMajorVersionNumber() - { - using var connection = new SqlConnection(this._config.ConnectionString); - connection.Open(); - - using (SqlCommand command = connection.CreateCommand()) - { - command.CommandText = "SELECT SERVERPROPERTY('ProductMajorVersion')"; - - var result = command.ExecuteScalar(); - - return Convert.ToInt32(result, CultureInfo.InvariantCulture); - } - } - #endregion }