From 5adda31ffbf3efe7e211be3cdf53e6d7c5e938f5 Mon Sep 17 00:00:00 2001 From: Odonno Date: Thu, 13 Feb 2025 12:37:09 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20implement=20Import=20method?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../NativeMethods.g.cs | 8 ++ .../SurrealDbEmbeddedEngine.cs | 77 ++++++++++++++++ SurrealDb.Embedded.RocksDb/NativeMethods.g.cs | 8 ++ .../NativeMethods.g.cs | 8 ++ SurrealDb.Net.Tests/ImportTests.cs | 50 +++++++++++ .../Internals/Http/CommonHttpWrapper.cs | 16 ++++ .../Internals/SurrealDbEngine.Interface.cs | 10 +++ SurrealDb.Net/SurrealDbClient.Base.cs | 57 ++++++++++++ SurrealDb.Net/SurrealDbClient.Interface.cs | 10 +++ SurrealDb.Net/SurrealDbClient.Methods.cs | 89 ++++++++----------- rust-embedded/shared/src/app/mod.rs | 17 +++- rust-embedded/shared/src/lib.rs | 25 ++++++ 12 files changed, 321 insertions(+), 54 deletions(-) create mode 100644 SurrealDb.Net.Tests/ImportTests.cs create mode 100644 SurrealDb.Net/Internals/Http/CommonHttpWrapper.cs diff --git a/SurrealDb.Embedded.InMemory/NativeMethods.g.cs b/SurrealDb.Embedded.InMemory/NativeMethods.g.cs index 1a5a385b..cca22e93 100644 --- a/SurrealDb.Embedded.InMemory/NativeMethods.g.cs +++ b/SurrealDb.Embedded.InMemory/NativeMethods.g.cs @@ -34,6 +34,14 @@ internal static unsafe partial class NativeMethods [DllImport(__DllName, EntryPoint = "execute", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] internal static extern void execute(int id, Method method, byte* bytes, int len, SuccessAction success, FailureAction failure); + /// + /// # Safety + /// + /// Executes the "import" method of a SurrealDB engine (given its id). + /// + [DllImport(__DllName, EntryPoint = "import", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + internal static extern void import(int id, ushort* utf16_str, int utf16_len, SuccessAction success, FailureAction failure); + /// /// # Safety /// diff --git a/SurrealDb.Embedded.Internals/SurrealDbEmbeddedEngine.cs b/SurrealDb.Embedded.Internals/SurrealDbEmbeddedEngine.cs index fa9da23d..5ac07a32 100644 --- a/SurrealDb.Embedded.Internals/SurrealDbEmbeddedEngine.cs +++ b/SurrealDb.Embedded.Internals/SurrealDbEmbeddedEngine.cs @@ -337,6 +337,83 @@ await SendRequestAsync(Method.Ping, null, cancellationToken) } } + public async Task Import(string input, CancellationToken cancellationToken) + { + using var timeoutCts = new CancellationTokenSource(TimeSpan.FromSeconds(30)); + cancellationToken.Register(timeoutCts.Cancel); + + var taskCompletionSource = new TaskCompletionSource(); + timeoutCts.Token.Register(() => + { + taskCompletionSource.TrySetCanceled(); + }); + + Action success = (_) => + { + try + { + taskCompletionSource.SetResult(default); + } + catch (Exception e) + { + taskCompletionSource.SetException(e); + } + }; + Action fail = (byteBuffer) => + { + string error = CborSerializer.Deserialize( + byteBuffer.AsReadOnly(), + GetCborOptions() + ); + taskCompletionSource.SetException(new SurrealDbException(error)); + }; + + var successHandle = GCHandle.Alloc(success); + var failureHandle = GCHandle.Alloc(fail); + + unsafe + { + var successAction = new SuccessAction() + { + handle = new RustGCHandle() + { + ptr = GCHandle.ToIntPtr(successHandle), + drop_callback = &NativeBindings.DropGcHandle, + }, + callback = &NativeBindings.SuccessCallback, + }; + + var failureAction = new FailureAction() + { + handle = new RustGCHandle() + { + ptr = GCHandle.ToIntPtr(failureHandle), + drop_callback = &NativeBindings.DropGcHandle, + }, + callback = &NativeBindings.FailureCallback, + }; + + fixed (char* p = input.AsSpan()) + { + NativeMethods.import(_id, (ushort*)p, input.Length, successAction, failureAction); + } + } + + try + { + await taskCompletionSource.Task.ConfigureAwait(false); + } + catch (OperationCanceledException) + { + if (!cancellationToken.IsCancellationRequested) + { + throw new TimeoutException(); + } + + throw; + } + } + public Task Info(CancellationToken cancellationToken) { throw new NotSupportedException("Authentication is not enabled in embedded mode."); diff --git a/SurrealDb.Embedded.RocksDb/NativeMethods.g.cs b/SurrealDb.Embedded.RocksDb/NativeMethods.g.cs index b9122507..b82fba97 100644 --- a/SurrealDb.Embedded.RocksDb/NativeMethods.g.cs +++ b/SurrealDb.Embedded.RocksDb/NativeMethods.g.cs @@ -34,6 +34,14 @@ internal static unsafe partial class NativeMethods [DllImport(__DllName, EntryPoint = "execute", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] internal static extern void execute(int id, Method method, byte* bytes, int len, SuccessAction success, FailureAction failure); + /// + /// # Safety + /// + /// Executes the "import" method of a SurrealDB engine (given its id). + /// + [DllImport(__DllName, EntryPoint = "import", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + internal static extern void import(int id, ushort* utf16_str, int utf16_len, SuccessAction success, FailureAction failure); + /// /// # Safety /// diff --git a/SurrealDb.Embedded.SurrealKv/NativeMethods.g.cs b/SurrealDb.Embedded.SurrealKv/NativeMethods.g.cs index 2710527c..d6375f55 100644 --- a/SurrealDb.Embedded.SurrealKv/NativeMethods.g.cs +++ b/SurrealDb.Embedded.SurrealKv/NativeMethods.g.cs @@ -34,6 +34,14 @@ internal static unsafe partial class NativeMethods [DllImport(__DllName, EntryPoint = "execute", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] internal static extern void execute(int id, Method method, byte* bytes, int len, SuccessAction success, FailureAction failure); + /// + /// # Safety + /// + /// Executes the "import" method of a SurrealDB engine (given its id). + /// + [DllImport(__DllName, EntryPoint = "import", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)] + internal static extern void import(int id, ushort* utf16_str, int utf16_len, SuccessAction success, FailureAction failure); + /// /// # Safety /// diff --git a/SurrealDb.Net.Tests/ImportTests.cs b/SurrealDb.Net.Tests/ImportTests.cs new file mode 100644 index 00000000..732a527f --- /dev/null +++ b/SurrealDb.Net.Tests/ImportTests.cs @@ -0,0 +1,50 @@ +namespace SurrealDb.Net.Tests; + +public class ImportTests +{ + private const string IMPORT_QUERY = """ + DEFINE TABLE foo SCHEMALESS; + DEFINE TABLE bar SCHEMALESS; + CREATE foo:1 CONTENT { hello: "world" }; + CREATE bar:1 CONTENT { hello: "world" }; + DEFINE FUNCTION fn::foo() { + RETURN "bar"; + }; + """; + + [Test] + [ConnectionStringFixtureGenerator] + public async Task ShouldImportDataSuccessfully(string connectionString) + { + var version = await SurrealDbClientGenerator.GetSurrealTestVersion(connectionString); + if (version?.Major < 2) + { + return; + } + + await using var surrealDbClientGenerator = new SurrealDbClientGenerator(); + var dbInfo = surrealDbClientGenerator.GenerateDatabaseInfo(); + + var client = surrealDbClientGenerator.Create(connectionString); + await client.Use(dbInfo.Namespace, dbInfo.Database); + + Func func = async () => + { + await client.Import(IMPORT_QUERY); + }; + + await func.Should().NotThrowAsync(); + + // Check imported query by querying the db + var fooRecords = await client.Select("foo"); + fooRecords.Should().NotBeNull().And.HaveCount(1); + + var barRecords = await client.Select("bar"); + barRecords.Should().NotBeNull().And.HaveCount(1); + + var fnResult = await client.Run("fn::foo"); + fnResult.Should().Be("bar"); + + await client.DisposeAsync(); + } +} diff --git a/SurrealDb.Net/Internals/Http/CommonHttpWrapper.cs b/SurrealDb.Net/Internals/Http/CommonHttpWrapper.cs new file mode 100644 index 00000000..752d5011 --- /dev/null +++ b/SurrealDb.Net/Internals/Http/CommonHttpWrapper.cs @@ -0,0 +1,16 @@ +using Dahomey.Cbor; +using Semver; + +namespace SurrealDb.Net.Internals.Http; + +internal sealed record CommonHttpWrapper( + HttpClient HttpClient, + SemVersion? Version, + Action? ConfigureCborOptions +) : IDisposable +{ + public void Dispose() + { + HttpClient.Dispose(); + } +} diff --git a/SurrealDb.Net/Internals/SurrealDbEngine.Interface.cs b/SurrealDb.Net/Internals/SurrealDbEngine.Interface.cs index a85f5b2f..78175616 100644 --- a/SurrealDb.Net/Internals/SurrealDbEngine.Interface.cs +++ b/SurrealDb.Net/Internals/SurrealDbEngine.Interface.cs @@ -194,6 +194,16 @@ void Initialize( /// The cancellationToken enables graceful cancellation of asynchronous operations /// SurrealQL script as Task Export(ExportOptions? options, CancellationToken cancellationToken); + + /// + /// This method imports data into a SurrealDB database. + /// + /// + /// This method is only supported by SurrealDB v2.0.0 or higher. + /// + /// + /// The cancellationToken enables graceful cancellation of asynchronous operations + Task Import(string input, CancellationToken cancellationToken = default); } public interface ISurrealDbInMemoryEngine : ISurrealDbProviderEngine { } diff --git a/SurrealDb.Net/SurrealDbClient.Base.cs b/SurrealDb.Net/SurrealDbClient.Base.cs index 894ef76c..a1e375e3 100644 --- a/SurrealDb.Net/SurrealDbClient.Base.cs +++ b/SurrealDb.Net/SurrealDbClient.Base.cs @@ -1,8 +1,12 @@ using Dahomey.Cbor; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Semver; +using SurrealDb.Net.Exceptions; using SurrealDb.Net.Extensions.DependencyInjection; using SurrealDb.Net.Internals; +using SurrealDb.Net.Internals.Auth; +using SurrealDb.Net.Internals.Http; namespace SurrealDb.Net; @@ -26,4 +30,57 @@ protected void InitializeProviderEngine( loggerFactory is not null ? new SurrealDbLoggerFactory(loggerFactory) : null ); } + + private async Task CreateCommonHttpWrapperAsync( + CancellationToken cancellationToken + ) + { + SemVersion? version; + string? ns; + string? db; + IAuth? auth; + Action? configureCborOptions; + + switch (_engine) + { + case SurrealDbHttpEngine httpEngine: + // 💡 Ensures underlying engine is started to retrieve some information + await httpEngine.Connect(cancellationToken).ConfigureAwait(false); + + version = httpEngine._version; + ns = httpEngine._config.Ns; + db = httpEngine._config.Db; + auth = httpEngine._config.Auth; + configureCborOptions = httpEngine._configureCborOptions; + break; + case SurrealDbWsEngine wsEngine: + // 💡 Ensures underlying engine is started to retrieve some information + await wsEngine.InternalConnectAsync(true, cancellationToken).ConfigureAwait(false); + + version = wsEngine._version; + ns = wsEngine._config.Ns; + db = wsEngine._config.Db; + auth = wsEngine._config.Auth; + configureCborOptions = wsEngine._configureCborOptions; + break; + default: + throw new SurrealDbException("No underlying engine is started."); + } + + if (string.IsNullOrWhiteSpace(ns)) + { + throw new SurrealDbException("Namespace should be provided to export data."); + } + if (string.IsNullOrWhiteSpace(db)) + { + throw new SurrealDbException("Database should be provided to export data."); + } + + var httpClient = new HttpClient(); + + SurrealDbHttpEngine.SetNsDbHttpClientHeaders(httpClient, version, ns, db); + SurrealDbHttpEngine.SetAuthHttpClientHeaders(httpClient, auth); + + return new CommonHttpWrapper(httpClient, version, configureCborOptions); + } } diff --git a/SurrealDb.Net/SurrealDbClient.Interface.cs b/SurrealDb.Net/SurrealDbClient.Interface.cs index 43b26002..ec2b48a3 100644 --- a/SurrealDb.Net/SurrealDbClient.Interface.cs +++ b/SurrealDb.Net/SurrealDbClient.Interface.cs @@ -166,6 +166,16 @@ Task Export( /// Task Health(CancellationToken cancellationToken = default); + /// + /// This method imports data into a SurrealDB database. + /// + /// + /// This method is only supported by SurrealDB v2.0.0 or higher. + /// + /// + /// The cancellationToken enables graceful cancellation of asynchronous operations + Task Import(string input, CancellationToken cancellationToken = default); + /// /// Retrieves information about the authenticated scope user. /// diff --git a/SurrealDb.Net/SurrealDbClient.Methods.cs b/SurrealDb.Net/SurrealDbClient.Methods.cs index 3a859387..f5a164d9 100644 --- a/SurrealDb.Net/SurrealDbClient.Methods.cs +++ b/SurrealDb.Net/SurrealDbClient.Methods.cs @@ -1,9 +1,6 @@ using System.Collections.Immutable; -using Dahomey.Cbor; -using Semver; -using SurrealDb.Net.Exceptions; +using System.Text; using SurrealDb.Net.Internals; -using SurrealDb.Net.Internals.Auth; using SurrealDb.Net.Models; using SurrealDb.Net.Models.Auth; using SurrealDb.Net.Models.LiveQuery; @@ -99,63 +96,20 @@ public async Task Export( _ => throw new NotImplementedException(), }; - SemVersion? version; - string? ns; - string? db; - IAuth? auth; - Action? configureCborOptions; - - switch (_engine) - { - case SurrealDbHttpEngine httpEngine: - // 💡 Ensures underlying engine is started to retrieve some information - await httpEngine.Connect(cancellationToken).ConfigureAwait(false); - - version = httpEngine._version; - ns = httpEngine._config.Ns; - db = httpEngine._config.Db; - auth = httpEngine._config.Auth; - configureCborOptions = httpEngine._configureCborOptions; - break; - case SurrealDbWsEngine wsEngine: - // 💡 Ensures underlying engine is started to retrieve some information - await wsEngine.InternalConnectAsync(true, cancellationToken).ConfigureAwait(false); - - version = wsEngine._version; - ns = wsEngine._config.Ns; - db = wsEngine._config.Db; - auth = wsEngine._config.Auth; - configureCborOptions = wsEngine._configureCborOptions; - break; - default: - throw new SurrealDbException("No underlying engine is started."); - } - - if (string.IsNullOrWhiteSpace(ns)) - { - throw new SurrealDbException("Namespace should be provided to export data."); - } - if (string.IsNullOrWhiteSpace(db)) - { - throw new SurrealDbException("Database should be provided to export data."); - } + using var wrapper = await CreateCommonHttpWrapperAsync(cancellationToken) + .ConfigureAwait(false); using var httpContent = SurrealDbHttpEngine.CreateBodyContent( null, - configureCborOptions, + wrapper.ConfigureCborOptions, options ?? new() ); - using var httpClient = new HttpClient(); - - SurrealDbHttpEngine.SetNsDbHttpClientHeaders(httpClient, version, ns, db); - SurrealDbHttpEngine.SetAuthHttpClientHeaders(httpClient, auth); - - bool shouldUsePostRequest = version is { Major: >= 2, Minor: >= 1 }; + bool shouldUsePostRequest = wrapper.Version is { Major: >= 2, Minor: >= 1 }; var httpRequestTask = shouldUsePostRequest - ? httpClient.PostAsync(exportUri, httpContent, cancellationToken) - : httpClient.GetAsync(exportUri, cancellationToken); + ? wrapper.HttpClient.PostAsync(exportUri, httpContent, cancellationToken) + : wrapper.HttpClient.GetAsync(exportUri, cancellationToken); using var response = await httpRequestTask.ConfigureAwait(false); response.EnsureSuccessStatusCode(); @@ -172,6 +126,35 @@ public Task Health(CancellationToken cancellationToken = default) return _engine.Health(cancellationToken); } + public async Task Import(string input, CancellationToken cancellationToken = default) + { + if (_engine is ISurrealDbProviderEngine providerEngine) + { + await providerEngine.Import(input, cancellationToken).ConfigureAwait(false); + return; + } + + const string path = "/import"; + + var importUri = Uri.Scheme switch + { + "http" or "https" => new Uri(Uri, path), + "ws" => new Uri(new Uri(Uri.AbsoluteUri.Replace("ws://", "http://")), path), + "wss" => new Uri(new Uri(Uri.AbsoluteUri.Replace("wss://", "https://")), path), + _ => throw new NotImplementedException(), + }; + + using var wrapper = await CreateCommonHttpWrapperAsync(cancellationToken) + .ConfigureAwait(false); + + using var httpContent = new StringContent(input, Encoding.UTF8, "plain/text"); + + using var response = await wrapper + .HttpClient.PostAsync(importUri, httpContent, cancellationToken) + .ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + } + public Task Info(CancellationToken cancellationToken = default) { return _engine.Info(cancellationToken); diff --git a/rust-embedded/shared/src/app/mod.rs b/rust-embedded/shared/src/app/mod.rs index 6575a907..2687f6de 100644 --- a/rust-embedded/shared/src/app/mod.rs +++ b/rust-embedded/shared/src/app/mod.rs @@ -33,6 +33,14 @@ impl SurrealEmbeddedEngines { engine.execute(method, params).await } + pub async fn import(&self, id: i32, input: String) -> Result<(), Error> { + let read_lock = self.0.read().await; + let Some(engine) = read_lock.get(&id) else { + return Err("Engine not found".into()); + }; + engine.import(input).await + } + pub async fn export(&self, id: i32, params: Vec) -> Result, Error> { let read_lock = self.0.read().await; let Some(engine) = read_lock.get(&id) else { @@ -66,7 +74,7 @@ impl SurrealEmbeddedEngine { pub async fn execute(&self, method: Method, params: Vec) -> Result, Error> { let params = crate::cbor::get_params(params) .map_err(|_| "Failed to deserialize params".to_string())?; - let rpc = self.0.write().await; + let rpc = self.0.read().await; let res = RpcContext::execute(&*rpc, None, method, params) .await .map_err(|e| e.to_string())?; @@ -114,6 +122,13 @@ impl SurrealEmbeddedEngine { let out = cbor::res(result).map_err(|e| e.to_string())?; Ok(out) } + + pub async fn import(&self, input: String) -> Result<(), Error> { + let inner = self.0.read().await; + inner.kvs.import(&input, &inner.session()).await?; + + Ok(()) + } } struct SurrealEmbeddedEngineInner { diff --git a/rust-embedded/shared/src/lib.rs b/rust-embedded/shared/src/lib.rs index 8e8c1942..bd9438a9 100644 --- a/rust-embedded/shared/src/lib.rs +++ b/rust-embedded/shared/src/lib.rs @@ -68,6 +68,31 @@ pub unsafe extern "C" fn execute( }); } +/// # Safety +/// +/// Executes the "import" method of a SurrealDB engine (given its id). +#[no_mangle] +pub unsafe extern "C" fn import( + id: i32, + utf16_str: *const u16, + utf16_len: i32, + success: SuccessAction, + failure: FailureAction, +) { + let input = convert_csharp_to_rust_string_utf16(utf16_str, utf16_len); + + get_global_runtime().spawn(async move { + match ENGINES.import(id, input).await { + Ok(_) => { + send_success(vec![], success); + } + Err(error) => { + send_failure(error.as_str(), failure); + } + } + }); +} + /// # Safety /// /// Executes the "export" method of a SurrealDB engine (given its id).