diff --git a/.github/workflows/csharp.yml b/.github/workflows/csharp.yml index e7916a82e8..d6e63eef9d 100644 --- a/.github/workflows/csharp.yml +++ b/.github/workflows/csharp.yml @@ -2,7 +2,7 @@ name: C# tests on: push: - branches: ["main"] +# branches: ["main"] paths: - csharp/** - glide-core/src/** diff --git a/benchmarks/csharp/Program.cs b/benchmarks/csharp/Program.cs index 11df0e36be..fd7db07424 100644 --- a/benchmarks/csharp/Program.cs +++ b/benchmarks/csharp/Program.cs @@ -14,6 +14,8 @@ using StackExchange.Redis; +using static Glide.ConnectionConfiguration; + public static class MainClass { private enum ChosenAction { GET_NON_EXISTING, GET_EXISTING, SET }; @@ -292,7 +294,11 @@ private static async Task run_with_parameters(int total_commands, { var clients = await createClients(clientCount, () => { - var glide_client = new AsyncClient(host, PORT, useTLS); + var config = new StandaloneClientConfigurationBuilder() + .WithAddress(host, PORT) + .WithTlsMode(useTLS ? TlsMode.SecureTls : TlsMode.NoTls) + .Build(); + var glide_client = new AsyncClient(config); return Task.FromResult<(Func>, Func, Action)>( (async (key) => await glide_client.GetAsync(key), async (key, value) => await glide_client.SetAsync(key, value), diff --git a/csharp/lib/AsyncClient.cs b/csharp/lib/AsyncClient.cs index 83e3d4c39b..e4d063a029 100644 --- a/csharp/lib/AsyncClient.cs +++ b/csharp/lib/AsyncClient.cs @@ -4,35 +4,68 @@ using System.Runtime.InteropServices; +using static Glide.ConnectionConfiguration; +using static Glide.Errors; + namespace Glide; public class AsyncClient : IDisposable { #region public methods - public AsyncClient(string host, UInt32 port, bool useTLS) + public enum RequestType : uint + { + // copied from redis_request.proto + CustomCommand = 1, + GetString = 2, + SetString = 3, + Ping = 4, + Info = 5, + // to be continued ... + } + + public AsyncClient(StandaloneClientConfiguration config) { successCallbackDelegate = SuccessCallback; var successCallbackPointer = Marshal.GetFunctionPointerForDelegate(successCallbackDelegate); failureCallbackDelegate = FailureCallback; var failureCallbackPointer = Marshal.GetFunctionPointerForDelegate(failureCallbackDelegate); - clientPointer = CreateClientFfi(host, port, useTLS, successCallbackPointer, failureCallbackPointer); - if (clientPointer == IntPtr.Zero) + var configPtr = Marshal.AllocHGlobal(Marshal.SizeOf(typeof(ConnectionRequest))); + Marshal.StructureToPtr(config.ToRequest(), configPtr, false); + var responsePtr = CreateClientFfi(configPtr, successCallbackPointer, failureCallbackPointer); + Marshal.FreeHGlobal(configPtr); + var response = (ConnectionResponse?)Marshal.PtrToStructure(responsePtr, typeof(ConnectionResponse)); + + if (response == null) { - throw new Exception("Failed creating a client"); + throw new DisconnectedException("Failed creating a client"); + } + clientPointer = response?.Client ?? IntPtr.Zero; + FreeConnectionResponse(responsePtr); + + if (clientPointer == IntPtr.Zero || !string.IsNullOrEmpty(response?.Error)) + { + throw new DisconnectedException(response?.Error ?? "Failed creating a client"); } } public async Task SetAsync(string key, string value) { var message = messageContainer.GetMessageForCall(key, value); - SetFfi(clientPointer, (ulong)message.Index, message.KeyPtr, message.ValuePtr); + Command(clientPointer, (ulong)message.Index, RequestType.SetString, (ulong)message.Args.Length, message.Args); await message; } + public async Task Custom(string[] args) + { + var message = messageContainer.GetMessageForCall(args); + Command(clientPointer, (ulong)message.Index, RequestType.CustomCommand, (ulong)args.Length, message.Args); + return await message; + } + public async Task GetAsync(string key) { - var message = messageContainer.GetMessageForCall(key, null); - GetFfi(clientPointer, (ulong)message.Index, message.KeyPtr); + var message = messageContainer.GetMessageForCall(key); + Command(clientPointer, (ulong)message.Index, RequestType.GetString, (ulong)message.Args.Length, message.Args); return await message; } @@ -62,14 +95,12 @@ private void SuccessCallback(ulong index, IntPtr str) }); } - private void FailureCallback(ulong index) + private void FailureCallback(ulong index, IntPtr error_msg_ptr, ErrorType error_type) { + var error = error_msg_ptr == IntPtr.Zero ? null : Marshal.PtrToStringAnsi(error_msg_ptr); // Work needs to be offloaded from the calling thread, because otherwise we might starve the client's thread pool. - Task.Run(() => - { - var message = messageContainer.GetMessage((int)index); - message.SetException(new Exception("Operation failed")); - }); + _ = Task.Run(() => messageContainer.GetMessage((int)index) + .SetException(Errors.MakeException(error_type, error))); } ~AsyncClient() => Dispose(); @@ -95,19 +126,53 @@ private void FailureCallback(ulong index) #region FFI function declarations private delegate void StringAction(ulong index, IntPtr str); - private delegate void FailureAction(ulong index); - [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "get")] - private static extern void GetFfi(IntPtr client, ulong index, IntPtr key); + /// + /// Glide request failure callback. + /// + /// Request ID + /// Error message + /// Error type + private delegate void FailureAction(ulong index, IntPtr error_msg_ptr, ErrorType errorType); - [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "set")] - private static extern void SetFfi(IntPtr client, ulong index, IntPtr key, IntPtr value); + [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "command")] + private static extern void Command(IntPtr client, ulong index, RequestType requestType, ulong argCount, IntPtr[] args); private delegate void IntAction(IntPtr arg); - [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "create_client")] - private static extern IntPtr CreateClientFfi(String host, UInt32 port, bool useTLS, IntPtr successCallback, IntPtr failureCallback); + [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "create_client_using_config")] + private static extern IntPtr CreateClientFfi(IntPtr config, IntPtr successCallback, IntPtr failureCallback); [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "close_client")] private static extern void CloseClientFfi(IntPtr client); + [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "free_connection_response")] + private static extern void FreeConnectionResponse(IntPtr connectionResponsePtr); + + internal enum ErrorType : uint + { + /// + /// Represented by for user + /// + Unspecified = 0, + /// + /// Represented by for user + /// + ExecAbort = 1, + /// + /// Represented by for user + /// + Timeout = 2, + /// + /// Represented by for user + /// + Disconnect = 3, + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] + internal struct ConnectionResponse + { + public IntPtr Client; + public string Error; + public ErrorType ErrorType; + } #endregion } diff --git a/csharp/lib/ConnectionConfiguration.cs b/csharp/lib/ConnectionConfiguration.cs new file mode 100644 index 0000000000..101ca8119e --- /dev/null +++ b/csharp/lib/ConnectionConfiguration.cs @@ -0,0 +1,563 @@ +/** + * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 + */ + +using System.Runtime.InteropServices; + +namespace Glide; + +public abstract class ConnectionConfiguration +{ + #region Structs and Enums definitions + /// + /// A mirror of ConnectionRequest from connection_request.proto. + /// + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] + internal struct ConnectionRequest + { + public nuint AddressCount; + public IntPtr Addresses; // ** NodeAddress - array pointer + public TlsMode TlsMode; + public bool ClusterMode; + public uint RequestTimeout; + public ReadFrom ReadFrom; + public RetryStrategy ConnectionRetryStrategy; + public AuthenticationInfo AuthenticationInfo; + public uint DatabaseId; + public Protocol Protocol; + [MarshalAs(UnmanagedType.LPStr)] + public string? ClientName; + } + + /// + /// Represents the address and port of a node in the cluster. + /// + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] + internal struct NodeAddress + { + [MarshalAs(UnmanagedType.LPStr)] + public string Host; + public ushort Port; + } + + /// + /// Represents the strategy used to determine how and when to reconnect, in case of connection + /// failures. The time between attempts grows exponentially, to the formula rand(0 ... factor * + /// (exponentBase ^ N)), where N is the number of failed attempts. + /// + /// Once the maximum value is reached, that will remain the time between retry attempts until a + /// reconnect attempt is successful. The client will attempt to reconnect indefinitely. + /// + /// + [StructLayout(LayoutKind.Sequential)] + public struct RetryStrategy + { + /// + /// Number of retry attempts that the client should perform when disconnected from the server, + /// where the time between retries increases. Once the retries have reached the maximum value, the + /// time between retries will remain constant until a reconnect attempt is successful. + /// + public uint NumberOfRetries; + /// + /// The multiplier that will be applied to the waiting time between each retry. + /// + public uint Factor; + /// + /// The exponent base configured for the strategy. + /// + public uint ExponentBase; + + public RetryStrategy(uint number_of_retries, uint factor, uint exponent_base) + { + NumberOfRetries = number_of_retries; + Factor = factor; + ExponentBase = exponent_base; + } + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] + internal struct AuthenticationInfo + { + [MarshalAs(UnmanagedType.LPStr)] + public string? Username; + [MarshalAs(UnmanagedType.LPStr)] + public string Password; + + public AuthenticationInfo(string? username, string password) + { + Username = username; + Password = password; + } + } + + // TODO doc + public enum TlsMode : uint + { + NoTls = 0, + SecureTls = 1, + //InsecureTls = 2, + } + + /// + /// Represents the client's read from strategy. + /// + public enum ReadFrom : uint + { + /// + /// Always get from primary, in order to get the freshest data. + /// + Primary = 0, + /// + /// Spread the requests between all replicas in a round-robin manner. If no replica is available, route the requests to the primary. + /// + PreferReplica = 1, + // TODO: doc or comment out/remove + //LowestLatency = 2, + //AZAffinity = 3, + } + + /// + /// Represents the communication protocol with the server. + /// + public enum Protocol : uint + { + /// + /// Use RESP3 to communicate with the server nodes. + /// + RESP3 = 0, + /// + /// Use RESP2 to communicate with the server nodes. + /// + RESP2 = 1, + } + #endregion + + private static readonly string DEFAULT_HOST = "localhost"; + private static readonly ushort DEFAULT_PORT = 6379; + + /// + /// Basic class which holds common configuration for all types of clients.
+ /// Refer to derived classes for more details: and . + ///
+ public abstract class BaseClientConfiguration + { + internal ConnectionRequest Request; + + internal ConnectionRequest ToRequest() => Request; + } + + /// + /// Configuration for a standalone client. Use to create an instance. + /// + public sealed class StandaloneClientConfiguration : BaseClientConfiguration + { + internal StandaloneClientConfiguration() { } + } + + /// + /// Configuration for a cluster client. Use to create an instance. + /// + public sealed class ClusterClientConfiguration : BaseClientConfiguration + { + internal ClusterClientConfiguration() { } + } + + /// + /// Builder for configuration of common parameters for standalone and cluster client. + /// + /// Derived builder class + public abstract class ClientConfigurationBuilder : IDisposable + where T : ClientConfigurationBuilder, new() + { + internal ConnectionRequest Config; + + protected ClientConfigurationBuilder(bool cluster_mode) + { + Config = new ConnectionRequest { ClusterMode = cluster_mode }; + } + + #region address + private readonly List addresses = new(); + + /// + /// Add a new address to the list.
+ /// See also . + // + // + + protected (string? host, ushort? port) Address + { + set + { + addresses.Add(new NodeAddress + { + Host = value.host ?? DEFAULT_HOST, + Port = value.port ?? DEFAULT_PORT + }); + } + } + + /// + public T WithAddress((string? host, ushort? port) address) + { + Address = (address.host, address.port); + return (T)this; + } + + /// + public T WithAddress((string host, ushort port) address) + { + Address = (address.host, address.port); + return (T)this; + } + + /// + public T WithAddress(string? host, ushort? port) + { + Address = (host, port); + return (T)this; + } + + /// + public T WithAddress(string host, ushort port) + { + Address = (host, port); + return (T)this; + } + + /// + /// Add a new address to the list with default port. + /// + public T WithAddress(string host) + { + Address = (host, DEFAULT_PORT); + return (T)this; + } + + /// + /// Add a new address to the list with default host. + /// + public T WithAddress(ushort port) + { + Address = (DEFAULT_HOST, port); + return (T)this; + } + + /// + /// Syntax sugar helper class for adding addresses. + /// + public sealed class AddressBuilder + { + private readonly ClientConfigurationBuilder owner; + + internal AddressBuilder(ClientConfigurationBuilder owner) + { + this.owner = owner; + } + + /// + public static AddressBuilder operator +(AddressBuilder builder, (string? host, ushort? port) address) + { + builder.owner.WithAddress(address); + return builder; + } + + /// + public static AddressBuilder operator +(AddressBuilder builder, (string host, ushort port) address) + { + builder.owner.WithAddress(address); + return builder; + } + + /// + public static AddressBuilder operator +(AddressBuilder builder, string host) + { + builder.owner.WithAddress(host); + return builder; + } + + /// + public static AddressBuilder operator +(AddressBuilder builder, ushort port) + { + builder.owner.WithAddress(port); + return builder; + } + } + + /// + /// DNS Addresses and ports of known nodes in the cluster. If the server is in cluster mode the + /// list can be partial, as the client will attempt to map out the cluster and find all nodes. If + /// the server is in standalone mode, only nodes whose addresses were provided will be used by the + /// client. + /// + /// For example: + /// [ + /// ("sample-address-0001.use1.cache.amazonaws.com", 6378), + /// ("sample-address-0002.use2.cache.amazonaws.com"), + /// ("sample-address-0002.use3.cache.amazonaws.com", 6380) + /// ] + /// + public AddressBuilder Addresses + { + get + { + return new AddressBuilder(this); + } + set { } // needed for += + } + // TODO possible options : list and array + #endregion + #region TLS + /// + /// Configure whether communication with the server should use Transport Level Security.
+ /// Should match the TLS configuration of the server/cluster, otherwise the connection attempt will fail. + ///
+ public TlsMode TlsMode + { + set + { + Config.TlsMode = value; + } + } + /// + public T WithTlsMode(TlsMode tls_mode) + { + TlsMode = tls_mode; + return (T)this; + } + /// + public T With(TlsMode tls_mode) + { + return WithTlsMode(tls_mode); + } + #endregion + #region Request Timeout + /// + /// The duration in milliseconds that the client should wait for a request to complete. This + /// duration encompasses sending the request, awaiting for a response from the server, and any + /// required reconnections or retries. If the specified timeout is exceeded for a pending request, + /// it will result in a timeout error. If not set, a default value will be used. + /// + public uint RequestTimeout + { + set + { + Config.RequestTimeout = value; + } + } + /// + public T WithRequestTimeout(uint request_timeout) + { + RequestTimeout = request_timeout; + return (T)this; + } + #endregion + #region Read From + /// + /// Configure the client's read from strategy. If not set, will be used. + /// + public ReadFrom ReadFrom + { + set + { + Config.ReadFrom = value; + } + } + /// + public T WithReadFrom(ReadFrom read_from) + { + ReadFrom = read_from; + return (T)this; + } + /// + public T With(ReadFrom read_from) + { + return WithReadFrom(read_from); + } + #endregion + #region Authentication + /// + /// Configure credentials for authentication process. If none are set, the client will not authenticate itself with the server. + /// + /// + /// username The username that will be used for authenticating connections to the Redis servers. If not supplied, "default" will be used.
+ /// password The password that will be used for authenticating connections to the Redis servers. + ///
+ public (string? username, string password) Authentication + { + set + { + Config.AuthenticationInfo = new AuthenticationInfo + ( + value.username, + value.password + ); + } + } + /// + /// Configure credentials for authentication process. If none are set, the client will not authenticate itself with the server. + /// + /// The username that will be used for authenticating connections to the Redis servers. If not supplied, "default" will be used.> + /// The password that will be used for authenticating connections to the Redis servers. + public T WithAuthentication(string? username, string password) + { + Authentication = (username, password); + return (T)this; + } + /// + public T WithAuthentication((string? username, string password) credentials) + { + return WithAuthentication(credentials.username, credentials.password); + } + #endregion + #region Protocol + /// + /// Configure the protocol version to use. If not set, will be used.
+ /// See also . + ///
+ public Protocol ProtocolVersion + { + set + { + Config.Protocol = value; + } + } + + /// + public T WithProtocolVersion(Protocol protocol) + { + ProtocolVersion = protocol; + return (T)this; + } + + /// + public T With(Protocol protocol) + { + ProtocolVersion = protocol; + return (T)this; + } + #endregion + #region Client Name + /// + /// Client name to be used for the client. Will be used with CLIENT SETNAME command during connection establishment. + /// + public string? ClientName + { + set + { + Config.ClientName = value; + } + } + + /// + public T WithClientName(string? clientName) + { + ClientName = clientName; + return (T)this; + } + #endregion + + public void Dispose() => Clean(); + + private void Clean() + { + if (Config.Addresses != IntPtr.Zero) + { + Marshal.FreeHGlobal(Config.Addresses); + Config.Addresses = IntPtr.Zero; + } + } + + internal ConnectionRequest Build() + { + Clean(); // memory leak protection on rebuilding a config from the builder + Config.AddressCount = (uint)addresses.Count; + var address_size = Marshal.SizeOf(typeof(NodeAddress)); + Config.Addresses = Marshal.AllocHGlobal(address_size * addresses.Count); + for (int i = 0; i < addresses.Count; i++) + { + Marshal.StructureToPtr(addresses[i], Config.Addresses + i * address_size, false); + } + return Config; + } + } + + /// + /// Represents the configuration settings for a Standalone Redis client. + /// + public class StandaloneClientConfigurationBuilder : ClientConfigurationBuilder + { + public StandaloneClientConfigurationBuilder() : base(false) { } + + /// + /// Complete the configuration with given settings. + /// + public new StandaloneClientConfiguration Build() => new() { Request = base.Build() }; + + #region DataBase ID + // TODO: not used + /// + /// Index of the logical database to connect to. + /// + public uint DataBaseId + { + set + { + Config.DatabaseId = value; + } + } + /// + public StandaloneClientConfigurationBuilder WithDataBaseId(uint dataBaseId) + { + DataBaseId = dataBaseId; + return this; + } + #endregion + #region Connection Retry Strategy + /// + /// Strategy used to determine how and when to reconnect, in case of connection failures.
+ /// See also + ///
+ public RetryStrategy ConnectionRetryStrategy + { + set + { + Config.ConnectionRetryStrategy = value; + } + } + /// + public StandaloneClientConfigurationBuilder WithConnectionRetryStrategy(RetryStrategy connection_retry_strategy) + { + ConnectionRetryStrategy = connection_retry_strategy; + return this; + } + /// + public StandaloneClientConfigurationBuilder With(RetryStrategy connection_retry_strategy) + { + return WithConnectionRetryStrategy(connection_retry_strategy); + } + /// + /// + /// + /// + public StandaloneClientConfigurationBuilder WithConnectionRetryStrategy(uint number_of_retries, uint factor, uint exponent_base) + { + return WithConnectionRetryStrategy(new RetryStrategy(number_of_retries, factor, exponent_base)); + } + #endregion + } + + /// + /// Represents the configuration settings for a Cluster Redis client.
+ /// Notes: Currently, the reconnection strategy in cluster mode is not configurable, and exponential backoff with fixed values is used. + ///
+ public class ClusterClientConfigurationBuilder : ClientConfigurationBuilder + { + public ClusterClientConfigurationBuilder() : base(true) { } + + /// + /// Complete the configuration with given settings. + /// + public new ClusterClientConfiguration Build() => new() { Request = base.Build() }; + } +} diff --git a/csharp/lib/Errors.cs b/csharp/lib/Errors.cs new file mode 100644 index 0000000000..8b781cee29 --- /dev/null +++ b/csharp/lib/Errors.cs @@ -0,0 +1,43 @@ +/** +* Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 +*/ + +using static Glide.AsyncClient; + +namespace Glide; + +public abstract class Errors +{ + public abstract class RedisError : Exception + { + internal RedisError(string? message) : base(message) { } + } + + public sealed class UnspecifiedException : RedisError + { + internal UnspecifiedException(string? message) : base(message) { } + } + + public sealed class ExecutionAbortedException : RedisError + { + internal ExecutionAbortedException(string? message) : base(message) { } + } + + public sealed class DisconnectedException : RedisError + { + internal DisconnectedException(string? message) : base(message) { } + } + + public sealed class TimeoutException : RedisError + { + internal TimeoutException(string? message) : base(message) { } + } + + internal static RedisError MakeException(ErrorType type, string? message) => type switch + { + ErrorType.ExecAbort => new ExecutionAbortedException(message), + ErrorType.Disconnect => new DisconnectedException(message), + ErrorType.Timeout => new TimeoutException(message), + _ => new UnspecifiedException(message), + }; +} diff --git a/csharp/lib/Logger.cs b/csharp/lib/Logger.cs index 7edc16f16c..363447446d 100644 --- a/csharp/lib/Logger.cs +++ b/csharp/lib/Logger.cs @@ -56,7 +56,7 @@ internal static void Log(Level logLevel, string logIdentifier, string message) SetLoggerConfig(logLevel); } if (!(logLevel <= Logger.loggerLevel)) return; - log(Convert.ToInt32(logLevel), Encoding.UTF8.GetBytes(logIdentifier), Encoding.UTF8.GetBytes(message)); + log(logLevel, logIdentifier, message); } #endregion internal methods @@ -69,17 +69,16 @@ internal static void Log(Level logLevel, string logIdentifier, string message) // the filename argument is optional - if provided the target of the logs will be the file mentioned, else will be the console public static void SetLoggerConfig(Level? level, string? filename = null) { - var buffer = filename is null ? null : Encoding.UTF8.GetBytes(filename); - Logger.loggerLevel = InitInternalLogger(Convert.ToInt32(level), buffer); + Logger.loggerLevel = InitInternalLogger(level ?? Level.Error, filename); } #endregion public methods #region FFI function declaration [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "log")] - private static extern void log(Int32 logLevel, byte[] logIdentifier, byte[] message); + private static extern void log(Level logLevel, string logIdentifier, string message); - [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "init")] - private static extern Level InitInternalLogger(Int32 level, byte[]? filename); + [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "init_logging")] + private static extern Level InitInternalLogger(Level level, string? filename); #endregion } diff --git a/csharp/lib/Message.cs b/csharp/lib/Message.cs index c0d4c7f07b..b2ea68757c 100644 --- a/csharp/lib/Message.cs +++ b/csharp/lib/Message.cs @@ -8,6 +8,8 @@ using Glide; +using static Glide.Errors; + /// Reusable source of ValueTask. This object can be allocated once and then reused /// to create multiple asynchronous operations, as long as each call to CreateTask /// is awaited to completion before the next call begins. @@ -17,11 +19,9 @@ internal class Message : INotifyCompletion /// know how to find the message and set its result. public int Index { get; } - /// The pointer to the unmanaged memory that contains the operation's key. - public IntPtr KeyPtr { get; private set; } + /// The pointers to the unmanaged memory that contains the command arguments + public IntPtr[] Args { get; private set; } - /// The pointer to the unmanaged memory that contains the operation's key. - public IntPtr ValuePtr { get; private set; } private readonly MessageContainer container; public Message(int index, MessageContainer container) @@ -29,6 +29,7 @@ public Message(int index, MessageContainer container) Index = index; continuation = () => { }; this.container = container; + Args = new IntPtr[0]; } private Action? continuation; @@ -37,7 +38,7 @@ public Message(int index, MessageContainer container) const int COMPLETION_STAGE_CONTINUATION_EXECUTED = 2; private int completionState; private T? result; - private Exception? exception; + private RedisError? exception; /// Triggers a succesful completion of the task returned from the latest call /// to CreateTask. @@ -49,7 +50,7 @@ public void SetResult(T? result) /// Triggers a failure completion of the task returned from the latest call to /// CreateTask. - public void SetException(Exception exc) + public void SetException(RedisError exc) { this.exception = exc; FinishSet(); @@ -84,31 +85,22 @@ private void CheckRaceAndCallContinuation() /// This returns a task that will complete once SetException / SetResult are called, /// and ensures that the internal state of the message is set-up before the task is created, /// and cleaned once it is complete. - public void StartTask(string? key, string? value, object client) + public void StartTask(string?[] args, object client) { continuation = null; this.completionState = COMPLETION_STAGE_STARTED; this.result = default(T); this.exception = null; this.client = client; - this.KeyPtr = key is null ? IntPtr.Zero : Marshal.StringToHGlobalAnsi(key); - this.ValuePtr = value is null ? IntPtr.Zero : Marshal.StringToHGlobalAnsi(value); + this.Args = args.Select(arg => Marshal.StringToHGlobalAnsi(arg)).ToArray(); } // This function isn't thread-safe. Access to it should be from a single thread, and only once per operation. // For the sake of performance, this responsibility is on the caller, and the function doesn't contain any safety measures. private void FreePointers() { - if (KeyPtr != IntPtr.Zero) - { - Marshal.FreeHGlobal(KeyPtr); - KeyPtr = IntPtr.Zero; - } - if (ValuePtr != IntPtr.Zero) - { - Marshal.FreeHGlobal(ValuePtr); - ValuePtr = IntPtr.Zero; - } + foreach (var arg in Args.Where(arg => arg != IntPtr.Zero)) + Marshal.FreeHGlobal(arg); client = null; } diff --git a/csharp/lib/MessageContainer.cs b/csharp/lib/MessageContainer.cs index faa1b5a277..1c6632c07b 100644 --- a/csharp/lib/MessageContainer.cs +++ b/csharp/lib/MessageContainer.cs @@ -4,6 +4,8 @@ using System.Collections.Concurrent; +using static Glide.Errors; + namespace Glide; @@ -11,10 +13,10 @@ internal class MessageContainer { internal Message GetMessage(int index) => messages[index]; - internal Message GetMessageForCall(string? key, string? value) + internal Message GetMessageForCall(params string?[] args) { var message = GetFreeMessage(); - message.StartTask(key, value, this); + message.StartTask(args, this); return message; } @@ -42,7 +44,7 @@ internal void DisposeWithError(Exception? error) { try { - message.SetException(new TaskCanceledException("Client closed", error)); + message.SetException(new ExecutionAbortedException($"Client closed: {error}")); } catch (Exception) { } } diff --git a/csharp/lib/src/lib.rs b/csharp/lib/src/lib.rs index 495d959598..dccf0b1fa1 100644 --- a/csharp/lib/src/lib.rs +++ b/csharp/lib/src/lib.rs @@ -1,228 +1,7 @@ /** * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */ -use glide_core::connection_request; -use glide_core::{client::Client as GlideClient, connection_request::NodeAddress}; -use redis::{Cmd, FromRedisValue, RedisResult}; -use std::{ - ffi::{c_void, CStr, CString}, - os::raw::c_char, -}; -use tokio::runtime::Builder; -use tokio::runtime::Runtime; -pub enum Level { - Error = 0, - Warn = 1, - Info = 2, - Debug = 3, - Trace = 4, -} - -pub struct Client { - client: GlideClient, - success_callback: unsafe extern "C" fn(usize, *const c_char) -> (), - failure_callback: unsafe extern "C" fn(usize) -> (), // TODO - add specific error codes - runtime: Runtime, -} - -fn create_connection_request( - host: String, - port: u32, - use_tls: bool, -) -> connection_request::ConnectionRequest { - let mut address_info = NodeAddress::new(); - address_info.host = host.to_string().into(); - address_info.port = port; - let addresses_info = vec![address_info]; - let mut connection_request = connection_request::ConnectionRequest::new(); - connection_request.addresses = addresses_info; - connection_request.tls_mode = if use_tls { - connection_request::TlsMode::SecureTls - } else { - connection_request::TlsMode::NoTls - } - .into(); - - connection_request -} - -fn create_client_internal( - host: *const c_char, - port: u32, - use_tls: bool, - success_callback: unsafe extern "C" fn(usize, *const c_char) -> (), - failure_callback: unsafe extern "C" fn(usize) -> (), -) -> RedisResult { - let host_cstring = unsafe { CStr::from_ptr(host as *mut c_char) }; - let host_string = host_cstring.to_str()?.to_string(); - let request = create_connection_request(host_string, port, use_tls); - let runtime = Builder::new_multi_thread() - .enable_all() - .thread_name("GLIDE for Redis C# thread") - .build()?; - let _runtime_handle = runtime.enter(); - let client = runtime.block_on(GlideClient::new(request)).unwrap(); // TODO - handle errors. - Ok(Client { - client, - success_callback, - failure_callback, - runtime, - }) -} - -/// Creates a new client to the given address. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns. All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool. -#[no_mangle] -pub extern "C" fn create_client( - host: *const c_char, - port: u32, - use_tls: bool, - success_callback: unsafe extern "C" fn(usize, *const c_char) -> (), - failure_callback: unsafe extern "C" fn(usize) -> (), -) -> *const c_void { - match create_client_internal(host, port, use_tls, success_callback, failure_callback) { - Err(_) => std::ptr::null(), // TODO - log errors - Ok(client) => Box::into_raw(Box::new(client)) as *const c_void, - } -} - -#[no_mangle] -pub extern "C" fn close_client(client_ptr: *const c_void) { - let client_ptr = unsafe { Box::from_raw(client_ptr as *mut Client) }; - let _runtime_handle = client_ptr.runtime.enter(); - drop(client_ptr); -} - -/// Expects that key and value will be kept valid until the callback is called. -#[no_mangle] -pub extern "C" fn set( - client_ptr: *const c_void, - callback_index: usize, - key: *const c_char, - value: *const c_char, -) { - let client = unsafe { Box::leak(Box::from_raw(client_ptr as *mut Client)) }; - // The safety of this needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed. - let ptr_address = client_ptr as usize; - - let key_cstring = unsafe { CStr::from_ptr(key as *mut c_char) }; - let value_cstring = unsafe { CStr::from_ptr(value as *mut c_char) }; - let mut client_clone = client.client.clone(); - client.runtime.spawn(async move { - let key_bytes = key_cstring.to_bytes(); - let value_bytes = value_cstring.to_bytes(); - let mut cmd = Cmd::new(); - cmd.arg("SET").arg(key_bytes).arg(value_bytes); - let result = client_clone.send_command(&cmd, None).await; - unsafe { - let client = Box::leak(Box::from_raw(ptr_address as *mut Client)); - match result { - Ok(_) => (client.success_callback)(callback_index, std::ptr::null()), // TODO - should return "OK" string. - Err(_) => (client.failure_callback)(callback_index), // TODO - report errors - }; - } - }); -} - -/// Expects that key will be kept valid until the callback is called. If the callback is called with a string pointer, the pointer must -/// be used synchronously, because the string will be dropped after the callback. -#[no_mangle] -pub extern "C" fn get(client_ptr: *const c_void, callback_index: usize, key: *const c_char) { - let client = unsafe { Box::leak(Box::from_raw(client_ptr as *mut Client)) }; - // The safety of this needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed. - let ptr_address = client_ptr as usize; - - let key_cstring = unsafe { CStr::from_ptr(key as *mut c_char) }; - let mut client_clone = client.client.clone(); - client.runtime.spawn(async move { - let key_bytes = key_cstring.to_bytes(); - let mut cmd = Cmd::new(); - cmd.arg("GET").arg(key_bytes); - let result = client_clone.send_command(&cmd, None).await; - let client = unsafe { Box::leak(Box::from_raw(ptr_address as *mut Client)) }; - let value = match result { - Ok(value) => value, - Err(_) => { - unsafe { (client.failure_callback)(callback_index) }; // TODO - report errors, - return; - } - }; - let result = Option::::from_owned_redis_value(value); - - unsafe { - match result { - Ok(None) => (client.success_callback)(callback_index, std::ptr::null()), - Ok(Some(c_str)) => (client.success_callback)(callback_index, c_str.as_ptr()), - Err(_) => (client.failure_callback)(callback_index), // TODO - report errors - }; - } - }); -} - -impl From for Level { - fn from(level: logger_core::Level) -> Self { - match level { - logger_core::Level::Error => Level::Error, - logger_core::Level::Warn => Level::Warn, - logger_core::Level::Info => Level::Info, - logger_core::Level::Debug => Level::Debug, - logger_core::Level::Trace => Level::Trace, - } - } -} - -impl From for logger_core::Level { - fn from(level: Level) -> logger_core::Level { - match level { - Level::Error => logger_core::Level::Error, - Level::Warn => logger_core::Level::Warn, - Level::Info => logger_core::Level::Info, - Level::Debug => logger_core::Level::Debug, - Level::Trace => logger_core::Level::Trace, - } - } -} - -#[no_mangle] -#[allow(improper_ctypes_definitions)] -/// # Safety -/// Unsafe function because creating string from pointer -pub unsafe extern "C" fn log( - log_level: Level, - log_identifier: *const c_char, - message: *const c_char, -) { - unsafe { - logger_core::log( - log_level.into(), - CStr::from_ptr(log_identifier) - .to_str() - .expect("Can not read log_identifier argument."), - CStr::from_ptr(message) - .to_str() - .expect("Can not read message argument."), - ); - } -} - -#[no_mangle] -#[allow(improper_ctypes_definitions)] -/// # Safety -/// Unsafe function because creating string from pointer -pub unsafe extern "C" fn init(level: Option, file_name: *const c_char) -> Level { - let file_name_as_str; - unsafe { - file_name_as_str = if file_name.is_null() { - None - } else { - Some( - CStr::from_ptr(file_name) - .to_str() - .expect("Can not read string argument."), - ) - }; - - let logger_level = logger_core::init(level.map(|level| level.into()), file_name_as_str); - logger_level.into() - } -} +// rustc blames empty file or file with a comment only +#[allow(unused_imports)] +use glide_core; diff --git a/csharp/tests/Integration/GetAndSet.cs b/csharp/tests/Integration/GetAndSet.cs index ed37512337..c4fb8ee9bf 100644 --- a/csharp/tests/Integration/GetAndSet.cs +++ b/csharp/tests/Integration/GetAndSet.cs @@ -6,6 +6,7 @@ namespace tests.Integration; using Glide; +using static Glide.ConnectionConfiguration; using static tests.Integration.IntegrationTestBase; public class GetAndSet @@ -19,10 +20,18 @@ private async Task GetAndSetRandomValues(AsyncClient client) Assert.That(result, Is.EqualTo(value)); } + private StandaloneClientConfiguration GetConfig() + { + return new StandaloneClientConfigurationBuilder() + .WithAddress((ushort)TestConfiguration.STANDALONE_PORTS[0]) + .WithTlsMode(TlsMode.NoTls) + .Build(); + } + [Test] public async Task GetReturnsLastSet() { - using (var client = new AsyncClient("localhost", TestConfiguration.STANDALONE_PORTS[0], false)) + using (var client = new AsyncClient(GetConfig())) { await GetAndSetRandomValues(client); } @@ -31,7 +40,7 @@ public async Task GetReturnsLastSet() [Test] public async Task GetAndSetCanHandleNonASCIIUnicode() { - using (var client = new AsyncClient("localhost", TestConfiguration.STANDALONE_PORTS[0], false)) + using (var client = new AsyncClient(GetConfig())) { var key = Guid.NewGuid().ToString(); var value = "שלום hello 汉字"; @@ -44,7 +53,7 @@ public async Task GetAndSetCanHandleNonASCIIUnicode() [Test] public async Task GetReturnsNull() { - using (var client = new AsyncClient("localhost", TestConfiguration.STANDALONE_PORTS[0], false)) + using (var client = new AsyncClient(GetConfig())) { var result = await client.GetAsync(Guid.NewGuid().ToString()); Assert.That(result, Is.EqualTo(null)); @@ -54,7 +63,7 @@ public async Task GetReturnsNull() [Test] public async Task GetReturnsEmptyString() { - using (var client = new AsyncClient("localhost", TestConfiguration.STANDALONE_PORTS[0], false)) + using (var client = new AsyncClient(GetConfig())) { var key = Guid.NewGuid().ToString(); var value = ""; @@ -67,7 +76,7 @@ public async Task GetReturnsEmptyString() [Test] public async Task HandleVeryLargeInput() { - using (var client = new AsyncClient("localhost", TestConfiguration.STANDALONE_PORTS[0], false)) + using (var client = new AsyncClient(GetConfig())) { var key = Guid.NewGuid().ToString(); var value = Guid.NewGuid().ToString(); @@ -87,7 +96,7 @@ public async Task HandleVeryLargeInput() [Test] public void ConcurrentOperationsWork() { - using (var client = new AsyncClient("localhost", TestConfiguration.STANDALONE_PORTS[0], false)) + using (var client = new AsyncClient(GetConfig())) { var operations = new List();