diff --git a/.editorconfig b/.editorconfig index f18b998a..45695edb 100644 --- a/.editorconfig +++ b/.editorconfig @@ -212,5 +212,8 @@ dotnet_diagnostic.CSIsNull001.severity = warning # CSIsNull002: Use `is object` for non-null checks dotnet_diagnostic.CSIsNull002.severity = warning +# NBMsgPack051: Prefer .NET APIs over netstandard ones. +dotnet_diagnostic.NBMsgPack051.severity = silent + [*.sln] indent_style = tab diff --git a/Directory.Packages.props b/Directory.Packages.props index 8fa150db..59d76879 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -21,11 +21,12 @@ + - + diff --git a/docfx/docs/extensibility.md b/docfx/docs/extensibility.md index d4e38284..64079fcc 100644 --- a/docfx/docs/extensibility.md +++ b/docfx/docs/extensibility.md @@ -76,13 +76,20 @@ Interop with other parties is most likely with a UTF-8 text encoding of JSON-RPC StreamJsonRpc includes the following implementations: -1. @StreamJsonRpc.JsonMessageFormatter - Uses Newtonsoft.Json to serialize each JSON-RPC message as actual JSON. +1. - Uses the [`Nerdbank.MessagePack` library][NBMsgPack] to serialize + each message using the very fast and compact binary [MessagePack format][MessagePackFormat]. + This serializer is NativeAOT ready. + Any RPC method parameters and return types that require custom serialization may provide it + with a `MessagePackConverter`-derived class. + All custom converters can be added to the serializer at `NerdbankMessagePackFormatter.UserDataSerializer`. + +1. - Uses Newtonsoft.Json to serialize each JSON-RPC message as actual JSON. The text encoding is configurable via a property. All RPC method parameters and return types must be serializable by Newtonsoft.Json. You can leverage `JsonConverter` and add your custom converters via attributes or by contributing them to the `JsonMessageFormatter.JsonSerializer.Converters` collection. -1. - Uses the [MessagePack-CSharp][MessagePackLibrary] library to serialize each +1. - Uses the [MessagePack-CSharp][MessagePackCSharp] library to serialize each JSON-RPC message using the very fast and compact binary [MessagePack format][MessagePackFormat]. All RPC method parameters and return types must be serializable by `IMessagePackFormatter`. You can contribute your own via `MessagePackFormatter.SetOptions(MessagePackSerializationOptions)`. @@ -100,17 +107,17 @@ Refer to the source code from our built-in formatters to see how to use these he ### Choosing your formatter -#### When to use +#### When to use -The very best performance comes from using the with the . +The very best performance comes from using the with the . This combination is the fastest and produces the most compact serialized format. The [MessagePack format][MessagePackFormat] is a fast, binary serialization format that resembles the structure of JSON. It can be used as a substitute for JSON when both parties agree on the protocol for significant wins in terms of performance and payload size. -Utilizing `MessagePack` for exchanging JSON-RPC messages is incredibly easy. -Check out the `BasicJsonRpc` method in our [MessagePackFormatterTests][MessagePackUsage] class. +The `MessagePackFormatter` is an older formatter that is not NativeAOT ready. +Using it is only advisable for purposes of maintaining serialized format compatibility, since the serialized schema between the two MessagePack formatters varies slightly. #### When to use @@ -128,7 +135,8 @@ It produces JSON text and allows configuring the text encoding, with UTF-8 being This formatter is compatible with remote systems that use when using the default UTF-8 encoding. The remote party must also use the same message handler, such as . -[MessagePackLibrary]: https://github.com/MessagePack-CSharp/MessagePack-CSharp +[NBMsgPack]: https://github.com/AArnott/Nerdbank.MessagePack +[MessagePackCSharp]: https://github.com/MessagePack-CSharp/MessagePack-CSharp [MessagePackUsage]: https://github.com/microsoft/vs-streamjsonrpc/blob/main/test/StreamJsonRpc.Tests/MessagePackFormatterTests.cs [MessagePackFormat]: https://msgpack.org/ [SystemTextJson]: https://learn.microsoft.com/dotnet/standard/serialization/system-text-json/overview diff --git a/docfx/exotic_types/asyncenumerable.md b/docfx/exotic_types/asyncenumerable.md index c71e37e9..22c22fcc 100644 --- a/docfx/exotic_types/asyncenumerable.md +++ b/docfx/exotic_types/asyncenumerable.md @@ -552,6 +552,17 @@ The generator MAY respond with an error if this is done. The generator should never return an empty array of values unless the last value in the sequence has already been returned to the client. +#### Compatibility note + +The `MessagePackFormatter` deviates from this spec by formatting the result object above as an array of values instead. +The example above would instead be formatted as: + +```json +[[4,5,6], false] +``` + +The `NerdbankMessagePackFormatter` does *not* share this spec bug, and thus cannot interoperate with a `MessagePackFormatter` across the wire. + ### Consumer disposes enumerator When the consumer aborts enumeration before the generator has sent `finished: true`, diff --git a/nuget.config b/nuget.config index 35d15c11..6a811cf6 100644 --- a/nuget.config +++ b/nuget.config @@ -5,10 +5,31 @@ + + + + + + + + + + + + + + + + + + + + + diff --git a/src/StreamJsonRpc/FormatterBase.cs b/src/StreamJsonRpc/FormatterBase.cs index 7dd4c647..aa7223e9 100644 --- a/src/StreamJsonRpc/FormatterBase.cs +++ b/src/StreamJsonRpc/FormatterBase.cs @@ -7,6 +7,7 @@ using System.Reflection; using System.Runtime.Serialization; using Nerdbank.Streams; +using PolyType; using StreamJsonRpc.Protocol; using StreamJsonRpc.Reflection; @@ -436,6 +437,7 @@ protected abstract class JsonRpcErrorBase : JsonRpcError, IJsonRpcMessageBufferM /// [Newtonsoft.Json.JsonIgnore] [IgnoreDataMember] + [PropertyShape(Ignore = true)] public TopLevelPropertyBagBase? TopLevelPropertyBag { get; set; } void IJsonRpcMessageBufferManager.DeserializationComplete(JsonRpcMessage message) @@ -480,6 +482,7 @@ protected abstract class JsonRpcResultBase : JsonRpcResult, IJsonRpcMessageBuffe /// [Newtonsoft.Json.JsonIgnore] [IgnoreDataMember] + [PropertyShape(Ignore = true)] public TopLevelPropertyBagBase? TopLevelPropertyBag { get; set; } void IJsonRpcMessageBufferManager.DeserializationComplete(JsonRpcMessage message) diff --git a/src/StreamJsonRpc/JsonMessageFormatter.cs b/src/StreamJsonRpc/JsonMessageFormatter.cs index 8f21c5d0..355dcc9f 100644 --- a/src/StreamJsonRpc/JsonMessageFormatter.cs +++ b/src/StreamJsonRpc/JsonMessageFormatter.cs @@ -1403,7 +1403,7 @@ internal ExceptionConverter(JsonMessageFormatter formatter) } } - return ExceptionSerializationHelpers.Deserialize(this.formatter.JsonRpc, info, this.formatter.JsonRpc?.TraceSource); + return ExceptionSerializationHelpers.Deserialize(this.formatter.JsonRpc, info, this.formatter.JsonRpc.LoadType, this.formatter.JsonRpc?.TraceSource); } finally { diff --git a/src/StreamJsonRpc/JsonRpc.cs b/src/StreamJsonRpc/JsonRpc.cs index 3047c7be..aaad1efa 100644 --- a/src/StreamJsonRpc/JsonRpc.cs +++ b/src/StreamJsonRpc/JsonRpc.cs @@ -8,6 +8,7 @@ using System.Globalization; using System.Reflection; using System.Runtime.CompilerServices; +using System.Runtime.Serialization; using Microsoft.VisualStudio.Threading; using Newtonsoft.Json; using StreamJsonRpc.Protocol; @@ -38,6 +39,11 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR /// private static readonly JsonRpcError DroppedError = new(); + private static readonly ImmutableDictionary DefaultRuntimeDeserializableTypes = ImmutableDictionary.Create() + .Add("System.Exception", new LoadableType(typeof(Exception))) + .Add("System.ArgumentException", new LoadableType(typeof(ArgumentException))) + .Add("System.InvalidOperationException", new LoadableType(typeof(InvalidOperationException))); + #if NET private static readonly MethodInfo ValueTaskAsTaskMethodInfo = typeof(ValueTask<>).GetMethod(nameof(ValueTask.AsTask))!; private static readonly MethodInfo ValueTaskGetResultMethodInfo = typeof(ValueTask<>).GetMethod("get_Result")!; @@ -99,6 +105,8 @@ public class JsonRpc : IDisposableObservable, IJsonRpcFormatterCallbacks, IJsonR /// private ImmutableList remoteRpcTargets = ImmutableList.Empty; + // TODO: make this a custom collection type so it can be shared across JsonRpc instances. + private ImmutableDictionary runtimeDeserializableTypes = DefaultRuntimeDeserializableTypes; private Task? readLinesTask; private long nextId = 1; private int requestsInDispatchCount; @@ -251,6 +259,10 @@ public JsonRpc(IJsonRpcMessageHandler messageHandler) // so that all incoming messages are queued to the threadpool, allowing immediate concurrency. this.SynchronizationContext = new NonConcurrentSynchronizationContext(sticky: false); this.CancellationStrategy = new StandardCancellationStrategy(this); + + this.AddLoadableType(typeof(Exception)); + this.AddLoadableType(typeof(InvalidOperationException)); + this.AddLoadableType(typeof(ArgumentException)); } /// @@ -435,6 +447,11 @@ public enum TraceEvents /// A base-type that does offer the constructor will be instantiated instead. /// ExceptionNotDeserializable, + + /// + /// An error occurred while deserializing a value within an interface. + /// + IFormatterConverterDeserializationFailure, } /// @@ -939,6 +956,24 @@ public void AddLocalRpcMethod(string? rpcMethodName, Delegate handler) return this.rpcTargetInfo.GetJsonRpcMethodAttribute(methodName, parameters); } + /// + /// Gets or sets the set of types that can be deserialized from their name at runtime. + /// + /// + /// This set of types is used by the default implementation of to determine + /// which types can be deserialized when their name is encountered in an RPC message. + /// + public void AddLoadableType([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)] Type type) + { + Requires.NotNull(type); + Requires.Argument(type.FullName is not null, nameof(type), Resources.TypeMustHaveFullName); + this.ThrowIfConfigurationLocked(); + lock (this.syncObject) + { + this.runtimeDeserializableTypes = this.runtimeDeserializableTypes.SetItem(type.FullName, new LoadableType(type)); + } + } + /// /// Starts listening to incoming messages. /// @@ -1311,6 +1346,7 @@ internal void AddRpcInterfaceToTargetInternal([DynamicallyAccessedMembers(Dynami /// Implementations should avoid throwing , or other exceptions, preferring to return instead. /// [RequiresUnreferencedCode(RuntimeReasons.LoadType)] + [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)] protected internal virtual Type? LoadType(string typeFullName, string? assemblyName) { Requires.NotNull(typeFullName, nameof(typeFullName)); @@ -1345,6 +1381,25 @@ internal void AddRpcInterfaceToTargetInternal([DynamicallyAccessedMembers(Dynami return runtimeType; } + /// + /// When overridden by a derived type, this attempts to load a type based on its full name and possibly assembly name. + /// + /// The of the type to be loaded. + /// The assemble name that is expected to define the type, if available. This should be parseable by . + /// The loaded , if one could be found; otherwise . + /// + /// + /// This method is used to load types that are strongly referenced by incoming messages during serialization. + /// It is important to not load types that may pose a security threat based on the type and the trust level of the remote party. + /// + /// + /// The default implementation of this method matches types registered with . + /// + /// Implementations should avoid throwing , or other exceptions, preferring to return instead. + /// + [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)] + protected internal virtual Type? LoadTypeTrimSafe(string typeFullName, string? assemblyName) => this.runtimeDeserializableTypes.TryGetValue(typeFullName, out LoadableType type) ? type.Type : null; + /// /// Disposes managed and native resources held by this instance. /// @@ -1392,7 +1447,11 @@ protected virtual JsonRpcError.ErrorDetail CreateErrorDetails(JsonRpcRequest req bool iserializable = this.ExceptionStrategy == ExceptionProcessing.ISerializable; if (!ExceptionSerializationHelpers.IsSerializable(exception)) { - this.TraceSource.TraceEvent(TraceEventType.Warning, (int)TraceEvents.ExceptionNotSerializable, "An exception of type {0} was thrown but is not serializable.", exception.GetType().AssemblyQualifiedName); + if (localRpcEx is null) + { + this.TraceSource.TraceEvent(TraceEventType.Warning, (int)TraceEvents.ExceptionNotSerializable, "An exception of type {0} was thrown but is not serializable.", exception.GetType().AssemblyQualifiedName); + } + iserializable = false; } @@ -1737,7 +1796,7 @@ protected virtual async ValueTask DispatchRequestAsync(JsonRpcRe } /// - /// Sends the JSON-RPC message to intance to be transmitted. + /// Sends the JSON-RPC message to instance to be transmitted. /// /// The message to send. /// A token to cancel the send request. @@ -2825,6 +2884,18 @@ private IJsonRpcClientProxyInternal CreateProxy(Type contractInterface, ReadOnly options?.OnDispose)!; } + private struct LoadableType([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)] Type type) : IEquatable + { + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)] + public Type Type => type; + + public bool Equals(LoadableType other) => this.Type == other.Type; + + public override int GetHashCode() => type.GetHashCode(); + + public override bool Equals(object? obj) => obj is LoadableType other && this.Equals(other); + } + /// /// An object that correlates tokens within and between instances /// within a process that does not use , diff --git a/src/StreamJsonRpc/MessagePackFormatter.cs b/src/StreamJsonRpc/MessagePackFormatter.cs index d2aa723d..b34d073b 100644 --- a/src/StreamJsonRpc/MessagePackFormatter.cs +++ b/src/StreamJsonRpc/MessagePackFormatter.cs @@ -1445,7 +1445,7 @@ public T Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions var resolverWrapper = options.Resolver as ResolverWrapper; Report.If(resolverWrapper is null, "Unexpected resolver type."); - return ExceptionSerializationHelpers.Deserialize(this.formatter.JsonRpc, info, resolverWrapper?.Formatter.JsonRpc?.TraceSource); + return ExceptionSerializationHelpers.Deserialize(this.formatter.JsonRpc, info, this.formatter.JsonRpc.LoadType, resolverWrapper?.Formatter.JsonRpc?.TraceSource); } finally { diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.AsyncEnumerableConverters.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.AsyncEnumerableConverters.cs new file mode 100644 index 00000000..01e7c402 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.AsyncEnumerableConverters.cs @@ -0,0 +1,123 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +#pragma warning disable CA1812 // Avoid uninstantiated internal classes + +using System.ComponentModel; +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType; +using PolyType.Abstractions; +using StreamJsonRpc; +using StreamJsonRpc.Reflection; + +[assembly: TypeShapeExtension(typeof(IAsyncEnumerable<>), AssociatedTypes = [typeof(NerdbankMessagePackFormatter.AsyncEnumerableConverter<>)], Requirements = TypeShapeRequirements.Constructor)] + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// Converts between an enumeration token and . + /// + /// The type of element to be enumerated. + [EditorBrowsable(EditorBrowsableState.Never)] + public class AsyncEnumerableConverter : MessagePackConverter> + { + /// + /// The constant "token", in its various forms. + /// + private static readonly MessagePackString TokenPropertyName = new(MessageFormatterEnumerableTracker.TokenPropertyName); + + /// + /// The constant "values", in its various forms. + /// + private static readonly MessagePackString ValuesPropertyName = new(MessageFormatterEnumerableTracker.ValuesPropertyName); + + /// + public override IAsyncEnumerable? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return default; + } + + NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); + + context.DepthStep(); + + RawMessagePack? token = default; + IReadOnlyList? initialElements = null; + int propertyCount = reader.ReadMapHeader(); + for (int i = 0; i < propertyCount; i++) + { + if (TokenPropertyName.TryRead(ref reader)) + { + // The value needs to outlive the reader, so we clone it. + token = reader.ReadRaw(context).ToOwned(); + } + else if (ValuesPropertyName.TryRead(ref reader)) + { + initialElements = context.GetConverter>(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + reader.Skip(context); // Skip the unrecognized key + reader.Skip(context); // and its value. + } + } + + return mainFormatter.EnumerableTracker.CreateEnumerableProxy(token.HasValue ? token.Value : null, initialElements); + } + + /// + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to helper method")] + public override void Write(ref MessagePackWriter writer, in IAsyncEnumerable? value, SerializationContext context) + { + context.DepthStep(); + + NerdbankMessagePackFormatter mainFormatter = context.GetFormatter(); + if (value is null) + { + writer.WriteNil(); + } + else + { + (IReadOnlyList elements, bool finished) = value.TearOffPrefetchedElements(); + long token = mainFormatter.EnumerableTracker.GetToken(value); + + int propertyCount = 0; + if (elements.Count > 0) + { + propertyCount++; + } + + if (!finished) + { + propertyCount++; + } + + writer.WriteMapHeader(propertyCount); + + if (!finished) + { + writer.Write(TokenPropertyName); + writer.Write(token); + } + + if (elements.Count > 0) + { + writer.Write(ValuesPropertyName); + context.GetConverter(mainFormatter.GetUserDataShape(typeof(IReadOnlyList))).WriteObject(ref writer, elements, context); + } + } + } + + /// + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs new file mode 100644 index 00000000..c7cd1db3 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.Constants.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; +using StreamJsonRpc.Protocol; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// The constant "jsonrpc". + /// + private static readonly MessagePackString VersionPropertyName = new(Constants.jsonrpc); + + /// + /// The constant "id". + /// + private static readonly MessagePackString IdPropertyName = new(Constants.id); + + /// + /// The constant "method". + /// + private static readonly MessagePackString MethodPropertyName = new(Constants.Request.method); + + /// + /// The constant "result". + /// + private static readonly MessagePackString ResultPropertyName = new(Constants.Result.result); + + /// + /// The constant "error". + /// + private static readonly MessagePackString ErrorPropertyName = new(Constants.Error.error); + + /// + /// The constant "params". + /// + private static readonly MessagePackString ParamsPropertyName = new(Constants.Request.@params); + + /// + /// The constant "traceparent". + /// + private static readonly MessagePackString TraceParentPropertyName = new(Constants.Request.traceparent); + + /// + /// The constant "tracestate". + /// + private static readonly MessagePackString TraceStatePropertyName = new(Constants.Request.tracestate); + + /// + /// The constant "2.0". + /// + private static readonly MessagePackString Version2 = new("2.0"); +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ExceptionConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ExceptionConverter.cs new file mode 100644 index 00000000..d6f820f6 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ExceptionConverter.cs @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Runtime.Serialization; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// Manages serialization of any -derived type that follows standard rules. + /// + /// + /// A serializable class will: + /// 1. Derive from + /// 2. Be attributed with + /// 3. Declare a constructor with a signature of (, ). + /// + private class ExceptionConverter : MessagePackConverter + { + public static readonly ExceptionConverter Instance = new(); + + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + Assumes.NotNull(formatter.JsonRpc); + + context.DepthStep(); + + if (reader.TryReadNil()) + { + return default; + } + + // We have to guard our own recursion because the serializer has no visibility into inner exceptions. + // Each exception in the russian doll is a new serialization job from its perspective. + formatter.exceptionRecursionCounter.Value++; + try + { + if (formatter.exceptionRecursionCounter.Value > formatter.JsonRpc.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and return null as if there were no inner exception. + // Note that in skipping, the parser may use recursion internally and may still throw if its own limits are exceeded. + reader.Skip(context); + return default; + } + + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter)); + int memberCount = reader.ReadMapHeader(); + for (int i = 0; i < memberCount; i++) + { + string? name = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context) + ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + + // SerializationInfo.GetValue(string, typeof(object)) does not call our formatter, + // so the caller will get a boxed RawMessagePack struct in that case. + // Although we can't do much about *that* in general, we can at least ensure that null values + // are represented as null instead of this boxed struct. + RawMessagePack? value = reader.TryReadNil() ? null : reader.ReadRaw(context); + + info.AddSafeValue(name, value); + } + + return ExceptionSerializationHelpers.Deserialize(formatter.JsonRpc, info, formatter.JsonRpc.LoadTypeTrimSafe, formatter.JsonRpc.TraceSource); + } + finally + { + formatter.exceptionRecursionCounter.Value--; + } + } + + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (value is null) + { + writer.WriteNil(); + return; + } + + formatter.exceptionRecursionCounter.Value++; + try + { + if (formatter.exceptionRecursionCounter.Value > formatter.JsonRpc?.ExceptionOptions.RecursionLimit) + { + // Exception recursion has gone too deep. Skip this value and write null as if there were no inner exception. + writer.WriteNil(); + return; + } + + var info = new SerializationInfo(typeof(T), new MessagePackFormatterConverter(formatter)); + ExceptionSerializationHelpers.Serialize((Exception)(object)value, info); + writer.WriteMapHeader(info.GetSafeMemberCount()); + foreach (SerializationEntry element in info.GetSafeMembers()) + { + writer.Write(element.Name); + if (element.Value is null) + { + writer.WriteNil(); + } + else + { + // We prefer the declared type but will fallback to the runtime type. + context.GetConverter(formatter.TypeShapeProvider.GetShape(NormalizeType(element.ObjectType)) ?? formatter.TypeShapeProvider.Resolve(NormalizeType(element.Value.GetType()))) + .WriteObject(ref writer, element.Value, context); + } + } + } + finally + { + formatter.exceptionRecursionCounter.Value--; + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.IJsonRpcMessagePackRetention.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.IJsonRpcMessagePackRetention.cs new file mode 100644 index 00000000..6e86c7a9 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.IJsonRpcMessagePackRetention.cs @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private interface IJsonRpcMessagePackRetention + { + /// + /// Gets the original msgpack sequence that was deserialized into this message. + /// + /// + /// The buffer is only retained for a short time. If it has already been cleared, the result of this property is an empty sequence. + /// + RawMessagePack OriginalMessagePack { get; } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs new file mode 100644 index 00000000..2b0cfc3e --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.MessagePackFormatterConverter.cs @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics; +using System.Runtime.Serialization; +using Nerdbank.MessagePack; +using PolyType; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private partial class MessagePackFormatterConverter(NerdbankMessagePackFormatter formatter) : IFormatterConverter + { +#pragma warning disable CS8766 // This method may in fact return null, and no one cares. + public object? Convert(object value, Type type) +#pragma warning restore CS8766 + { + // We don't support serializing/deserializing the non-generic IDictionary, + // since it uses untyped keys and values which we cannot securely hash. + if (type == typeof(System.Collections.IDictionary)) + { + // Force us to deserialize into a semi-typed dictionary. + // The string key is a reasonable 99% compatible assumption, and allows us to securely hash the keys. + // The untyped values will be alright because we support the primitives types. + type = typeof(Dictionary); + } + + MessagePackReader reader = this.CreateReader(value); + try + { + return formatter.UserDataSerializer.DeserializeObject(ref reader, formatter.GetUserDataShape(type)); + } + catch (Exception ex) + { + formatter.JsonRpc?.TraceSource.TraceData(TraceEventType.Error, (int)JsonRpc.TraceEvents.ExceptionNotDeserializable, ex); + throw; + } + } + + public object Convert(object value, TypeCode typeCode) => typeCode switch + { + TypeCode.Object => new object(), + _ => ExceptionSerializationHelpers.Convert(this, value, typeCode), + }; + + public bool ToBoolean(object value) => this.CreateReader(value).ReadBoolean(); + + public byte ToByte(object value) => this.CreateReader(value).ReadByte(); + + public char ToChar(object value) => this.CreateReader(value).ReadChar(); + + public DateTime ToDateTime(object value) => this.CreateReader(value).ReadDateTime(); + + public decimal ToDecimal(object value) => formatter.UserDataSerializer.Deserialize((RawMessagePack)value, Witness.ShapeProvider); + + public double ToDouble(object value) => this.CreateReader(value).ReadDouble(); + + public short ToInt16(object value) => this.CreateReader(value).ReadInt16(); + + public int ToInt32(object value) => this.CreateReader(value).ReadInt32(); + + public long ToInt64(object value) => this.CreateReader(value).ReadInt64(); + + public sbyte ToSByte(object value) => this.CreateReader(value).ReadSByte(); + + public float ToSingle(object value) => this.CreateReader(value).ReadSingle(); + + public string? ToString(object value) => value is null ? null : this.CreateReader(value).ReadString(); + + public ushort ToUInt16(object value) => this.CreateReader(value).ReadUInt16(); + + public uint ToUInt32(object value) => this.CreateReader(value).ReadUInt32(); + + public ulong ToUInt64(object value) => this.CreateReader(value).ReadUInt64(); + + private MessagePackReader CreateReader(object value) => new((RawMessagePack)value); + + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor>] + private partial class Witness; + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.PipeConverters.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.PipeConverters.cs new file mode 100644 index 00000000..39afed88 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.PipeConverters.cs @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.IO.Pipelines; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private static class PipeConverters + { + internal class DuplexPipeConverter : MessagePackConverter + { + public static readonly DuplexPipeConverter DefaultInstance = new(); + + public override IDuplexPipe? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (reader.TryReadNil()) + { + return null; + } + + return formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in IDuplexPipe? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (formatter.DuplexPipeTracker.GetULongToken(value) is ulong token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } + + internal class PipeReaderConverter : MessagePackConverter + { + public static readonly PipeReaderConverter DefaultInstance = new(); + + public override PipeReader? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (reader.TryReadNil()) + { + return null; + } + + return formatter.DuplexPipeTracker.GetPipeReader(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in PipeReader? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } + + internal class PipeWriterConverter : MessagePackConverter + { + public static readonly PipeWriterConverter DefaultInstance = new(); + + public override PipeWriter? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (reader.TryReadNil()) + { + return null; + } + + return formatter.DuplexPipeTracker.GetPipeWriter(reader.ReadUInt64()); + } + + public override void Write(ref MessagePackWriter writer, in PipeWriter? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value) is ulong token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } + + internal class StreamConverter : MessagePackConverter + { + public static readonly StreamConverter DefaultInstance = new(); + + public override Stream? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (reader.TryReadNil()) + { + return null; + } + + return formatter.DuplexPipeTracker.GetPipe(reader.ReadUInt64()).AsStream(); + } + + public override void Write(ref MessagePackWriter writer, in Stream? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + if (formatter.DuplexPipeTracker.GetULongToken(value?.UsePipe()) is { } token) + { + writer.Write(token); + } + else + { + writer.WriteNil(); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ProgressConverters.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ProgressConverters.cs new file mode 100644 index 00000000..818ce534 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ProgressConverters.cs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// Converts a progress token to an or an into a token. + /// + /// The closed interface. + private class FullProgressConverter : MessagePackConverter + { + private Func? progressProxyCtor; + + [return: MaybeNull] + public override TClass? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (reader.TryReadNil()) + { + return default; + } + + Assumes.NotNull(formatter.JsonRpc); + RawMessagePack token = reader.ReadRaw(context).ToOwned(); + bool clientRequiresNamedArgs = formatter.ApplicableMethodAttributeOnDeserializingMethod?.ClientRequiresNamedArguments is true; + + if (this.progressProxyCtor is null) + { + ITypeShape typeShape = context.TypeShapeProvider?.Resolve(typeof(TClass)) ?? throw new InvalidOperationException("No TypeShapeProvider available."); + IObjectTypeShape progressProxyShape = (IObjectTypeShape?)typeShape.GetAssociatedTypeShape(typeof(MessageFormatterProgressTracker.ProgressProxy<>)) ?? throw new InvalidOperationException("Unable to get ProgressProxy associated shape."); + this.progressProxyCtor = (Func?)progressProxyShape.Constructor?.Accept(NonDefaultConstructorVisitor.Instance) ?? throw new InvalidOperationException("Unable to construct IProgress proxy."); + } + + return this.progressProxyCtor(formatter.JsonRpc, token, clientRequiresNamedArgs); + } + + public override void Write(ref MessagePackWriter writer, in TClass? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (value is null) + { + writer.WriteNil(); + } + else + { + long progressId = formatter.FormatterProgressTracker.GetTokenForProgress(value); + writer.Write(progressId); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.RequestIdConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RequestIdConverter.cs new file mode 100644 index 00000000..0ac3303a --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RequestIdConverter.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private class RequestIdConverter : MessagePackConverter + { + internal static readonly RequestIdConverter Instance = new(); + + private RequestIdConverter() + { + } + + public override RequestId Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + if (reader.NextMessagePackType == MessagePackType.Integer) + { + return new RequestId(reader.ReadInt64()); + } + else + { + // Do *not* read as an interned string here because this ID should be unique. + return new RequestId(reader.ReadString()); + } + } + + public override void Write(ref MessagePackWriter writer, in RequestId value, SerializationContext context) + { + context.DepthStep(); + + if (value.Number.HasValue) + { + writer.Write(value.Number.Value); + } + else + { + writer.Write(value.String); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => JsonNode.Parse(""" + { + "type": ["string", "integer"] + } + """)?.AsObject(); + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.RpcMarshalableConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RpcMarshalableConverter.cs new file mode 100644 index 00000000..5417956e --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.RpcMarshalableConverter.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Diagnostics.CodeAnalysis; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + private class RpcMarshalableConverter( + JsonRpcProxyOptions proxyOptions, + JsonRpcTargetOptions targetOptions, + RpcMarshalableAttribute rpcMarshalableAttribute) : MessagePackConverter + ////where T : class // We expect this, but requiring it adds a constraint that some callers cannot statically satisfy. + { + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Reader is passed to rpc context")] + public override T? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + +#pragma warning disable NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods - We need to switch from user data to envelope serializer + MessageFormatterRpcMarshaledContextTracker.MarshalToken? token = formatter + .envelopeSerializer.Deserialize(ref reader, Witness.ShapeProvider, context.CancellationToken); +#pragma warning restore NBMsgPack030 // Converters should not call top-level `MessagePackSerializer` methods + + return token.HasValue + ? (T?)formatter.RpcMarshaledContextTracker.GetObject(typeof(T), token, proxyOptions) + : default; + } + + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to rpc context")] + public override void Write(ref MessagePackWriter writer, in T? value, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + if (value is null) + { + writer.WriteNil(); + } + else + { + MessageFormatterRpcMarshaledContextTracker.MarshalToken token = formatter.RpcMarshaledContextTracker.GetToken(value, targetOptions, typeof(T), rpcMarshalableAttribute); + context.GetConverter(Witness.ShapeProvider).Write(ref writer, token, context); + } + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.ToStringHelper.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ToStringHelper.cs new file mode 100644 index 00000000..9baa5c80 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.ToStringHelper.cs @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + /// + /// A recyclable object that can serialize a message to JSON on demand. + /// + /// + /// In perf traces, creation of this object used to show up as one of the most allocated objects. + /// It is used even when tracing isn't active. So we changed its design to be reused, + /// since its lifetime is only required during a synchronous call to a trace API. + /// + private class ToStringHelper + { + private RawMessagePack? encodedMessage; + private string? jsonString; + + public override string ToString() + { + Verify.Operation(this.encodedMessage.HasValue, "This object has not been activated. It may have already been recycled."); + + return this.jsonString ??= MessagePackSerializer.ConvertToJson(this.encodedMessage.Value); + } + + /// + /// Initializes this object to represent a message. + /// + internal void Activate(RawMessagePack encodedMessage) + { + this.encodedMessage = encodedMessage; + } + + /// + /// Cleans out this object to release memory and ensure throws if someone uses it after deactivation. + /// + internal void Deactivate() + { + this.encodedMessage = null; + this.jsonString = null; + } + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs new file mode 100644 index 00000000..39b31939 --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.TraceParentConverter.cs @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType.Abstractions; +using StreamJsonRpc.Protocol; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +public partial class NerdbankMessagePackFormatter +{ + internal class TraceParentConverter : MessagePackConverter + { + public unsafe override TraceParent Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + if (reader.ReadArrayHeader() != 2) + { + throw new NotSupportedException("Unexpected array length."); + } + + var result = default(TraceParent); + result.Version = reader.ReadByte(); + if (result.Version != 0) + { + throw new NotSupportedException("traceparent version " + result.Version + " is not supported."); + } + + if (reader.ReadArrayHeader() != 3) + { + throw new NotSupportedException("Unexpected array length in version-format."); + } + + ReadOnlySequence bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected traceid not found."); + bytes.CopyTo(new Span(result.TraceId, TraceParent.TraceIdByteCount)); + + bytes = reader.ReadBytes() ?? throw new NotSupportedException("Expected parentid not found."); + bytes.CopyTo(new Span(result.ParentId, TraceParent.ParentIdByteCount)); + + result.Flags = (TraceParent.TraceFlags)reader.ReadByte(); + + return result; + } + + public unsafe override void Write(ref MessagePackWriter writer, in TraceParent value, SerializationContext context) + { + if (value.Version != 0) + { + throw new NotSupportedException("traceparent version " + value.Version + " is not supported."); + } + + context.DepthStep(); + + writer.WriteArrayHeader(2); + + writer.Write(value.Version); + + writer.WriteArrayHeader(3); + + fixed (byte* traceId = value.TraceId) + { + writer.Write(new ReadOnlySpan(traceId, TraceParent.TraceIdByteCount)); + } + + fixed (byte* parentId = value.ParentId) + { + writer.Write(new ReadOnlySpan(parentId, TraceParent.ParentIdByteCount)); + } + + writer.Write((byte)value.Flags); + } + + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } +} diff --git a/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs new file mode 100644 index 00000000..25e9ef5c --- /dev/null +++ b/src/StreamJsonRpc/NerdbankMessagePackFormatter.cs @@ -0,0 +1,1447 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Buffers; +using System.Collections; +using System.Collections.Immutable; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.IO.Pipelines; +using System.Reflection; +using System.Runtime.ExceptionServices; +using System.Runtime.Serialization; +using System.Text; +using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType; +using PolyType.Abstractions; +using StreamJsonRpc.Protocol; +using StreamJsonRpc.Reflection; + +namespace StreamJsonRpc; + +/// +/// Serializes JSON-RPC messages using MessagePack (a fast, compact binary format). +/// +/// +/// +/// The MessagePack implementation used here comes from https://github.com/AArnott/Nerdbank.MessagePack. +/// +/// +/// This formatter prioritizes being trim and NativeAOT safe. As such, it uses instead of to load exception types to be deserialized. +/// This trim-friendly method should be overridden to return types that are particularly interesting to the application. +/// +/// +public partial class NerdbankMessagePackFormatter : FormatterBase, IJsonRpcMessageFormatter, IJsonRpcFormatterTracingCallbacks, IJsonRpcMessageFactory +{ + /// + /// The default serializer to use for user data, and a good basis for any custom values for + /// . + /// + /// + /// + /// This serializer is configured with set to + /// and various and . + /// + /// + /// When deviating from this default, doing so while preserving the converters and converter factories from the default + /// is highly recommended. + /// It should be done once and stored in a field and reused for the lifetime of the application + /// to avoid repeated startup costs associated with building up the converter tree. + /// + /// + public static readonly MessagePackSerializer DefaultSerializer = new MessagePackSerializer() + { + InternStrings = true, + ConverterFactories = [ConverterFactory.Instance], + Converters = + [ + GetRpcMarshalableConverter(), + PipeConverters.PipeReaderConverter.DefaultInstance, + PipeConverters.PipeWriterConverter.DefaultInstance, + PipeConverters.DuplexPipeConverter.DefaultInstance, + PipeConverters.StreamConverter.DefaultInstance, + + // We preset this one in user data because $/cancellation methods can carry RequestId values as arguments. + RequestIdConverter.Instance, + + ExceptionConverter.Instance, + ], + }.WithObjectConverter(); + + /// + /// A cache of property names to declared property types, indexed by their containing parameter object type. + /// + /// + /// All access to this field should be while holding a lock on this member's value. + /// + private static readonly Dictionary> ParameterObjectPropertyTypes = []; + + /// + /// The serializer context to use for top-level RPC messages. + /// + private readonly MessagePackSerializer envelopeSerializer; + + private readonly ToStringHelper serializationToStringHelper = new(); + + private readonly ToStringHelper deserializationToStringHelper = new(); + + /// + /// Tracks recursion count while serializing or deserializing an exception. + /// + /// + /// This is placed here (outside the generic class) + /// so that it's one counter shared across all exception types that may be serialized or deserialized. + /// + private readonly ThreadLocal exceptionRecursionCounter = new(); + + /// + /// The serializer to use for user data (e.g. arguments, return values and errors). + /// + private MessagePackSerializer userDataSerializer; + + /// + /// Initializes a new instance of the class. + /// + public NerdbankMessagePackFormatter() + { + // Set up initial options for our own message types. + this.envelopeSerializer = DefaultSerializer with + { + StartingContext = new SerializationContext() + { + [SerializationContextExtensions.FormatterKey] = this, + }, + }; + + // Create a serializer for user data. + // At the moment, we just reuse the same serializer for envelope and user data. + this.userDataSerializer = this.envelopeSerializer; + } + + /// + /// Gets the shape provider for user data types. + /// + public required ITypeShapeProvider TypeShapeProvider { get; init; } + + /// + /// Gets the configured serializer to use for request arguments, result values and error data. + /// + /// + /// When setting this property, basing the new value on is highly recommended. + /// + public MessagePackSerializer UserDataSerializer + { + get => this.userDataSerializer; + [MemberNotNull(nameof(this.userDataSerializer))] + init + { + Requires.NotNull(value); + + // Customizing the input serializer to set the FormatterKey is necessary for our stateful converters. + // Doing this does NOT destroy the converter graph that may be cached because mutating the StartingContext + // property does not invalidate the graph. + this.userDataSerializer = value with + { + StartingContext = new SerializationContext() + { + [SerializationContextExtensions.FormatterKey] = this, + }, + }; + } + } + + /// + public JsonRpcMessage Deserialize(ReadOnlySequence contentBuffer) + { + JsonRpcMessage message = this.envelopeSerializer.Deserialize(contentBuffer, Witness.ShapeProvider) + ?? throw new MessagePackSerializationException("Failed to deserialize JSON-RPC message."); + + IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; + this.deserializationToStringHelper.Activate((RawMessagePack)contentBuffer); + try + { + tracingCallbacks?.OnMessageDeserialized(message, this.deserializationToStringHelper); + } + finally + { + this.deserializationToStringHelper.Deactivate(); + } + + return message; + } + + /// + public void Serialize(IBufferWriter bufferWriter, JsonRpcMessage message) + { + if (message is Protocol.JsonRpcRequest { ArgumentsList: null, Arguments: not null and not IReadOnlyDictionary } request) + { + // This request contains named arguments, but not using a standard dictionary. + // Convert it to a dictionary so that the parameters can be matched to the method we're invoking. + if (GetParamsObjectDictionary(request.Arguments) is { } namedArgs) + { + request.Arguments = namedArgs.ArgumentValues; + request.NamedArgumentDeclaredTypes = namedArgs.ArgumentTypes; + } + } + + var writer = new MessagePackWriter(bufferWriter); + try + { + this.envelopeSerializer.Serialize(ref writer, message, Witness.ShapeProvider); + writer.Flush(); + } + catch (Exception ex) + { + throw new MessagePackSerializationException(string.Format(CultureInfo.CurrentCulture, Resources.ErrorWritingJsonRpcMessage, ex.GetType().Name, ex.Message), ex); + } + } + + /// + public object GetJsonText(JsonRpcMessage message) => message is IJsonRpcMessagePackRetention retainedMsgPack + ? MessagePackSerializer.ConvertToJson(retainedMsgPack.OriginalMessagePack) + : throw new NotSupportedException(); + + /// + Protocol.JsonRpcRequest IJsonRpcMessageFactory.CreateRequestMessage() => new OutboundJsonRpcRequest(this); + + /// + Protocol.JsonRpcError IJsonRpcMessageFactory.CreateErrorMessage() => new JsonRpcError(this); + + /// + Protocol.JsonRpcResult IJsonRpcMessageFactory.CreateResultMessage() => new JsonRpcResult(this); + + void IJsonRpcFormatterTracingCallbacks.OnSerializationComplete(JsonRpcMessage message, ReadOnlySequence encodedMessage) + { + IJsonRpcTracingCallbacks? tracingCallbacks = this.JsonRpc; + this.serializationToStringHelper.Activate((RawMessagePack)encodedMessage); + try + { + tracingCallbacks?.OnMessageSerialized(message, this.serializationToStringHelper); + } + finally + { + this.serializationToStringHelper.Deactivate(); + } + } + + internal static MessagePackConverter GetRpcMarshalableConverter() + where T : class + { + if (MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType( + typeof(T), + out JsonRpcProxyOptions? proxyOptions, + out JsonRpcTargetOptions? targetOptions, + out RpcMarshalableAttribute? attribute)) + { + return new RpcMarshalableConverter(proxyOptions, targetOptions, attribute); + } + + throw new NotSupportedException($"Type '{typeof(T).FullName}' is not supported for RPC Marshaling."); + } + + /// + /// Extracts a dictionary of property names and values from the specified params object. + /// + /// The params object. + /// A dictionary of argument values and another of declared argument types, or if is null. + /// + /// This method supports DataContractSerializer-compliant types. This includes C# anonymous types. + /// + [return: NotNullIfNotNull(nameof(paramsObject))] + private static (IReadOnlyDictionary ArgumentValues, IReadOnlyDictionary ArgumentTypes)? GetParamsObjectDictionary(object? paramsObject) + { + if (paramsObject is null) + { + return default; + } + + // Look up the argument types dictionary if we saved it before. + Type paramsObjectType = paramsObject.GetType(); + IReadOnlyDictionary? argumentTypes; + lock (ParameterObjectPropertyTypes) + { + ParameterObjectPropertyTypes.TryGetValue(paramsObjectType, out argumentTypes); + } + + // If we couldn't find a previously created argument types dictionary, create a mutable one that we'll build this time. + Dictionary? mutableArgumentTypes = argumentTypes is null ? [] : null; + + var result = new Dictionary(StringComparer.Ordinal); + + TypeInfo paramsTypeInfo = paramsObject.GetType().GetTypeInfo(); + bool isDataContract = paramsTypeInfo.GetCustomAttribute() is not null; + + BindingFlags bindingFlags = BindingFlags.FlattenHierarchy | BindingFlags.Public | BindingFlags.Instance; + if (isDataContract) + { + bindingFlags |= BindingFlags.NonPublic; + } + + bool TryGetSerializationInfo(MemberInfo memberInfo, out string key) + { + key = memberInfo.Name; + if (isDataContract) + { + DataMemberAttribute? dataMemberAttribute = memberInfo.GetCustomAttribute(); + if (dataMemberAttribute is null) + { + return false; + } + + if (!dataMemberAttribute.EmitDefaultValue) + { + throw new NotSupportedException($"(DataMemberAttribute.EmitDefaultValue == false) is not supported but was found on: {memberInfo.DeclaringType!.FullName}.{memberInfo.Name}."); + } + + key = dataMemberAttribute.Name ?? memberInfo.Name; + return true; + } + else + { + return memberInfo.GetCustomAttribute() is null; + } + } + + foreach (PropertyInfo property in paramsTypeInfo.GetProperties(bindingFlags)) + { + if (property.GetMethod is not null) + { + if (TryGetSerializationInfo(property, out string key)) + { + result[key] = property.GetValue(paramsObject); + if (mutableArgumentTypes is not null) + { + mutableArgumentTypes[key] = property.PropertyType; + } + } + } + } + + foreach (FieldInfo field in paramsTypeInfo.GetFields(bindingFlags)) + { + if (TryGetSerializationInfo(field, out string key)) + { + result[key] = field.GetValue(paramsObject); + if (mutableArgumentTypes is not null) + { + mutableArgumentTypes[key] = field.FieldType; + } + } + } + + // If we assembled the argument types dictionary this time, save it for next time. + if (mutableArgumentTypes is not null) + { + lock (ParameterObjectPropertyTypes) + { + if (ParameterObjectPropertyTypes.TryGetValue(paramsObjectType, out IReadOnlyDictionary? lostRace)) + { + // Of the two, pick the winner to use ourselves so we consolidate on one and allow the GC to collect the loser sooner. + argumentTypes = lostRace; + } + else + { + ParameterObjectPropertyTypes.Add(paramsObjectType, argumentTypes = mutableArgumentTypes); + } + } + } + + return (result, argumentTypes!); + } + + /// + /// Reads a string with an optimized path for the value "2.0". + /// + /// The reader to use. + /// The decoded string. + private static string ReadProtocolVersion(ref MessagePackReader reader) + { + // Recognize "2.0" since we expect it and can avoid decoding and allocating a new string for it. + return Version2.TryRead(ref reader) + ? Version2.Value + : reader.ReadString() ?? throw new MessagePackSerializationException(Resources.RequiredArgumentMissing); + } + + /// + /// Writes the JSON-RPC version property name and value in a highly optimized way. + /// + private static void WriteProtocolVersionPropertyAndValue(ref MessagePackWriter writer, string version) + { + writer.Write(VersionPropertyName); + writer.Write(version); + } + + private static void ReadUnknownProperty(ref MessagePackReader reader, in SerializationContext context, ref Dictionary? topLevelProperties) + { + topLevelProperties ??= new Dictionary(StringComparer.Ordinal); + string name = context.GetConverter(Witness.ShapeProvider).Read(ref reader, context) ?? throw new MessagePackSerializationException("Unexpected nil at property name position."); + topLevelProperties.Add(name, reader.ReadRaw(context)); + } + + private static Type NormalizeType(Type type) + { + if (TrackerHelpers.FindIProgressInterfaceImplementedBy(type) is Type iface) + { + type = iface; + } + else if (TrackerHelpers.FindIAsyncEnumerableInterfaceImplementedBy(type) is Type iface2) + { + type = iface2; + } + else if (typeof(IDuplexPipe).IsAssignableFrom(type)) + { + type = typeof(IDuplexPipe); + } + else if (typeof(PipeWriter).IsAssignableFrom(type)) + { + type = typeof(PipeWriter); + } + else if (typeof(PipeReader).IsAssignableFrom(type)) + { + type = typeof(PipeReader); + } + else if (typeof(Stream).IsAssignableFrom(type)) + { + type = typeof(Stream); + } + else if (typeof(Exception).IsAssignableFrom(type)) + { + type = typeof(Exception); + } + + return type; + } + + private static T ActivateAssociatedType(ITypeShape shape, Type associatedType) + where T : class + => (T?)((IObjectTypeShape?)shape.GetAssociatedTypeShape(associatedType))?.GetDefaultConstructor()?.Invoke() ?? throw new InvalidOperationException($"Missing associated type from {shape.Type.FullName} to {associatedType.FullName}."); + + private ITypeShape GetUserDataShape(Type type) + { + type = NormalizeType(type); + + // We prefer to get the shape from the user shape provider, but will fallback to our own for built-in types. + // But if that fails too, try again with Resolve on the user shape provider so that it throws an exception explaining that the user needs to provide it. + return this.TypeShapeProvider.GetShape(type) ?? Witness.ShapeProvider.GetShape(type) ?? this.TypeShapeProvider.Resolve(type); + } + + private void WriteUserData(ref MessagePackWriter writer, object? value, Type? valueType, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + if (valueType == typeof(void) || valueType == typeof(object)) + { + valueType = null; + } + + ITypeShape valueShape = this.GetUserDataShape(valueType ?? value.GetType()); + this.UserDataSerializer.SerializeObject(ref writer, value, valueShape, context.CancellationToken); + } + } + + /// + /// Converts JSON-RPC messages to and from MessagePack format. + /// + internal class JsonRpcMessageConverter : MessagePackConverter + { + /// + /// Reads a JSON-RPC message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC message. + public override JsonRpcMessage? Read(ref MessagePackReader reader, SerializationContext context) + { + context.DepthStep(); + + MessagePackReader readAhead = reader.CreatePeekReader(); + int propertyCount = readAhead.ReadMapHeader(); + + for (int i = 0; i < propertyCount; i++) + { + if (MethodPropertyName.TryRead(ref readAhead)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ResultPropertyName.TryRead(ref readAhead)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ErrorPropertyName.TryRead(ref readAhead)) + { + return context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + + // This property doesn't tell us the message type. + // Skip its name and value. + readAhead.Skip(context); + readAhead.Skip(context); + } + + throw new UnrecognizedJsonRpcMessageException(); + } + + /// + /// Writes a JSON-RPC message to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in JsonRpcMessage? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + using (formatter.TrackSerialization(value)) + { + switch (value) + { + case Protocol.JsonRpcRequest request: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, request, context); + break; + case Protocol.JsonRpcResult result: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, result, context); + break; + case Protocol.JsonRpcError error: + context.GetConverter(context.TypeShapeProvider).Write(ref writer, error, context); + break; + default: + throw new NotSupportedException("Unexpected JsonRpcMessage-derived type: " + value.GetType().Name); + } + } + } + + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } + + /// + /// Converts a JSON-RPC request message to and from MessagePack format. + /// + internal class JsonRpcRequestConverter : MessagePackConverter + { + /// + /// Reads a JSON-RPC request message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC request message. + public override Protocol.JsonRpcRequest? Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + var result = new JsonRpcRequest(formatter) + { + OriginalMessagePack = (RawMessagePack)reader.Sequence, + }; + + Dictionary? topLevelProperties = null; + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) + { + if (VersionPropertyName.TryRead(ref reader)) + { + result.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(ref reader)) + { + result.RequestId = context.GetConverter(null).Read(ref reader, context); + } + else if (MethodPropertyName.TryRead(ref reader)) + { + result.Method = context.GetConverter(Witness.ShapeProvider).Read(ref reader, context); + } + else if (ParamsPropertyName.TryRead(ref reader)) + { + SequencePosition paramsTokenStartPosition = reader.Position; + + // Parse out the arguments into a dictionary or array, but don't deserialize them because we don't yet know what types to deserialize them to. + switch (reader.NextMessagePackType) + { + case MessagePackType.Array: + var positionalArgs = new RawMessagePack[reader.ReadArrayHeader()]; + for (int i = 0; i < positionalArgs.Length; i++) + { + positionalArgs[i] = reader.ReadRaw(context); + } + + result.MsgPackPositionalArguments = positionalArgs; + break; + case MessagePackType.Map: + int namedArgsCount = reader.ReadMapHeader(); + var namedArgs = new Dictionary(namedArgsCount, StringComparer.Ordinal); + for (int i = 0; i < namedArgsCount; i++) + { + // Use a string converter so that strings can be interned. + string propertyName = context.GetConverter(Witness.ShapeProvider).Read(ref reader, context) ?? throw new MessagePackSerializationException(Resources.UnexpectedNullValueInMap); + namedArgs.Add(propertyName, reader.ReadRaw(context)); + } + + result.MsgPackNamedArguments = namedArgs; + break; + case MessagePackType.Nil: + result.MsgPackPositionalArguments = []; + reader.ReadNil(); + break; + case MessagePackType type: + throw new MessagePackSerializationException("Expected a map or array of arguments but got " + type); + } + + result.MsgPackArguments = (RawMessagePack)reader.Sequence.Slice(paramsTokenStartPosition, reader.Position); + } + else if (TraceParentPropertyName.TryRead(ref reader)) + { + TraceParent traceParent = context.GetConverter(null).Read(ref reader, context); + result.TraceParent = traceParent.ToString(); + } + else if (TraceStatePropertyName.TryRead(ref reader)) + { + result.TraceState = ReadTraceState(ref reader, context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties); + } + } + + if (topLevelProperties is not null) + { + result.TopLevelPropertyBag = new TopLevelPropertyBag(formatter, topLevelProperties); + } + + formatter.TryHandleSpecialIncomingMessage(result); + + return result; + } + + /// + /// Writes a JSON-RPC request message to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC request message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcRequest? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + + int mapElementCount = value.RequestId.IsEmpty ? 3 : 4; + if (value.TraceParent?.Length > 0) + { + mapElementCount++; + if (value.TraceState?.Length > 0) + { + mapElementCount++; + } + } + + mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + if (!value.RequestId.IsEmpty) + { + writer.Write(IdPropertyName); + context.GetConverter(Witness.ShapeProvider).Write(ref writer, value.RequestId, context); + } + + writer.Write(MethodPropertyName); + writer.Write(value.Method); + + writer.Write(ParamsPropertyName); + if (value.ArgumentsList is not null) + { + writer.WriteArrayHeader(value.ArgumentsList.Count); + + for (int i = 0; i < value.ArgumentsList.Count; i++) + { + formatter.WriteUserData(ref writer, value.ArgumentsList[i], value.ArgumentListDeclaredTypes?[i], context); + } + } + else if (value.NamedArguments is not null) + { + writer.WriteMapHeader(value.NamedArguments.Count); + foreach (KeyValuePair entry in value.NamedArguments) + { + writer.Write(entry.Key); + formatter.WriteUserData(ref writer, entry.Value, value.NamedArgumentDeclaredTypes?[entry.Key], context); + } + } + else + { + writer.WriteNil(); + } + + if (value.TraceParent?.Length > 0) + { + writer.Write(TraceParentPropertyName); + context.GetConverter(Witness.ShapeProvider).Write(ref writer, new TraceParent(value.TraceParent), context); + + if (value.TraceState?.Length > 0) + { + writer.Write(TraceStatePropertyName); + WriteTraceState(ref writer, value.TraceState); + } + } + + topLevelPropertyBag?.WriteProperties(ref writer); + } + + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + + private static void WriteTraceState(ref MessagePackWriter writer, string traceState) + { + ReadOnlySpan traceStateChars = traceState.AsSpan(); + + // Count elements first so we can write the header. + int elementCount = 1; + int commaIndex; + while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) + { + elementCount++; + traceStateChars = traceStateChars.Slice(commaIndex + 1); + } + + // For every element, we have a key and value to record. + writer.WriteArrayHeader(elementCount * 2); + + traceStateChars = traceState.AsSpan(); + while ((commaIndex = traceStateChars.IndexOf(',')) >= 0) + { + ReadOnlySpan element = traceStateChars.Slice(0, commaIndex); + WritePair(ref writer, element); + traceStateChars = traceStateChars.Slice(commaIndex + 1); + } + + // Write out the last one. + WritePair(ref writer, traceStateChars); + + static void WritePair(ref MessagePackWriter writer, ReadOnlySpan pair) + { + int equalsIndex = pair.IndexOf('='); + ReadOnlySpan key = pair.Slice(0, equalsIndex); + ReadOnlySpan value = pair.Slice(equalsIndex + 1); + writer.Write(key); + writer.Write(value); + } + } + + private static unsafe string ReadTraceState(ref MessagePackReader reader, SerializationContext context) + { + int elements = reader.ReadArrayHeader(); + if (elements % 2 != 0) + { + throw new NotSupportedException("Odd number of elements not expected."); + } + + // With care, we could probably assemble this string with just two allocations (the string + a char[]). + var resultBuilder = new StringBuilder(); + for (int i = 0; i < elements; i += 2) + { + if (resultBuilder.Length > 0) + { + resultBuilder.Append(','); + } + + // We assume the key is a frequent string, and the value is unique, + // so we optimize whether to use string interning or not on that basis. + resultBuilder.Append(context.GetConverter(Witness.ShapeProvider).Read(ref reader, context)); + resultBuilder.Append('='); + resultBuilder.Append(reader.ReadString()); + } + + return resultBuilder.ToString(); + } + } + + /// + /// Converts a JSON-RPC result message to and from MessagePack format. + /// + internal class JsonRpcResultConverter : MessagePackConverter + { + /// + /// Reads a JSON-RPC result message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC result message. + public override Protocol.JsonRpcResult Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + var result = new JsonRpcResult(formatter) + { + OriginalMessagePack = (RawMessagePack)reader.Sequence, + }; + + Dictionary? topLevelProperties = null; + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIndex = 0; propertyIndex < propertyCount; propertyIndex++) + { + if (VersionPropertyName.TryRead(ref reader)) + { + result.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(ref reader)) + { + result.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ResultPropertyName.TryRead(ref reader)) + { + result.MsgPackResult = reader.ReadRaw(context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties); + } + } + + if (topLevelProperties is not null) + { + result.TopLevelPropertyBag = new TopLevelPropertyBag(formatter, topLevelProperties); + } + + return result; + } + + /// + /// Writes a JSON-RPC result message to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC result message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcResult? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + + context.DepthStep(); + + var topLevelPropertyBagMessage = value as IMessageWithTopLevelPropertyBag; + + int mapElementCount = 3; + mapElementCount += (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + writer.Write(IdPropertyName); + context.GetConverter(context.TypeShapeProvider).Write(ref writer, value.RequestId, context); + + writer.Write(ResultPropertyName); + + if (value.Result is null) + { + writer.WriteNil(); + } + else + { + formatter.WriteUserData(ref writer, value.Result, value.ResultDeclaredType, context); + } + + (topLevelPropertyBagMessage?.TopLevelPropertyBag as TopLevelPropertyBag)?.WriteProperties(ref writer); + } + + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } + + /// + /// Converts a JSON-RPC error message to and from MessagePack format. + /// + internal class JsonRpcErrorConverter : MessagePackConverter + { + /// + /// Reads a JSON-RPC error message from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC error message. + public override Protocol.JsonRpcError Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + var error = new JsonRpcError(formatter) + { + OriginalMessagePack = (RawMessagePack)reader.Sequence, + }; + + Dictionary? topLevelProperties = null; + + context.DepthStep(); + int propertyCount = reader.ReadMapHeader(); + for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) + { + if (VersionPropertyName.TryRead(ref reader)) + { + error.Version = ReadProtocolVersion(ref reader); + } + else if (IdPropertyName.TryRead(ref reader)) + { + error.RequestId = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (ErrorPropertyName.TryRead(ref reader)) + { + error.Error = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else + { + ReadUnknownProperty(ref reader, context, ref topLevelProperties); + } + } + + if (topLevelProperties is not null) + { + error.TopLevelPropertyBag = new TopLevelPropertyBag(formatter, topLevelProperties); + } + + return error; + } + + /// + /// Writes a JSON-RPC error message to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC error message to write. + /// The serialization context. + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + var topLevelPropertyBag = (TopLevelPropertyBag?)(value as IMessageWithTopLevelPropertyBag)?.TopLevelPropertyBag; + + context.DepthStep(); + int mapElementCount = 3; + mapElementCount += topLevelPropertyBag?.PropertyCount ?? 0; + writer.WriteMapHeader(mapElementCount); + + WriteProtocolVersionPropertyAndValue(ref writer, value.Version); + + writer.Write(IdPropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.RequestId, context); + + writer.Write(ErrorPropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.Error, context); + + topLevelPropertyBag?.WriteProperties(ref writer); + } + + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } + + /// + /// Converts a JSON-RPC error detail to and from MessagePack format. + /// + internal class JsonRpcErrorDetailConverter : MessagePackConverter + { + private static readonly MessagePackString CodePropertyName = new("code"); + private static readonly MessagePackString MessagePropertyName = new("message"); + private static readonly MessagePackString DataPropertyName = new("data"); + + /// + /// Reads a JSON-RPC error detail from the specified MessagePack reader. + /// + /// The MessagePack reader to read from. + /// The serialization context. + /// The deserialized JSON-RPC error detail. + public override Protocol.JsonRpcError.ErrorDetail Read(ref MessagePackReader reader, SerializationContext context) + { + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + var result = new JsonRpcError.ErrorDetail(formatter); + + int propertyCount = reader.ReadMapHeader(); + for (int propertyIdx = 0; propertyIdx < propertyCount; propertyIdx++) + { + if (CodePropertyName.TryRead(ref reader)) + { + result.Code = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (MessagePropertyName.TryRead(ref reader)) + { + result.Message = context.GetConverter(context.TypeShapeProvider).Read(ref reader, context); + } + else if (DataPropertyName.TryRead(ref reader)) + { + result.MsgPackData = reader.ReadRaw(context); + } + else + { + reader.Skip(context); // skip the key + reader.Skip(context); // skip the value + } + } + + return result; + } + + /// + /// Writes a JSON-RPC error detail to the specified MessagePack writer. + /// + /// The MessagePack writer to write to. + /// The JSON-RPC error detail to write. + /// The serialization context. + [SuppressMessage("Usage", "NBMsgPack031:Converters should read or write exactly one msgpack structure", Justification = "Writer is passed to user data context")] + public override void Write(ref MessagePackWriter writer, in Protocol.JsonRpcError.ErrorDetail? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + NerdbankMessagePackFormatter formatter = context.GetFormatter(); + context.DepthStep(); + + writer.WriteMapHeader(3); + + writer.Write(CodePropertyName); + context.GetConverter(context.TypeShapeProvider) + .Write(ref writer, value.Code, context); + + writer.Write(MessagePropertyName); + writer.Write(value.Message); + + writer.Write(DataPropertyName); + if (value.Data is null) + { + writer.WriteNil(); + } + else + { + // We generally leave error data for the user to provide the shape for. + // But for CommonErrorData, we can take responsibility for that. + // We also take responsibility for Exception serialization (for now). + ITypeShapeProvider provider = value.Data is CommonErrorData or Exception ? Witness.ShapeProvider : formatter.TypeShapeProvider; + Type declaredType = value.Data is Exception ? typeof(Exception) : value.Data.GetType(); + context.GetConverter(declaredType, provider).WriteObject(ref writer, value.Data, context); + } + } + + /// + /// Gets the JSON schema for the specified type. + /// + /// The JSON schema context. + /// The type shape. + /// The JSON schema for the specified type. + public override JsonObject? GetJsonSchema(JsonSchemaContext context, ITypeShape typeShape) => null; + } + + private class TopLevelPropertyBag : TopLevelPropertyBagBase + { + private readonly NerdbankMessagePackFormatter formatter; + private readonly IReadOnlyDictionary? inboundUnknownProperties; + + /// + /// Initializes a new instance of the class + /// for an incoming message. + /// + /// The owning formatter. + /// The map of unrecognized inbound properties. + internal TopLevelPropertyBag(NerdbankMessagePackFormatter formatter, IReadOnlyDictionary inboundUnknownProperties) + : base(isOutbound: false) + { + this.formatter = formatter; + this.inboundUnknownProperties = inboundUnknownProperties; + } + + /// + /// Initializes a new instance of the class + /// for an outbound message. + /// + /// The owning formatter. + internal TopLevelPropertyBag(NerdbankMessagePackFormatter formatter) + : base(isOutbound: true) + { + this.formatter = formatter; + } + + internal int PropertyCount => this.inboundUnknownProperties?.Count ?? this.OutboundProperties?.Count ?? 0; + + /// + /// Writes the properties tracked by this collection to a messagepack writer. + /// + /// The writer to use. + internal void WriteProperties(ref MessagePackWriter writer) + { + if (this.inboundUnknownProperties is not null) + { + // We're actually re-transmitting an incoming message (remote target feature). + // We need to copy all the properties that were in the original message. + // Don't implement this without enabling the tests for the scenario found in JsonRpcRemoteTargetMessagePackFormatterTests.cs. + // The tests fail for reasons even without this support, so there's work to do beyond just implementing this. + throw new NotImplementedException(); + + ////foreach (KeyValuePair> entry in this.inboundUnknownProperties) + ////{ + //// writer.Write(entry.Key); + //// writer.Write(entry.Value); + ////} + } + else + { + foreach (KeyValuePair entry in this.OutboundProperties) + { + writer.Write(entry.Key); + this.formatter.userDataSerializer.SerializeObject(ref writer, entry.Value.Value, this.formatter.TypeShapeProvider.Resolve(entry.Value.DeclaredType)); + } + } + } + + protected internal override bool TryGetTopLevelProperty(string name, [MaybeNull] out T value) + { + if (this.inboundUnknownProperties is null) + { + throw new InvalidOperationException(Resources.InboundMessageOnly); + } + + value = default; + + if (this.inboundUnknownProperties.TryGetValue(name, out RawMessagePack serializedValue) is true) + { + value = this.formatter.userDataSerializer.Deserialize(serializedValue, this.formatter.TypeShapeProvider); + return true; + } + + return false; + } + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + private class OutboundJsonRpcRequest(NerdbankMessagePackFormatter formatter) : JsonRpcRequestBase + { + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(formatter); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + private class JsonRpcRequest : JsonRpcRequestBase, IJsonRpcMessagePackRetention + { + private readonly NerdbankMessagePackFormatter formatter; + + internal JsonRpcRequest(NerdbankMessagePackFormatter formatter) + { + this.formatter = formatter; + } + + public override int ArgumentCount => this.MsgPackNamedArguments?.Count ?? this.MsgPackPositionalArguments?.Count ?? base.ArgumentCount; + + public override IEnumerable? ArgumentNames => this.MsgPackNamedArguments?.Keys; + + public RawMessagePack OriginalMessagePack { get; internal set; } + + internal RawMessagePack MsgPackArguments { get; set; } + + internal IReadOnlyDictionary? MsgPackNamedArguments { get; set; } + + internal IReadOnlyList? MsgPackPositionalArguments { get; set; } + + public override ArgumentMatchResult TryGetTypedArguments(ReadOnlySpan parameters, Span typedArguments) + { + using (this.formatter.TrackDeserialization(this, parameters)) + { + if (parameters.Length == 1 && this.MsgPackNamedArguments is not null) + { + if (this.formatter.ApplicableMethodAttributeOnDeserializingMethod?.UseSingleObjectParameterDeserialization ?? false) + { + var reader = new MessagePackReader(this.MsgPackArguments); + try + { + typedArguments[0] = this.formatter.userDataSerializer.DeserializeObject( + ref reader, + this.formatter.TypeShapeProvider.Resolve(parameters[0].ParameterType)); + + return ArgumentMatchResult.Success; + } + catch (MessagePackSerializationException) + { + return ArgumentMatchResult.ParameterArgumentTypeMismatch; + } + } + } + + return base.TryGetTypedArguments(parameters, typedArguments); + } + } + + public override bool TryGetArgumentByNameOrIndex(string? name, int position, Type? typeHint, out object? value) + { + // If anyone asks us for an argument *after* we've been told deserialization is done, there's something very wrong. + Assumes.True(this.MsgPackNamedArguments is not null || this.MsgPackPositionalArguments is not null); + + RawMessagePack msgpackArgument = default; + if (position >= 0 && this.MsgPackPositionalArguments?.Count > position) + { + msgpackArgument = this.MsgPackPositionalArguments[position]; + } + else if (name is not null && this.MsgPackNamedArguments is not null) + { + this.MsgPackNamedArguments.TryGetValue(name, out msgpackArgument); + } + + if (msgpackArgument.MsgPack.IsEmpty) + { + value = null; + return false; + } + + using (this.formatter.TrackDeserialization(this)) + { + try + { + MessagePackReader reader = new(msgpackArgument); + value = this.formatter.userDataSerializer.DeserializeObject( + ref reader, + this.formatter.GetUserDataShape(typeHint ?? typeof(object))); + + return true; + } + catch (MessagePackSerializationException ex) + { + if (this.formatter.JsonRpc?.TraceSource.Switch.ShouldTrace(TraceEventType.Warning) ?? false) + { + this.formatter.JsonRpc.TraceSource.TraceEvent(TraceEventType.Warning, (int)JsonRpc.TraceEvents.MethodArgumentDeserializationFailure, Resources.FailureDeserializingRpcArgument, name, position, typeHint, ex); + } + + throw new RpcArgumentDeserializationException(name, position, typeHint, ex); + } + } + } + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + this.MsgPackNamedArguments = null; + this.MsgPackPositionalArguments = null; + this.TopLevelPropertyBag = null; + this.MsgPackArguments = default; + this.OriginalMessagePack = default; + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(this.formatter); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + private partial class JsonRpcResult(NerdbankMessagePackFormatter formatter) : JsonRpcResultBase, IJsonRpcMessagePackRetention + { + private Exception? resultDeserializationException; + + public RawMessagePack OriginalMessagePack { get; internal set; } + + internal RawMessagePack MsgPackResult { get; set; } + + public override T GetResult() + { + if (this.resultDeserializationException is not null) + { + ExceptionDispatchInfo.Capture(this.resultDeserializationException).Throw(); + } + + return this.MsgPackResult.MsgPack.IsEmpty + ? (T)this.Result! + : formatter.userDataSerializer.Deserialize(this.MsgPackResult, formatter.TypeShapeProvider) + ?? throw new MessagePackSerializationException("Failed to deserialize result."); + } + + protected internal override void SetExpectedResultType(Type resultType) + { + Verify.Operation(!this.MsgPackResult.MsgPack.IsEmpty, "Result is no longer available or has already been deserialized."); + + try + { + using (formatter.TrackDeserialization(this)) + { + MessagePackReader reader = new(this.MsgPackResult); + this.Result = formatter.userDataSerializer.DeserializeObject(ref reader, formatter.TypeShapeProvider.Resolve(resultType)); + } + + this.MsgPackResult = default; + } + catch (MessagePackSerializationException ex) + { + // This was a best effort anyway. We'll throw again later at a more convenient time for JsonRpc. + this.resultDeserializationException = ex; + } + } + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + this.MsgPackResult = default; + this.OriginalMessagePack = default; + } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(formatter); + } + + [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] + private class JsonRpcError(NerdbankMessagePackFormatter formatter) : JsonRpcErrorBase, IJsonRpcMessagePackRetention + { + public RawMessagePack OriginalMessagePack { get; internal set; } + + protected override TopLevelPropertyBagBase? CreateTopLevelPropertyBag() => new TopLevelPropertyBag(formatter); + + protected override void ReleaseBuffers() + { + base.ReleaseBuffers(); + if (this.Error is ErrorDetail privateDetail) + { + privateDetail.MsgPackData = default; + } + + this.OriginalMessagePack = default; + } + + internal new class ErrorDetail(NerdbankMessagePackFormatter formatter) : Protocol.JsonRpcError.ErrorDetail + { + internal ReadOnlySequence MsgPackData { get; set; } + + public override object? GetData(Type dataType) + { + Requires.NotNull(dataType, nameof(dataType)); + if (this.MsgPackData.IsEmpty) + { + return this.Data; + } + + MessagePackReader reader = new(this.MsgPackData); + try + { + return + (dataType == typeof(Exception) || dataType == typeof(CommonErrorData)) ? formatter.envelopeSerializer.DeserializeObject(ref reader, Witness.ShapeProvider.Resolve(dataType)) : + formatter.userDataSerializer.DeserializeObject(ref reader, formatter.TypeShapeProvider.Resolve(dataType)); + } + catch (MessagePackSerializationException) + { + // Deserialization failed. Try returning array/dictionary based primitive objects. + try + { + reader = new(this.MsgPackData); + return formatter.envelopeSerializer.DeserializePrimitives(ref reader); + } + catch (MessagePackSerializationException) + { + return null; + } + } + } + + protected internal override void SetExpectedDataType(Type dataType) + { + Verify.Operation(!this.MsgPackData.IsEmpty, "Data is no longer available or has already been deserialized."); + + this.Data = this.GetData(dataType); + + // Clear the source now that we've deserialized to prevent GetData from attempting + // deserialization later when the buffer may be recycled on another thread. + this.MsgPackData = default; + } + } + } + + private class ConverterFactory : IMessagePackConverterFactory + { + internal static readonly ConverterFactory Instance = new(); + + private ConverterFactory() + { + } + + public MessagePackConverter? CreateConverter(ITypeShape shape) + => MessageFormatterProgressTracker.CanDeserialize(typeof(T)) || MessageFormatterProgressTracker.CanSerialize(typeof(T)) ? new FullProgressConverter() : + TrackerHelpers.IsIAsyncEnumerable(typeof(T)) ? ActivateAssociatedType>(shape, typeof(AsyncEnumerableConverter<>)) : + TrackerHelpers.FindIAsyncEnumerableInterfaceImplementedBy(typeof(T)) is Type iface ? ActivateAssociatedType>(shape, typeof(AsyncEnumerableConverter<>)) : + MessageFormatterRpcMarshaledContextTracker.TryGetMarshalOptionsForType(typeof(T), out JsonRpcProxyOptions? proxyOptions, out JsonRpcTargetOptions? targetOptions, out RpcMarshalableAttribute? attribute) ? new RpcMarshalableConverter(proxyOptions, targetOptions, attribute) : + typeof(Exception).IsAssignableFrom(typeof(T)) ? new ExceptionConverter() : + null; + } + + private class NonDefaultConstructorVisitor : TypeShapeVisitor + { + internal static readonly NonDefaultConstructorVisitor Instance = new(); + + private NonDefaultConstructorVisitor() + { + } + + public override object? VisitConstructor(IConstructorShape constructorShape, object? state = null) + { + Func argStateCtor = constructorShape.GetArgumentStateConstructor(); + Constructor ctor = constructorShape.GetParameterizedConstructor(); + if (constructorShape.Parameters.Count != 3 || + constructorShape.Parameters[0].ParameterType.Type != typeof(T1) || + constructorShape.Parameters[1].ParameterType.Type != typeof(T2) || + constructorShape.Parameters[2].ParameterType.Type != typeof(T3)) + { + throw new InvalidOperationException("Unexpected constructor parameter types."); + } + + var setter1 = (Setter)constructorShape.Parameters[0].Accept(this)!; + var setter2 = (Setter)constructorShape.Parameters[1].Accept(this)!; + var setter3 = (Setter)constructorShape.Parameters[2].Accept(this)!; + + Func func = (p1, p2, p3) => + { + TArgumentState state = argStateCtor(); + setter1(ref state, p1); + setter2(ref state, p2); + setter3(ref state, p3); + return ctor(ref state); + }; + + return func; + } + + public override object? VisitParameter(IParameterShape parameterShape, object? state = null) => parameterShape.GetSetter(); + } + + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + private partial class Witness; +} diff --git a/src/StreamJsonRpc/Polyfills.cs b/src/StreamJsonRpc/Polyfills.cs new file mode 100644 index 00000000..31072d76 --- /dev/null +++ b/src/StreamJsonRpc/Polyfills.cs @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Text; + +namespace StreamJsonRpc; + +internal static class Polyfills +{ +#if !(NETSTANDARD2_1_OR_GREATER || NET) + internal static unsafe string GetString(this Encoding encoding, ReadOnlySpan utf8Bytes) + { + fixed (byte* pBytes = utf8Bytes) + { + return encoding.GetString(pBytes, utf8Bytes.Length); + } + } +#endif +} diff --git a/src/StreamJsonRpc/Protocol/CommonErrorData.cs b/src/StreamJsonRpc/Protocol/CommonErrorData.cs index 823f4d5a..851fa0d9 100644 --- a/src/StreamJsonRpc/Protocol/CommonErrorData.cs +++ b/src/StreamJsonRpc/Protocol/CommonErrorData.cs @@ -2,6 +2,8 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Runtime.Serialization; +using PolyType; +using NBMsgPack = Nerdbank.MessagePack; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -10,7 +12,8 @@ namespace StreamJsonRpc.Protocol; /// A class that describes useful data that may be found in the JSON-RPC error message's error.data property. /// [DataContract] -public class CommonErrorData +[GenerateShape] +public partial class CommonErrorData { /// /// Initializes a new instance of the class. @@ -39,6 +42,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 0, Name = "type")] [STJ.JsonPropertyName("type"), STJ.JsonPropertyOrder(0)] + [PropertyShape(Name = "type"), NBMsgPack.Key(0)] public string? TypeName { get; set; } /// @@ -46,6 +50,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 1, Name = "message")] [STJ.JsonPropertyName("message"), STJ.JsonPropertyOrder(1)] + [PropertyShape(Name = "message"), NBMsgPack.Key(1)] public string? Message { get; set; } /// @@ -53,6 +58,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 2, Name = "stack")] [STJ.JsonPropertyName("stack"), STJ.JsonPropertyOrder(2)] + [PropertyShape(Name = "stack"), NBMsgPack.Key(2)] public string? StackTrace { get; set; } /// @@ -60,6 +66,7 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 3, Name = "code")] [STJ.JsonPropertyName("code"), STJ.JsonPropertyOrder(3)] + [PropertyShape(Name = "code"), NBMsgPack.Key(3)] public int HResult { get; set; } /// @@ -67,5 +74,6 @@ public CommonErrorData(Exception copyFrom) /// [DataMember(Order = 4, Name = "inner")] [STJ.JsonPropertyName("inner"), STJ.JsonPropertyOrder(4)] + [PropertyShape(Name = "inner"), NBMsgPack.Key(4)] public CommonErrorData? Inner { get; set; } } diff --git a/src/StreamJsonRpc/Protocol/JsonRpcError.cs b/src/StreamJsonRpc/Protocol/JsonRpcError.cs index b673edcb..38ccfa91 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcError.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcError.cs @@ -4,6 +4,8 @@ using System.Diagnostics; using System.Runtime.Serialization; using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType; using StreamJsonRpc.Reflection; using STJ = System.Text.Json.Serialization; @@ -13,14 +15,17 @@ namespace StreamJsonRpc.Protocol; /// Describes the error resulting from a that failed on the server. /// [DataContract] +[GenerateShape] +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcErrorConverter))] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + "}")] -public class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId +public partial class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId { /// /// Gets or sets the detail about the error. /// [DataMember(Name = "error", Order = 2, IsRequired = true)] [STJ.JsonPropertyName("error"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PropertyShape(Name = "error", Order = 2)] public ErrorDetail? Error { get; set; } /// @@ -30,6 +35,7 @@ public class JsonRpcError : JsonRpcMessage, IJsonRpcMessageWithId [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -41,6 +47,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = true, EmitDefaultValue = true)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// @@ -66,7 +73,9 @@ public override string ToString() /// Describes the error. /// [DataContract] - public class ErrorDetail + [GenerateShape] + [MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcErrorDetailConverter))] + public partial class ErrorDetail { /// /// Gets or sets a number that indicates the error type that occurred. @@ -77,6 +86,7 @@ public class ErrorDetail /// [DataMember(Name = "code", Order = 0, IsRequired = true)] [STJ.JsonPropertyName("code"), STJ.JsonPropertyOrder(0), STJ.JsonRequired] + [PropertyShape(Name = "code", Order = 0)] public JsonRpcErrorCode Code { get; set; } /// @@ -87,6 +97,7 @@ public class ErrorDetail /// [DataMember(Name = "message", Order = 1, IsRequired = true)] [STJ.JsonPropertyName("message"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PropertyShape(Name = "message", Order = 1)] public string? Message { get; set; } /// @@ -95,6 +106,7 @@ public class ErrorDetail [DataMember(Name = "data", Order = 2, IsRequired = false)] [Newtonsoft.Json.JsonProperty(DefaultValueHandling = Newtonsoft.Json.DefaultValueHandling.Ignore)] [STJ.JsonPropertyName("data"), STJ.JsonPropertyOrder(2)] + [PropertyShape(Name = "data", Order = 2)] public object? Data { get; set; } /// @@ -129,7 +141,7 @@ public class ErrorDetail /// /// The type that will be used as the generic type argument to . /// - /// Overridding methods in types that retain buffers used to deserialize should deserialize within this method and clear those buffers + /// Overriding methods in types that retain buffers used to deserialize should deserialize within this method and clear those buffers /// to prevent further access to these buffers which may otherwise happen concurrently with a call to /// that would recycle the same buffer being deserialized from. /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs index 84acc937..d9487634 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcMessage.cs @@ -3,6 +3,8 @@ using System.Diagnostics.CodeAnalysis; using System.Runtime.Serialization; +using Nerdbank.MessagePack; +using PolyType; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -14,7 +16,9 @@ namespace StreamJsonRpc.Protocol; [KnownType(typeof(JsonRpcRequest))] [KnownType(typeof(JsonRpcResult))] [KnownType(typeof(JsonRpcError))] -public abstract class JsonRpcMessage +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcMessageConverter))] +[GenerateShape] +public abstract partial class JsonRpcMessage { /// /// Gets or sets the version of the JSON-RPC protocol that this message conforms to. @@ -22,6 +26,7 @@ public abstract class JsonRpcMessage /// Defaults to "2.0". [DataMember(Name = "jsonrpc", Order = 0, IsRequired = true)] [STJ.JsonPropertyName("jsonrpc"), STJ.JsonPropertyOrder(0), STJ.JsonRequired] + [PropertyShape(Name = "jsonrpc", Order = 0)] public string Version { get; set; } = "2.0"; /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs index 904b21d2..185a1c84 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcRequest.cs @@ -5,6 +5,9 @@ using System.Reflection; using System.Runtime.Serialization; using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType; +using JsonNET = Newtonsoft.Json.Linq; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -13,8 +16,10 @@ namespace StreamJsonRpc.Protocol; /// Describes a method to be invoked on the server. /// [DataContract] +[GenerateShape] +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcRequestConverter))] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] -public class JsonRpcRequest : JsonRpcMessage, IJsonRpcMessageWithId +public partial class JsonRpcRequest : JsonRpcMessage, IJsonRpcMessageWithId { /// /// The result of an attempt to match request arguments with a candidate method's parameters. @@ -47,6 +52,7 @@ public enum ArgumentMatchResult /// [DataMember(Name = "method", Order = 2, IsRequired = true)] [STJ.JsonPropertyName("method"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PropertyShape(Name = "method", Order = 2)] public string? Method { get; set; } /// @@ -61,6 +67,7 @@ public enum ArgumentMatchResult /// [DataMember(Name = "params", Order = 3, IsRequired = false, EmitDefaultValue = false)] [STJ.JsonPropertyName("params"), STJ.JsonPropertyOrder(3), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "params", Order = 3)] public object? Arguments { get; set; } /// @@ -70,6 +77,7 @@ public enum ArgumentMatchResult [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -81,6 +89,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = false, EmitDefaultValue = false)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingDefault)] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// @@ -88,6 +97,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public bool IsResponseExpected => !this.RequestId.IsEmpty; /// @@ -95,6 +105,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public bool IsNotification => this.RequestId.IsEmpty; /// @@ -102,6 +113,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public virtual int ArgumentCount => this.NamedArguments?.Count ?? this.ArgumentsList?.Count ?? 0; /// @@ -109,6 +121,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyDictionary? NamedArguments { get => this.Arguments as IReadOnlyDictionary; @@ -127,6 +140,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyDictionary? NamedArgumentDeclaredTypes { get; set; } /// @@ -134,6 +148,7 @@ public object? Id /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] [Obsolete("Use " + nameof(ArgumentsList) + " instead.")] public object?[]? ArgumentsArray { @@ -146,6 +161,7 @@ public object?[]? ArgumentsArray /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyList? ArgumentsList { get => this.Arguments as IReadOnlyList; @@ -166,6 +182,7 @@ public IReadOnlyList? ArgumentsList /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public IReadOnlyList? ArgumentListDeclaredTypes { get; set; } /// @@ -173,6 +190,7 @@ public IReadOnlyList? ArgumentsList /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public virtual IEnumerable? ArgumentNames => this.NamedArguments?.Keys; /// @@ -180,6 +198,7 @@ public IReadOnlyList? ArgumentsList /// [DataMember(Name = "traceparent", EmitDefaultValue = false)] [STJ.JsonPropertyName("traceparent"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "traceparent")] public string? TraceParent { get; set; } /// @@ -187,6 +206,7 @@ public IReadOnlyList? ArgumentsList /// [DataMember(Name = "tracestate", EmitDefaultValue = false)] [STJ.JsonPropertyName("tracestate"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "tracestate")] public string? TraceState { get; set; } /// diff --git a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs index 83ac903c..f4efff65 100644 --- a/src/StreamJsonRpc/Protocol/JsonRpcResult.cs +++ b/src/StreamJsonRpc/Protocol/JsonRpcResult.cs @@ -4,6 +4,9 @@ using System.Diagnostics; using System.Runtime.Serialization; using System.Text.Json.Nodes; +using Nerdbank.MessagePack; +using PolyType; +using JsonNET = Newtonsoft.Json.Linq; using STJ = System.Text.Json.Serialization; namespace StreamJsonRpc.Protocol; @@ -12,14 +15,17 @@ namespace StreamJsonRpc.Protocol; /// Describes the result of a successful method invocation. /// [DataContract] +[GenerateShape] +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.JsonRpcResultConverter))] [DebuggerDisplay("{" + nameof(DebuggerDisplay) + ",nq}")] -public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId +public partial class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId { /// /// Gets or sets the value of the result of an invocation, if any. /// [DataMember(Name = "result", Order = 2, IsRequired = true, EmitDefaultValue = true)] [STJ.JsonPropertyName("result"), STJ.JsonPropertyOrder(2), STJ.JsonRequired] + [PropertyShape(Name = "result", Order = 2)] public object? Result { get; set; } /// @@ -30,6 +36,7 @@ public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId /// [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public Type? ResultDeclaredType { get; set; } /// @@ -39,6 +46,7 @@ public class JsonRpcResult : JsonRpcMessage, IJsonRpcMessageWithId [Obsolete("Use " + nameof(RequestId) + " instead.")] [IgnoreDataMember] [STJ.JsonIgnore] + [PropertyShape(Ignore = true)] public object? Id { get => this.RequestId.ObjectValue; @@ -50,6 +58,7 @@ public object? Id /// [DataMember(Name = "id", Order = 1, IsRequired = true)] [STJ.JsonPropertyName("id"), STJ.JsonPropertyOrder(1), STJ.JsonRequired] + [PropertyShape(Name = "id", Order = 1)] public RequestId RequestId { get; set; } /// diff --git a/src/StreamJsonRpc/Protocol/TraceParent.cs b/src/StreamJsonRpc/Protocol/TraceParent.cs index 3d44d15d..5089ae46 100644 --- a/src/StreamJsonRpc/Protocol/TraceParent.cs +++ b/src/StreamJsonRpc/Protocol/TraceParent.cs @@ -2,9 +2,11 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Diagnostics; +using Nerdbank.MessagePack; namespace StreamJsonRpc.Protocol; +[MessagePackConverter(typeof(NerdbankMessagePackFormatter.TraceParentConverter))] internal unsafe struct TraceParent { internal const int VersionByteCount = 1; diff --git a/src/StreamJsonRpc/Reflection/ExceptionSerializationHelpers.cs b/src/StreamJsonRpc/Reflection/ExceptionSerializationHelpers.cs index 9f11b586..f206645c 100644 --- a/src/StreamJsonRpc/Reflection/ExceptionSerializationHelpers.cs +++ b/src/StreamJsonRpc/Reflection/ExceptionSerializationHelpers.cs @@ -23,11 +23,13 @@ internal static class ExceptionSerializationHelpers private static readonly Type[] DeserializingConstructorParameterTypes = new Type[] { typeof(SerializationInfo), typeof(StreamingContext) }; + [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.NonPublicConstructors)] + internal delegate Type? ExceptionTypeLoader(string typeName, string? assemblyName); + private static StreamingContext Context => new StreamingContext(StreamingContextStates.Remoting); - [RequiresUnreferencedCode(RuntimeReasons.LoadType)] - internal static T Deserialize(JsonRpc jsonRpc, SerializationInfo info, TraceSource? traceSource) - where T : Exception + internal static T Deserialize(JsonRpc jsonRpc, SerializationInfo info, ExceptionTypeLoader typeLoader, TraceSource? traceSource) + ////where T : Exception { if (!TryGetValue(info, "ClassName", out string? runtimeTypeName) || runtimeTypeName is null) { @@ -35,7 +37,7 @@ internal static T Deserialize(JsonRpc jsonRpc, SerializationInfo info, TraceS } TryGetValue(info, AssemblyNameKeyName, out string? runtimeAssemblyName); - Type? runtimeType = jsonRpc.LoadType(runtimeTypeName, runtimeAssemblyName); + Type? runtimeType = typeLoader(runtimeTypeName, runtimeAssemblyName); if (runtimeType is null) { if (traceSource?.Switch.ShouldTrace(TraceEventType.Warning) ?? false) @@ -75,7 +77,7 @@ internal static T Deserialize(JsonRpc jsonRpc, SerializationInfo info, TraceS traceSource?.TraceEvent(TraceEventType.Warning, (int)JsonRpc.TraceEvents.ExceptionNotDeserializable, errorMessage); - runtimeType = runtimeType.BaseType; + runtimeType = runtimeType.BaseType is { FullName: not null } ? typeLoader(runtimeType.BaseType.FullName, runtimeType.BaseType.Assembly.FullName) : null; } if (ctor is null) @@ -83,7 +85,7 @@ internal static T Deserialize(JsonRpc jsonRpc, SerializationInfo info, TraceS throw new NotSupportedException($"{originalRuntimeType.FullName} is not a supported exception type to deserialize and no adequate substitute could be found."); } - return (T)ctor.Invoke(new object?[] { info, Context }); + return (T)ctor.Invoke([info, Context]); } internal static void Serialize(Exception exception, SerializationInfo info) diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs index 24097f4a..bbb1f8a2 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterEnumerableTracker.cs @@ -9,9 +9,12 @@ using System.Threading.Tasks.Dataflow; using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; +using PolyType; using StreamJsonRpc.Protocol; using STJ = System.Text.Json.Serialization; +[assembly: TypeShapeExtension(typeof(IAsyncEnumerable<>), AssociatedTypes = [typeof(StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults<>)])] + namespace StreamJsonRpc.Reflection; /// @@ -219,13 +222,14 @@ private void CleanUpResources(RequestId outboundRequestId) /// The type of element in the enumeration. [DataContract] [EditorBrowsable(EditorBrowsableState.Never)] - public sealed class EnumeratorResults + public class EnumeratorResults { /// /// Gets the slice of values in this segment. /// [DataMember(Name = ValuesPropertyName, Order = 0)] [STJ.JsonPropertyName(ValuesPropertyName), STJ.JsonPropertyOrder(0)] + [PropertyShape(Name = ValuesPropertyName)] public IReadOnlyList? Values { get; init; } /// @@ -233,6 +237,7 @@ public sealed class EnumeratorResults /// [DataMember(Name = FinishedPropertyName, Order = 1)] [STJ.JsonPropertyName(FinishedPropertyName), STJ.JsonPropertyOrder(1)] + [PropertyShape(Name = FinishedPropertyName)] public bool Finished { get; init; } } diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterProgressTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterProgressTracker.cs index bd7c48cc..8fb99467 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterProgressTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterProgressTracker.cs @@ -2,9 +2,13 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System.Collections.Immutable; +using System.ComponentModel; using System.Diagnostics.CodeAnalysis; using System.Reflection; using Microsoft.VisualStudio.Threading; +using PolyType; + +[assembly: TypeShapeExtension(typeof(IProgress<>), AssociatedTypes = [typeof(StreamJsonRpc.Reflection.MessageFormatterProgressTracker.ProgressProxy<>)], Requirements = TypeShapeRequirements.Constructor)] namespace StreamJsonRpc.Reflection; @@ -167,7 +171,7 @@ public bool TryGetProgressObject(long progressId, [NotNullWhen(true)] out Progre /// /// The type of the value to be reported by . - public IProgress CreateProgress(JsonRpc rpc, object token, bool clientRequiresNamedArguments) => new JsonProgress(rpc, token, clientRequiresNamedArguments); + public IProgress CreateProgress(JsonRpc rpc, object token, bool clientRequiresNamedArguments) => new ProgressProxy(rpc, token, clientRequiresNamedArguments); /// /// @@ -192,8 +196,19 @@ public object CreateProgress(JsonRpc rpc, object token, Type valueType, bool cli Requires.NotNull(token, nameof(token)); Requires.NotNull(valueType, nameof(valueType)); - Type progressType = typeof(JsonProgress<>).MakeGenericType(valueType.GenericTypeArguments[0]); - return Activator.CreateInstance(progressType, new object[] { rpc, token, clientRequiresNamedArguments })!; + Type progressType = typeof(ProgressProxy<>).MakeGenericType(valueType.GenericTypeArguments[0]); + return Activator.CreateInstance(progressType, [rpc, token, clientRequiresNamedArguments])!; + } + + /// + /// + /// + /// + /// The closed generic type of . + /// A new instance of . + internal static object CreateJsonProgress(JsonRpc rpc, object token, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type progressProxyType, bool clientRequiresNamedArguments) + { + return Activator.CreateInstance(progressProxyType, [rpc, token, clientRequiresNamedArguments]) ?? throw Assumes.Fail("Activator.CreateInstance failed."); } private void CleanUpResources(RequestId requestId) @@ -278,19 +293,21 @@ public void InvokeReport(object? typedValue) /// /// Class that implements and sends notification when reporting. /// - private class JsonProgress : IProgress + /// The type of the value to be reported by . + [EditorBrowsable(EditorBrowsableState.Never)] + public class ProgressProxy : IProgress { private readonly JsonRpc rpc; private readonly object token; private readonly bool useNamedArguments; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The instance used to send the notification. /// The progress token used to obtain the instance from . /// to use named arguments; to use positional arguments. - public JsonProgress(JsonRpc rpc, object token, bool useNamedArguments) + public ProgressProxy(JsonRpc rpc, object token, bool useNamedArguments) { this.rpc = rpc ?? throw new ArgumentNullException(nameof(rpc)); this.token = token ?? throw new ArgumentNullException(nameof(token)); diff --git a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs index 58162f94..3d701b4f 100644 --- a/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs +++ b/src/StreamJsonRpc/Reflection/MessageFormatterRpcMarshaledContextTracker.cs @@ -9,6 +9,7 @@ using System.Reflection; using System.Runtime.Serialization; using Microsoft.VisualStudio.Threading; +using PolyType; using static System.FormattableString; using STJ = System.Text.Json.Serialization; @@ -17,7 +18,7 @@ namespace StreamJsonRpc.Reflection; /// /// Tracks objects that get marshaled using the general marshaling protocol. /// -internal class MessageFormatterRpcMarshaledContextTracker +internal partial class MessageFormatterRpcMarshaledContextTracker { private static readonly IReadOnlyCollection<(Type ImplicitlyMarshaledType, JsonRpcProxyOptions ProxyOptions, JsonRpcTargetOptions TargetOptions, RpcMarshalableAttribute Attribute)> ImplicitlyMarshaledTypes = new (Type, JsonRpcProxyOptions, JsonRpcTargetOptions, RpcMarshalableAttribute)[] { @@ -462,12 +463,16 @@ private void CleanUpOutboundResources(RequestId requestId, bool successful) } } + /// + /// A token that represents a marshaled object. + /// [DataContract] - internal struct MarshalToken + [GenerateShape] + internal partial struct MarshalToken { [MessagePack.SerializationConstructor] #pragma warning disable SA1313 // Parameter names should begin with lower-case letter - public MarshalToken(int __jsonrpc_marshaled, long handle, string? lifetime, int[]? optionalInterfaces) + public MarshalToken(int __jsonrpc_marshaled, long handle, string? lifetime = null, int[]? optionalInterfaces = null) #pragma warning restore SA1313 // Parameter names should begin with lower-case letter { this.Marshaled = __jsonrpc_marshaled; @@ -478,18 +483,22 @@ public MarshalToken(int __jsonrpc_marshaled, long handle, string? lifetime, int[ [DataMember(Name = "__jsonrpc_marshaled", IsRequired = true)] [STJ.JsonPropertyName("__jsonrpc_marshaled"), STJ.JsonRequired] + [PropertyShape(Name = "__jsonrpc_marshaled")] public int Marshaled { get; set; } [DataMember(Name = "handle", IsRequired = true)] [STJ.JsonPropertyName("handle"), STJ.JsonRequired] + [PropertyShape(Name = "handle")] public long Handle { get; set; } [DataMember(Name = "lifetime", EmitDefaultValue = false)] [STJ.JsonPropertyName("lifetime"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "lifetime")] public string? Lifetime { get; set; } [DataMember(Name = "optionalInterfaces", EmitDefaultValue = false)] [STJ.JsonPropertyName("optionalInterfaces"), STJ.JsonIgnore(Condition = STJ.JsonIgnoreCondition.WhenWritingNull)] + [PropertyShape(Name = "optionalInterfaces")] public int[]? OptionalInterfacesCodes { get; set; } } diff --git a/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs b/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs index 07e5f29d..e2158249 100644 --- a/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs +++ b/src/StreamJsonRpc/Reflection/RpcMarshalableAttribute.cs @@ -7,7 +7,7 @@ namespace StreamJsonRpc; /// Designates an interface that is used in an RPC contract to marshal the object so the receiver can invoke remote methods on it instead of serializing the object to send its data to the remote end. /// /// -/// Learn more about marshable interfaces. +/// Learn more about marshalable interfaces. /// [AttributeUsage(AttributeTargets.Interface, AllowMultiple = false, Inherited = false)] public class RpcMarshalableAttribute : Attribute diff --git a/src/StreamJsonRpc/Resources.resx b/src/StreamJsonRpc/Resources.resx index e87e9b73..15819110 100644 --- a/src/StreamJsonRpc/Resources.resx +++ b/src/StreamJsonRpc/Resources.resx @@ -167,10 +167,6 @@ Error writing JSON RPC Result: {0}: {1} {0} is the exception type, {1} is the exception message. - - Failure deserializing incoming JSON RPC '{0}': {1} - {0} is the JSON RPC, {1} is the exception message. - Deserializing JSON-RPC argument with name "{0}" and position {1} to type "{2}" failed: {3} {0} is a parameter name, {1} is an integer, {2} is a CLR type name and {3} is an exception object. @@ -337,6 +333,9 @@ The length of this list must equal the length of the arguments list. + + Type must have FullName. + Unable to find method '{0}/{1}' on {2} for the following reasons: {3} {0} is the method name, {1} is arity, {2} is target class full name, {3} is the error list. diff --git a/src/StreamJsonRpc/SerializationContextExtensions.cs b/src/StreamJsonRpc/SerializationContextExtensions.cs new file mode 100644 index 00000000..b26eb905 --- /dev/null +++ b/src/StreamJsonRpc/SerializationContextExtensions.cs @@ -0,0 +1,11 @@ +using Nerdbank.MessagePack; + +namespace StreamJsonRpc; + +internal static class SerializationContextExtensions +{ + internal static object FormatterKey { get; } = new(); + + internal static NerdbankMessagePackFormatter GetFormatter(this in SerializationContext context) + => ((NerdbankMessagePackFormatter?)context[FormatterKey]) ?? throw new InvalidOperationException("This converter may only be used within the context of its owning formatter."); +} diff --git a/src/StreamJsonRpc/StreamJsonRpc.csproj b/src/StreamJsonRpc/StreamJsonRpc.csproj index cd1e33d1..657564b6 100644 --- a/src/StreamJsonRpc/StreamJsonRpc.csproj +++ b/src/StreamJsonRpc/StreamJsonRpc.csproj @@ -14,11 +14,15 @@ + + + + diff --git a/src/StreamJsonRpc/SystemTextJsonFormatter.cs b/src/StreamJsonRpc/SystemTextJsonFormatter.cs index 8a4a1d1e..df476051 100644 --- a/src/StreamJsonRpc/SystemTextJsonFormatter.cs +++ b/src/StreamJsonRpc/SystemTextJsonFormatter.cs @@ -1072,7 +1072,7 @@ internal ExceptionConverter(SystemTextJsonFormatter formatter) info.AddSafeValue(property.Key, property.Value); } - return ExceptionSerializationHelpers.Deserialize(this.formatter.JsonRpc, info, this.formatter.JsonRpc?.TraceSource); + return ExceptionSerializationHelpers.Deserialize(this.formatter.JsonRpc, info, this.formatter.JsonRpc.LoadType, this.formatter.JsonRpc?.TraceSource); } finally { diff --git a/src/StreamJsonRpc/net8.0/PublicAPI.Unshipped.txt b/src/StreamJsonRpc/net8.0/PublicAPI.Unshipped.txt index 3df87170..82b4cca3 100644 --- a/src/StreamJsonRpc/net8.0/PublicAPI.Unshipped.txt +++ b/src/StreamJsonRpc/net8.0/PublicAPI.Unshipped.txt @@ -1,7 +1,28 @@ +override StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.GetJsonSchema(Nerdbank.MessagePack.JsonSchemaContext! context, PolyType.Abstractions.ITypeShape! typeShape) -> System.Text.Json.Nodes.JsonObject? +override StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.Read(ref Nerdbank.MessagePack.MessagePackReader reader, Nerdbank.MessagePack.SerializationContext context) -> System.Collections.Generic.IAsyncEnumerable? +override StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.Write(ref Nerdbank.MessagePack.MessagePackWriter writer, in System.Collections.Generic.IAsyncEnumerable? value, Nerdbank.MessagePack.SerializationContext context) -> void +static readonly StreamJsonRpc.NerdbankMessagePackFormatter.DefaultSerializer -> Nerdbank.MessagePack.MessagePackSerializer! +StreamJsonRpc.JsonRpc.AddLoadableType(System.Type! type) -> void StreamJsonRpc.JsonRpc.Attach(System.ReadOnlySpan interfaceTypes, StreamJsonRpc.JsonRpcProxyOptions? options) -> object! +StreamJsonRpc.JsonRpc.TraceEvents.IFormatterConverterDeserializationFailure = 22 -> StreamJsonRpc.JsonRpc.TraceEvents +StreamJsonRpc.NerdbankMessagePackFormatter +StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter +StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.AsyncEnumerableConverter() -> void +StreamJsonRpc.NerdbankMessagePackFormatter.Deserialize(System.Buffers.ReadOnlySequence contentBuffer) -> StreamJsonRpc.Protocol.JsonRpcMessage! +StreamJsonRpc.NerdbankMessagePackFormatter.GetJsonText(StreamJsonRpc.Protocol.JsonRpcMessage! message) -> object! +StreamJsonRpc.NerdbankMessagePackFormatter.NerdbankMessagePackFormatter() -> void +StreamJsonRpc.NerdbankMessagePackFormatter.Serialize(System.Buffers.IBufferWriter! bufferWriter, StreamJsonRpc.Protocol.JsonRpcMessage! message) -> void +StreamJsonRpc.NerdbankMessagePackFormatter.TypeShapeProvider.get -> PolyType.ITypeShapeProvider! +StreamJsonRpc.NerdbankMessagePackFormatter.TypeShapeProvider.init -> void +StreamJsonRpc.NerdbankMessagePackFormatter.UserDataSerializer.get -> Nerdbank.MessagePack.MessagePackSerializer! +StreamJsonRpc.NerdbankMessagePackFormatter.UserDataSerializer.init -> void StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.EnumeratorResults() -> void StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Finished.get -> bool StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Finished.init -> void StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Values.get -> System.Collections.Generic.IReadOnlyList? -StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Values.init -> void \ No newline at end of file +StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Values.init -> void +StreamJsonRpc.Reflection.MessageFormatterProgressTracker.ProgressProxy +StreamJsonRpc.Reflection.MessageFormatterProgressTracker.ProgressProxy.ProgressProxy(StreamJsonRpc.JsonRpc! rpc, object! token, bool useNamedArguments) -> void +StreamJsonRpc.Reflection.MessageFormatterProgressTracker.ProgressProxy.Report(T value) -> void +virtual StreamJsonRpc.JsonRpc.LoadTypeTrimSafe(string! typeFullName, string? assemblyName) -> System.Type? diff --git a/src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt b/src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt index 3df87170..82b4cca3 100644 --- a/src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/StreamJsonRpc/netstandard2.0/PublicAPI.Unshipped.txt @@ -1,7 +1,28 @@ +override StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.GetJsonSchema(Nerdbank.MessagePack.JsonSchemaContext! context, PolyType.Abstractions.ITypeShape! typeShape) -> System.Text.Json.Nodes.JsonObject? +override StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.Read(ref Nerdbank.MessagePack.MessagePackReader reader, Nerdbank.MessagePack.SerializationContext context) -> System.Collections.Generic.IAsyncEnumerable? +override StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.Write(ref Nerdbank.MessagePack.MessagePackWriter writer, in System.Collections.Generic.IAsyncEnumerable? value, Nerdbank.MessagePack.SerializationContext context) -> void +static readonly StreamJsonRpc.NerdbankMessagePackFormatter.DefaultSerializer -> Nerdbank.MessagePack.MessagePackSerializer! +StreamJsonRpc.JsonRpc.AddLoadableType(System.Type! type) -> void StreamJsonRpc.JsonRpc.Attach(System.ReadOnlySpan interfaceTypes, StreamJsonRpc.JsonRpcProxyOptions? options) -> object! +StreamJsonRpc.JsonRpc.TraceEvents.IFormatterConverterDeserializationFailure = 22 -> StreamJsonRpc.JsonRpc.TraceEvents +StreamJsonRpc.NerdbankMessagePackFormatter +StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter +StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.AsyncEnumerableConverter() -> void +StreamJsonRpc.NerdbankMessagePackFormatter.Deserialize(System.Buffers.ReadOnlySequence contentBuffer) -> StreamJsonRpc.Protocol.JsonRpcMessage! +StreamJsonRpc.NerdbankMessagePackFormatter.GetJsonText(StreamJsonRpc.Protocol.JsonRpcMessage! message) -> object! +StreamJsonRpc.NerdbankMessagePackFormatter.NerdbankMessagePackFormatter() -> void +StreamJsonRpc.NerdbankMessagePackFormatter.Serialize(System.Buffers.IBufferWriter! bufferWriter, StreamJsonRpc.Protocol.JsonRpcMessage! message) -> void +StreamJsonRpc.NerdbankMessagePackFormatter.TypeShapeProvider.get -> PolyType.ITypeShapeProvider! +StreamJsonRpc.NerdbankMessagePackFormatter.TypeShapeProvider.init -> void +StreamJsonRpc.NerdbankMessagePackFormatter.UserDataSerializer.get -> Nerdbank.MessagePack.MessagePackSerializer! +StreamJsonRpc.NerdbankMessagePackFormatter.UserDataSerializer.init -> void StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.EnumeratorResults() -> void StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Finished.get -> bool StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Finished.init -> void StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Values.get -> System.Collections.Generic.IReadOnlyList? -StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Values.init -> void \ No newline at end of file +StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Values.init -> void +StreamJsonRpc.Reflection.MessageFormatterProgressTracker.ProgressProxy +StreamJsonRpc.Reflection.MessageFormatterProgressTracker.ProgressProxy.ProgressProxy(StreamJsonRpc.JsonRpc! rpc, object! token, bool useNamedArguments) -> void +StreamJsonRpc.Reflection.MessageFormatterProgressTracker.ProgressProxy.Report(T value) -> void +virtual StreamJsonRpc.JsonRpc.LoadTypeTrimSafe(string! typeFullName, string? assemblyName) -> System.Type? diff --git a/src/StreamJsonRpc/netstandard2.1/PublicAPI.Unshipped.txt b/src/StreamJsonRpc/netstandard2.1/PublicAPI.Unshipped.txt index 3df87170..82b4cca3 100644 --- a/src/StreamJsonRpc/netstandard2.1/PublicAPI.Unshipped.txt +++ b/src/StreamJsonRpc/netstandard2.1/PublicAPI.Unshipped.txt @@ -1,7 +1,28 @@ +override StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.GetJsonSchema(Nerdbank.MessagePack.JsonSchemaContext! context, PolyType.Abstractions.ITypeShape! typeShape) -> System.Text.Json.Nodes.JsonObject? +override StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.Read(ref Nerdbank.MessagePack.MessagePackReader reader, Nerdbank.MessagePack.SerializationContext context) -> System.Collections.Generic.IAsyncEnumerable? +override StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.Write(ref Nerdbank.MessagePack.MessagePackWriter writer, in System.Collections.Generic.IAsyncEnumerable? value, Nerdbank.MessagePack.SerializationContext context) -> void +static readonly StreamJsonRpc.NerdbankMessagePackFormatter.DefaultSerializer -> Nerdbank.MessagePack.MessagePackSerializer! +StreamJsonRpc.JsonRpc.AddLoadableType(System.Type! type) -> void StreamJsonRpc.JsonRpc.Attach(System.ReadOnlySpan interfaceTypes, StreamJsonRpc.JsonRpcProxyOptions? options) -> object! +StreamJsonRpc.JsonRpc.TraceEvents.IFormatterConverterDeserializationFailure = 22 -> StreamJsonRpc.JsonRpc.TraceEvents +StreamJsonRpc.NerdbankMessagePackFormatter +StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter +StreamJsonRpc.NerdbankMessagePackFormatter.AsyncEnumerableConverter.AsyncEnumerableConverter() -> void +StreamJsonRpc.NerdbankMessagePackFormatter.Deserialize(System.Buffers.ReadOnlySequence contentBuffer) -> StreamJsonRpc.Protocol.JsonRpcMessage! +StreamJsonRpc.NerdbankMessagePackFormatter.GetJsonText(StreamJsonRpc.Protocol.JsonRpcMessage! message) -> object! +StreamJsonRpc.NerdbankMessagePackFormatter.NerdbankMessagePackFormatter() -> void +StreamJsonRpc.NerdbankMessagePackFormatter.Serialize(System.Buffers.IBufferWriter! bufferWriter, StreamJsonRpc.Protocol.JsonRpcMessage! message) -> void +StreamJsonRpc.NerdbankMessagePackFormatter.TypeShapeProvider.get -> PolyType.ITypeShapeProvider! +StreamJsonRpc.NerdbankMessagePackFormatter.TypeShapeProvider.init -> void +StreamJsonRpc.NerdbankMessagePackFormatter.UserDataSerializer.get -> Nerdbank.MessagePack.MessagePackSerializer! +StreamJsonRpc.NerdbankMessagePackFormatter.UserDataSerializer.init -> void StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.EnumeratorResults() -> void StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Finished.get -> bool StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Finished.init -> void StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Values.get -> System.Collections.Generic.IReadOnlyList? -StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Values.init -> void \ No newline at end of file +StreamJsonRpc.Reflection.MessageFormatterEnumerableTracker.EnumeratorResults.Values.init -> void +StreamJsonRpc.Reflection.MessageFormatterProgressTracker.ProgressProxy +StreamJsonRpc.Reflection.MessageFormatterProgressTracker.ProgressProxy.ProgressProxy(StreamJsonRpc.JsonRpc! rpc, object! token, bool useNamedArguments) -> void +StreamJsonRpc.Reflection.MessageFormatterProgressTracker.ProgressProxy.Report(T value) -> void +virtual StreamJsonRpc.JsonRpc.LoadTypeTrimSafe(string! typeFullName, string? assemblyName) -> System.Type? diff --git a/test/Benchmarks/InvokeBenchmarks.cs b/test/Benchmarks/InvokeBenchmarks.cs index 78fc982c..48e0b1c6 100644 --- a/test/Benchmarks/InvokeBenchmarks.cs +++ b/test/Benchmarks/InvokeBenchmarks.cs @@ -5,17 +5,19 @@ using BenchmarkDotNet.Attributes; using Microsoft; using Nerdbank.Streams; +using PolyType; using StreamJsonRpc; namespace Benchmarks; [MemoryDiagnoser] -public class InvokeBenchmarks +[GenerateShapeFor] +public partial class InvokeBenchmarks { private JsonRpc clientRpc = null!; private JsonRpc serverRpc = null!; - [Params("JSON", "MessagePack")] + [Params("JSON", "MessagePack", "NerdbankMessagePack")] public string Formatter { get; set; } = null!; [GlobalSetup] @@ -35,6 +37,7 @@ IJsonRpcMessageHandler CreateHandler(IDuplexPipe pipe) { "JSON" => new HeaderDelimitedMessageHandler(pipe, new JsonMessageFormatter()), "MessagePack" => new LengthHeaderMessageHandler(pipe, new MessagePackFormatter()), + "NerdbankMessagePack" => new LengthHeaderMessageHandler(pipe, new NerdbankMessagePackFormatter() { TypeShapeProvider = ShapeProvider }), _ => throw Assumes.NotReachable(), }; } diff --git a/test/Benchmarks/NotifyBenchmarks.cs b/test/Benchmarks/NotifyBenchmarks.cs index 92fe6ef4..74a9c529 100644 --- a/test/Benchmarks/NotifyBenchmarks.cs +++ b/test/Benchmarks/NotifyBenchmarks.cs @@ -3,16 +3,18 @@ using BenchmarkDotNet.Attributes; using Microsoft; +using PolyType; using StreamJsonRpc; namespace Benchmarks; [MemoryDiagnoser] -public class NotifyBenchmarks +[GenerateShapeFor] +public partial class NotifyBenchmarks { private JsonRpc clientRpc = null!; - [Params("JSON", "MessagePack")] + [Params("JSON", "MessagePack", "NerdbankMessagePack")] public string Formatter { get; set; } = null!; [GlobalSetup] @@ -26,6 +28,7 @@ IJsonRpcMessageHandler CreateHandler(Stream pipe) { "JSON" => new HeaderDelimitedMessageHandler(pipe, new JsonMessageFormatter()), "MessagePack" => new LengthHeaderMessageHandler(pipe, pipe, new MessagePackFormatter()), + "NerdbankMessagePack" => new LengthHeaderMessageHandler(pipe, pipe, new NerdbankMessagePackFormatter() { TypeShapeProvider = ShapeProvider }), _ => throw Assumes.NotReachable(), }; } diff --git a/test/Benchmarks/ShortLivedConnectionBenchmarks.cs b/test/Benchmarks/ShortLivedConnectionBenchmarks.cs index fe58eab7..398bceaa 100644 --- a/test/Benchmarks/ShortLivedConnectionBenchmarks.cs +++ b/test/Benchmarks/ShortLivedConnectionBenchmarks.cs @@ -5,16 +5,18 @@ using BenchmarkDotNet.Attributes; using Microsoft; using Nerdbank.Streams; +using PolyType; using StreamJsonRpc; namespace Benchmarks; [MemoryDiagnoser] -public class ShortLivedConnectionBenchmarks +[GenerateShapeFor] +public partial class ShortLivedConnectionBenchmarks { private const int Iterations = 1000; - [Params("JSON", "MessagePack")] + [Params("JSON", "MessagePack", "NerdbankMessagePack")] public string Formatter { get; set; } = null!; [Benchmark(OperationsPerInvoke = Iterations)] @@ -39,6 +41,7 @@ IJsonRpcMessageHandler CreateHandler(IDuplexPipe pipe) { "JSON" => new HeaderDelimitedMessageHandler(pipe, new JsonMessageFormatter()), "MessagePack" => new LengthHeaderMessageHandler(pipe, new MessagePackFormatter()), + "NerdbankMessagePack" => new LengthHeaderMessageHandler(pipe, new NerdbankMessagePackFormatter() { TypeShapeProvider = ShapeProvider }), _ => throw Assumes.NotReachable(), }; } diff --git a/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs b/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs index 4828ab02..dbeef134 100644 --- a/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs +++ b/test/StreamJsonRpc.Tests/AssemblyLoadTests.cs @@ -5,8 +5,9 @@ using System.Diagnostics; using System.Runtime.CompilerServices; using Nerdbank.Streams; +using PolyType; -public class AssemblyLoadTests : TestBase +public partial class AssemblyLoadTests : TestBase { public AssemblyLoadTests(ITestOutputHelper logger) : base(logger) @@ -55,6 +56,27 @@ public void MessagePackDoesNotLoadNewtonsoftJsonUnnecessarily() } } + [Fact] + public void NerdbankMessagePackDoesNotLoadNewtonsoftJsonUnnecessarily() + { + AppDomain testDomain = CreateTestAppDomain(); + try + { + var driver = (AppDomainTestDriver)testDomain.CreateInstanceAndUnwrap(typeof(AppDomainTestDriver).Assembly.FullName, typeof(AppDomainTestDriver).FullName); + + this.PrintLoadedAssemblies(driver); + + driver.CreateNerdbankMessagePackConnection(); + + this.PrintLoadedAssemblies(driver); + driver.ThrowIfAssembliesLoaded("Newtonsoft.Json"); + } + finally + { + AppDomain.Unload(testDomain); + } + } + [Fact] public void MockFormatterDoesNotLoadJsonOrMessagePackUnnecessarily() { @@ -87,7 +109,7 @@ private IEnumerable PrintLoadedAssemblies(AppDomainTestDriver driver) } #pragma warning disable CA1812 // Avoid uninstantiated internal classes - private class AppDomainTestDriver : MarshalByRefObject + private partial class AppDomainTestDriver : MarshalByRefObject #pragma warning restore CA1812 // Avoid uninstantiated internal classes { #pragma warning disable CA1822 // Mark members as static -- all members must be instance for marshalability @@ -134,6 +156,11 @@ internal void CreateMessagePackConnection() var jsonRpc = new JsonRpc(new LengthHeaderMessageHandler(FullDuplexStream.CreatePipePair().Item1, new MessagePackFormatter())); } + internal void CreateNerdbankMessagePackConnection() + { + var jsonRpc = new JsonRpc(new LengthHeaderMessageHandler(FullDuplexStream.CreatePipePair().Item1, new NerdbankMessagePackFormatter() { TypeShapeProvider = Witness.ShapeProvider })); + } + #pragma warning restore CA1822 // Mark members as static private class MockFormatter : IJsonRpcMessageFormatter @@ -153,6 +180,9 @@ public void Serialize(System.Buffers.IBufferWriter bufferWriter, JsonRpcMe throw new NotImplementedException(); } } + + [GenerateShapeFor] + private partial class Witness; } } diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs new file mode 100644 index 00000000..dc69ec38 --- /dev/null +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableNerdbankMessagePackTests.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using PolyType; + +public partial class AsyncEnumerableNerdbankMessagePackTests : AsyncEnumerableTests +{ + public AsyncEnumerableNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override void InitializeFormattersAndHandlers() + { + this.serverMessageFormatter = new NerdbankMessagePackFormatter() { TypeShapeProvider = Witness.ShapeProvider }; + this.clientMessageFormatter = new NerdbankMessagePackFormatter() { TypeShapeProvider = Witness.ShapeProvider }; + } + + [GenerateShapeFor>] + [GenerateShapeFor>] + [GenerateShapeFor>] + [GenerateShapeFor>] + [GenerateShapeFor] + private partial class Witness; +} diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs index eb71a095..bd09cd19 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs @@ -11,6 +11,7 @@ using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; using Newtonsoft.Json; +using NBMP = Nerdbank.MessagePack; public abstract class AsyncEnumerableTests : TestBase, IAsyncLifetime { @@ -447,8 +448,10 @@ public async Task NotifyAsync_ThrowsIfAsyncEnumerableSent() // But for a notification there's no guarantee the server handles the message and no way to get an error back, // so it simply should not be allowed since the risk of memory leak is too high. var numbers = new int[] { 1, 2, 3 }.AsAsyncEnumerable(); - await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync(nameof(Server.PassInNumbersAsync), new object?[] { numbers })); - await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync(nameof(Server.PassInNumbersAsync), new object?[] { new { e = numbers } })); + Exception ex = await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync(nameof(Server.PassInNumbersAsync), new object?[] { numbers })); + this.Logger.WriteLine(ex.ToString()); + ex = await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync(nameof(Server.PassInNumbersAsync), new object?[] { new CompoundEnumerableResult { Enumeration = numbers } })); + this.Logger.WriteLine(ex.ToString()); } [Fact] @@ -618,6 +621,16 @@ private async Task ReturnEnumerable_AutomaticallyReleasedOnErrorF return weakReferenceToSource; } + [DataContract] + protected internal class CompoundEnumerableResult + { + [DataMember] + public string? Message { get; set; } + + [DataMember] + public IAsyncEnumerable? Enumeration { get; set; } + } + protected class Server : IServer { /// @@ -795,18 +808,9 @@ protected class Client : IClient public Task DoSomethingAsync(CancellationToken cancellationToken) => Task.CompletedTask; } - [DataContract] - protected class CompoundEnumerableResult - { - [DataMember] - public string? Message { get; set; } - - [DataMember] - public IAsyncEnumerable? Enumeration { get; set; } - } - [JsonConverter(typeof(ThrowingJsonConverter))] [MessagePackFormatter(typeof(ThrowingMessagePackFormatter))] + [NBMP.MessagePackConverter(typeof(ThrowingMessagePackNerdbankConverter))] protected class UnserializableType { } @@ -836,4 +840,17 @@ public void Serialize(ref MessagePackWriter writer, T value, MessagePackSerializ throw new Exception(); } } + + protected class ThrowingMessagePackNerdbankConverter : NBMP.MessagePackConverter + { + public override T? Read(ref NBMP.MessagePackReader reader, NBMP.SerializationContext context) + { + throw new Exception(); + } + + public override void Write(ref NBMP.MessagePackWriter writer, in T? value, NBMP.SerializationContext context) + { + throw new Exception(); + } + } } diff --git a/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs new file mode 100644 index 00000000..9f96c4b4 --- /dev/null +++ b/test/StreamJsonRpc.Tests/DisposableProxyNerdbankMessagePackTests.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; +using PolyType; + +public partial class DisposableProxyNerdbankMessagePackTests(ITestOutputHelper logger) : DisposableProxyTests(logger) +{ + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter() { TypeShapeProvider = Witness.ShapeProvider }; + + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + private partial class Witness; +} diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs new file mode 100644 index 00000000..a97c3ec6 --- /dev/null +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingNerdbankMessagePackTests.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.IO.Pipelines; +using PolyType; + +public partial class DuplexPipeMarshalingNerdbankMessagePackTests : DuplexPipeMarshalingTests +{ + public DuplexPipeMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override void InitializeFormattersAndHandlers() + { + NerdbankMessagePackFormatter serverFormatter = new() + { + MultiplexingStream = this.serverMx, + TypeShapeProvider = Witness.ShapeProvider, + }; + + NerdbankMessagePackFormatter clientFormatter = new() + { + MultiplexingStream = this.clientMx, + TypeShapeProvider = Witness.ShapeProvider, + }; + + this.serverMessageFormatter = serverFormatter; + this.clientMessageFormatter = clientFormatter; + } + + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + private partial class Witness; +} diff --git a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs index 7e30c229..5b9fc4f3 100644 --- a/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs +++ b/test/StreamJsonRpc.Tests/DuplexPipeMarshalingTests.cs @@ -8,9 +8,10 @@ using System.Text; using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; +using PolyType; using STJ = System.Text.Json.Serialization; -public abstract class DuplexPipeMarshalingTests : TestBase, IAsyncLifetime +public abstract partial class DuplexPipeMarshalingTests : TestBase, IAsyncLifetime { protected readonly Server server = new Server(); protected JsonRpc serverRpc; @@ -75,8 +76,8 @@ public async ValueTask InitializeAsync() this.serverRpc = new JsonRpc(serverHandler, this.server); this.clientRpc = new JsonRpc(clientHandler); - this.serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Information); - this.clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Information); + this.serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose); + this.clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose); this.serverRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); this.clientRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); @@ -338,7 +339,10 @@ public async Task PassStreamWithArgsAsSingleObject() MemoryStream ms = new(); ms.Write(new byte[] { 1, 2, 3 }, 0, 3); ms.Position = 0; - int bytesRead = await this.clientRpc.InvokeWithParameterObjectAsync(nameof(Server.AcceptStreamArgInFirstParam), new { innerStream = ms }, this.TimeoutToken); + int bytesRead = await this.clientRpc.InvokeWithParameterObjectAsync( + nameof(Server.AcceptStreamArgInFirstParam), + new Dictionary { ["innerStream"] = ms }, + this.TimeoutToken); Assert.Equal(ms.Length, bytesRead); } @@ -477,6 +481,22 @@ public async Task ClientCanSendTwoWayStreamToServer(bool serverUsesStream) streamPair.Item1.Dispose(); } + [Theory] + [CombinatorialData] + public async Task ClientCanSendTwoWayStreamToServer_WithExplicitTypes(bool serverUsesStream) + { + (Stream, Stream) streamPair = FullDuplexStream.CreatePair(); + Task twoWayCom = TwoWayTalkAsync(streamPair.Item1, writeOnOdd: true, this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync( + serverUsesStream ? nameof(Server.TwoWayStreamAsArg) : nameof(Server.TwoWayPipeAsArg), + [false, streamPair.Item2], + [typeof(bool), typeof(Stream)], + this.TimeoutToken); + await twoWayCom.WithCancellation(this.TimeoutToken); // rethrow any exceptions. + + streamPair.Item1.Dispose(); + } + [Fact] public async Task PipeRemainsOpenAfterSuccessfulServerResult() { @@ -555,7 +575,7 @@ public async Task NotifyWithPipe_IsRejectedAtClient() { (IDuplexPipe, IDuplexPipe) duplexPipes = FullDuplexStream.CreatePipePair(); var ex = await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync(nameof(Server.AcceptReadableStream), "fileName", duplexPipes.Item2)); - Assert.IsType(ex.InnerException); + Assert.IsType(ex.GetBaseException()); } /// @@ -743,7 +763,9 @@ public async Task OverloadedMethod(bool writeOnOdd, IDuplexPipe pipe, string mes public void OverloadedMethod(bool foo, int value, string[] values) => Assert.NotNull(values); } - protected class Server +#pragma warning disable SA1202 // Elements should be ordered by access + public class Server +#pragma warning restore SA1202 // Elements should be ordered by access { internal Task? ChatLaterTask { get; private set; } @@ -1058,7 +1080,10 @@ protected override void Dispose(bool disposing) } [DataContract] - protected class StreamContainingClass + [GenerateShape] +#pragma warning disable SA1202 // Elements should be ordered by access + public partial class StreamContainingClass +#pragma warning restore SA1202 // Elements should be ordered by access { [DataMember] private Stream innerStream; @@ -1069,6 +1094,7 @@ public StreamContainingClass(Stream innerStream) } [STJ.JsonPropertyName("innerStream")] + [PropertyShape(Name = "innerStream")] public Stream InnerStream => this.innerStream; } } diff --git a/test/StreamJsonRpc.Tests/JsonRpcMessagePackCSharpLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcMessagePackCSharpLengthTests.cs new file mode 100644 index 00000000..db8b6d30 --- /dev/null +++ b/test/StreamJsonRpc.Tests/JsonRpcMessagePackCSharpLengthTests.cs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using MessagePack; +using MessagePack.Formatters; +using MessagePack.Resolvers; + +public class JsonRpcMessagePackCSharpLengthTests(ITestOutputHelper logger) : JsonRpcMessagePackLengthTests(logger) +{ + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + protected override void InitializeFormattersAndHandlers( + Stream serverStream, + Stream clientStream, + out IJsonRpcMessageFormatter serverMessageFormatter, + out IJsonRpcMessageFormatter clientMessageFormatter, + out IJsonRpcMessageHandler serverMessageHandler, + out IJsonRpcMessageHandler clientMessageHandler, + bool controlledFlushingClient) + { + serverMessageFormatter = new MessagePackFormatter(); + clientMessageFormatter = new MessagePackFormatter(); + + var options = MessagePackFormatter.DefaultUserDataSerializationOptions + .WithResolver(CompositeResolver.Create( + new IMessagePackFormatter[] { new UnserializableTypeFormatter(), new TypeThrowsWhenDeserializedFormatter(), new CustomExtensionFormatter() }, + new IFormatterResolver[] { StandardResolverAllowPrivate.Instance })); + ((MessagePackFormatter)serverMessageFormatter).SetMessagePackSerializerOptions(options); + ((MessagePackFormatter)clientMessageFormatter).SetMessagePackSerializerOptions(options); + + serverMessageHandler = new LengthHeaderMessageHandler(serverStream, serverStream, serverMessageFormatter); + clientMessageHandler = controlledFlushingClient + ? new DelayedFlushingHandler(clientStream, clientMessageFormatter) + : new LengthHeaderMessageHandler(clientStream, clientStream, clientMessageFormatter); + } + + private class CustomExtensionFormatter : IMessagePackFormatter + { + public CustomExtensionType? Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + { + if (reader.TryReadNil()) + { + return null; + } + + if (reader.ReadExtensionFormat() is { Header: { TypeCode: 1, Length: 0 } }) + { + return new(); + } + + throw new Exception("Unexpected extension header."); + } + + public void Serialize(ref MessagePackWriter writer, CustomExtensionType? value, MessagePackSerializerOptions options) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + writer.WriteExtensionFormat(new ExtensionResult(1, default(Memory))); + } + } + } + + private class UnserializableTypeFormatter : IMessagePackFormatter + { + public CustomSerializedType Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + { + return new CustomSerializedType { Value = reader.ReadString() }; + } + + public void Serialize(ref MessagePackWriter writer, CustomSerializedType value, MessagePackSerializerOptions options) + { + writer.Write(value?.Value); + } + } + + private class TypeThrowsWhenDeserializedFormatter : IMessagePackFormatter + { + public TypeThrowsWhenDeserialized Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + { + throw CreateExceptionToBeThrownByDeserializer(); + } + + public void Serialize(ref MessagePackWriter writer, TypeThrowsWhenDeserialized value, MessagePackSerializerOptions options) + { + writer.WriteArrayHeader(0); + } + } +} diff --git a/test/StreamJsonRpc.Tests/JsonRpcMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcMessagePackLengthTests.cs index bcb2ea2f..255f2415 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcMessagePackLengthTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcMessagePackLengthTests.cs @@ -4,16 +4,11 @@ using System.Runtime.CompilerServices; using MessagePack; using MessagePack.Formatters; -using MessagePack.Resolvers; using Microsoft.VisualStudio.Threading; +using PolyType; -public class JsonRpcMessagePackLengthTests : JsonRpcTests +public abstract partial class JsonRpcMessagePackLengthTests(ITestOutputHelper logger) : JsonRpcTests(logger) { - public JsonRpcMessagePackLengthTests(ITestOutputHelper logger) - : base(logger) - { - } - internal interface IMessagePackServer { Task ReturnUnionTypeAsync(CancellationToken cancellationToken); @@ -29,8 +24,6 @@ internal interface IMessagePackServer Task IsExtensionArgNonNull(CustomExtensionType extensionValue); } - protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); - [Fact] public override async Task CanPassAndCallPrivateMethodsObjects() { @@ -47,7 +40,7 @@ public async Task ExceptionControllingErrorData() IDictionary? data = (IDictionary?)exception.ErrorData; Assert.NotNull(data); - object myCustomData = data["myCustomData"]; + object myCustomData = data["MyCustomData"]; string actual = (string)myCustomData; Assert.Equal("hi", actual); } @@ -384,105 +377,28 @@ public async Task VerboseLoggingDoesNotFailWhenArgsDoNotDeserializePrimitively(b Assert.True(await clientProxy.IsExtensionArgNonNull(new CustomExtensionType())); } - protected override void InitializeFormattersAndHandlers( - Stream serverStream, - Stream clientStream, - out IJsonRpcMessageFormatter serverMessageFormatter, - out IJsonRpcMessageFormatter clientMessageFormatter, - out IJsonRpcMessageHandler serverMessageHandler, - out IJsonRpcMessageHandler clientMessageHandler, - bool controlledFlushingClient) - { - serverMessageFormatter = new MessagePackFormatter(); - clientMessageFormatter = new MessagePackFormatter(); - - var options = MessagePackFormatter.DefaultUserDataSerializationOptions - .WithResolver(CompositeResolver.Create( - new IMessagePackFormatter[] { new UnserializableTypeFormatter(), new TypeThrowsWhenDeserializedFormatter(), new CustomExtensionFormatter() }, - new IFormatterResolver[] { StandardResolverAllowPrivate.Instance })); - ((MessagePackFormatter)serverMessageFormatter).SetMessagePackSerializerOptions(options); - ((MessagePackFormatter)clientMessageFormatter).SetMessagePackSerializerOptions(options); - - serverMessageHandler = new LengthHeaderMessageHandler(serverStream, serverStream, serverMessageFormatter); - clientMessageHandler = controlledFlushingClient - ? new DelayedFlushingHandler(clientStream, clientMessageFormatter) - : new LengthHeaderMessageHandler(clientStream, clientStream, clientMessageFormatter); - } - protected override object[] CreateFormatterIntrinsicParamsObject(string arg) => []; [MessagePackObject] [Union(0, typeof(UnionDerivedClass))] - public abstract class UnionBaseClass + [GenerateShape] + [DerivedTypeShape(typeof(UnionDerivedClass))] + public abstract partial class UnionBaseClass { } + [GenerateShape] [MessagePackObject] - public class UnionDerivedClass : UnionBaseClass + public partial class UnionDerivedClass : UnionBaseClass { } - internal class CustomExtensionType + [GenerateShape] + internal partial class CustomExtensionType { } - private class CustomExtensionFormatter : IMessagePackFormatter - { - public CustomExtensionType? Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) - { - if (reader.TryReadNil()) - { - return null; - } - - if (reader.ReadExtensionFormat() is { Header: { TypeCode: 1, Length: 0 } }) - { - return new(); - } - - throw new Exception("Unexpected extension header."); - } - - public void Serialize(ref MessagePackWriter writer, CustomExtensionType? value, MessagePackSerializerOptions options) - { - if (value is null) - { - writer.WriteNil(); - } - else - { - writer.WriteExtensionFormat(new ExtensionResult(1, default(Memory))); - } - } - } - - private class UnserializableTypeFormatter : IMessagePackFormatter - { - public CustomSerializedType Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) - { - return new CustomSerializedType { Value = reader.ReadString() }; - } - - public void Serialize(ref MessagePackWriter writer, CustomSerializedType value, MessagePackSerializerOptions options) - { - writer.Write(value?.Value); - } - } - - private class TypeThrowsWhenDeserializedFormatter : IMessagePackFormatter - { - public TypeThrowsWhenDeserialized Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) - { - throw CreateExceptionToBeThrownByDeserializer(); - } - - public void Serialize(ref MessagePackWriter writer, TypeThrowsWhenDeserialized value, MessagePackSerializerOptions options) - { - writer.WriteArrayHeader(0); - } - } - - private class MessagePackServer : IMessagePackServer + internal class MessagePackServer : IMessagePackServer { internal UnionBaseClass? ReceivedValue { get; private set; } @@ -516,7 +432,7 @@ public async IAsyncEnumerable GetAsyncEnumerableOfUnionType([Enu public Task IsExtensionArgNonNull(CustomExtensionType extensionValue) => Task.FromResult(extensionValue is not null); } - private class DelayedFlushingHandler : LengthHeaderMessageHandler, IControlledFlushHandler + protected class DelayedFlushingHandler : LengthHeaderMessageHandler, IControlledFlushHandler { public DelayedFlushingHandler(Stream stream, IJsonRpcMessageFormatter formatter) : base(stream, stream, formatter) diff --git a/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs new file mode 100644 index 00000000..1acb9cf0 --- /dev/null +++ b/test/StreamJsonRpc.Tests/JsonRpcNerdbankMessagePackLengthTests.cs @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; +using PolyType; + +public partial class JsonRpcNerdbankMessagePackLengthTests(ITestOutputHelper logger) : JsonRpcMessagePackLengthTests(logger) +{ + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + protected override void InitializeFormattersAndHandlers( + Stream serverStream, + Stream clientStream, + out IJsonRpcMessageFormatter serverMessageFormatter, + out IJsonRpcMessageFormatter clientMessageFormatter, + out IJsonRpcMessageHandler serverMessageHandler, + out IJsonRpcMessageHandler clientMessageHandler, + bool controlledFlushingClient) + { + NerdbankMessagePackFormatter serverFormatter = new() + { + TypeShapeProvider = Witness.ShapeProvider, + UserDataSerializer = (NerdbankMessagePackFormatter.DefaultSerializer with + { + Converters = + [ + ..NerdbankMessagePackFormatter.DefaultSerializer.Converters, + new CustomExtensionConverter(), + new UnserializableTypeConverter(), + new TypeThrowsWhenDeserializedConverter(), + ], + }).WithGuidConverter(), + }; + + NerdbankMessagePackFormatter clientFormatter = new() + { + TypeShapeProvider = Witness.ShapeProvider, + UserDataSerializer = serverFormatter.UserDataSerializer, + }; + + serverMessageFormatter = serverFormatter; + clientMessageFormatter = clientFormatter; + + serverMessageHandler = new LengthHeaderMessageHandler(serverStream, serverStream, serverMessageFormatter); + clientMessageHandler = controlledFlushingClient + ? new DelayedFlushingHandler(clientStream, clientMessageFormatter) + : new LengthHeaderMessageHandler(clientStream, clientStream, clientMessageFormatter); + } + + protected override object[] CreateFormatterIntrinsicParamsObject(string arg) => []; + + internal class CustomExtensionConverter : MessagePackConverter + { + public override CustomExtensionType? Read(ref MessagePackReader reader, SerializationContext context) + { + if (reader.TryReadNil()) + { + return null; + } + + if (reader.ReadExtensionHeader() is { TypeCode: 1, Length: 0 }) + { + return new(); + } + + throw new Exception("Unexpected extension header."); + } + + public override void Write(ref MessagePackWriter writer, in CustomExtensionType? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + } + else + { + writer.Write(new Extension(1, default(Memory))); + } + } + } + + private class UnserializableTypeConverter : MessagePackConverter + { + public override CustomSerializedType Read(ref MessagePackReader reader, SerializationContext context) + { + return new CustomSerializedType { Value = reader.ReadString() }; + } + + public override void Write(ref MessagePackWriter writer, in CustomSerializedType? value, SerializationContext context) + { + writer.Write(value?.Value); + } + } + + private class TypeThrowsWhenDeserializedConverter : MessagePackConverter + { + public override TypeThrowsWhenDeserialized Read(ref MessagePackReader reader, SerializationContext context) + { + throw CreateExceptionToBeThrownByDeserializer(); + } + + public override void Write(ref MessagePackWriter writer, in TypeThrowsWhenDeserialized? value, SerializationContext context) + { + writer.WriteArrayHeader(0); + } + } + + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor>] + [GenerateShapeFor>] + [GenerateShapeFor>] + [GenerateShapeFor>] + [GenerateShapeFor>] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + private partial class Witness; +} diff --git a/test/StreamJsonRpc.Tests/JsonRpcTests.cs b/test/StreamJsonRpc.Tests/JsonRpcTests.cs index 97513375..9e6f9319 100644 --- a/test/StreamJsonRpc.Tests/JsonRpcTests.cs +++ b/test/StreamJsonRpc.Tests/JsonRpcTests.cs @@ -9,11 +9,11 @@ using System.Text; using Microsoft.VisualStudio.Threading; using Nerdbank.Streams; -using Newtonsoft.Json.Linq; +using PolyType; using JsonNET = Newtonsoft.Json; using STJ = System.Text.Json.Serialization; -public abstract class JsonRpcTests : TestBase +public abstract partial class JsonRpcTests : TestBase { #pragma warning disable SA1310 // Field names should not contain underscore protected const int COR_E_UNAUTHORIZEDACCESS = unchecked((int)0x80070005); @@ -109,21 +109,23 @@ public async Task AddLocalRpcTarget_OfT_InterfaceOnly() this.serverRpc = new JsonRpc(this.serverMessageHandler); this.serverRpc.AddLocalRpcTarget(this.server, null); - this.serverRpc.StartListening(); - this.clientRpc = new JsonRpc(this.clientMessageHandler); + + this.AddTracing(); + + this.serverRpc.StartListening(); this.clientRpc.StartListening(); // Verify that members on the interface and base interfaces are callable. - await this.clientRpc.InvokeAsync("AnotherName", new object[] { "my -name" }); - await this.clientRpc.InvokeAsync(nameof(IServerDerived.MethodOnDerived)); + await this.clientRpc.InvokeAsync("AnotherName", new object[] { "my -name" }).WithCancellation(this.TimeoutToken); + await this.clientRpc.InvokeAsync(nameof(IServerDerived.MethodOnDerived)).WithCancellation(this.TimeoutToken); // Verify that explicitly interface implementations of members on the interface are callable. - Assert.Equal(3, await this.clientRpc.InvokeAsync(nameof(IServer.Add_ExplicitInterfaceImplementation), 1, 2)); + Assert.Equal(3, await this.clientRpc.InvokeAsync(nameof(IServer.Add_ExplicitInterfaceImplementation), 1, 2).WithCancellation(this.TimeoutToken)); // Verify that members NOT on the interface are not callable, whether public or internal. - await Assert.ThrowsAsync(() => this.clientRpc.InvokeAsync(nameof(Server.AsyncMethod), new object[] { "my-name" })); - await Assert.ThrowsAsync(() => this.clientRpc.InvokeAsync(nameof(Server.InternalMethod))); + await Assert.ThrowsAsync(() => this.clientRpc.InvokeAsync(nameof(Server.AsyncMethod), new object[] { "my-name" })).WithCancellation(this.TimeoutToken); + await Assert.ThrowsAsync(() => this.clientRpc.InvokeAsync(nameof(Server.InternalMethod))).WithCancellation(this.TimeoutToken); } [Fact] @@ -439,6 +441,7 @@ public async Task CanCallAsyncMethodThatThrowsExceptionWithoutDeserializingConst this.serverRpc.AllowModificationWhileListening = true; this.clientRpc.ExceptionStrategy = exceptionStrategy; this.serverRpc.ExceptionStrategy = exceptionStrategy; + this.clientRpc.AddLoadableType(typeof(ExceptionMissingDeserializingConstructor)); RemoteInvocationException exception = await Assert.ThrowsAnyAsync(() => this.clientRpc.InvokeAsync(nameof(Server.AsyncMethodThatThrowsAnExceptionWithoutDeserializingConstructor))); var errorData = Assert.IsType(exception.DeserializedErrorData); @@ -478,6 +481,7 @@ public async Task ThrowCustomExceptionThatImplementsISerializableProperly() this.serverRpc.AllowModificationWhileListening = true; this.clientRpc.ExceptionStrategy = ExceptionProcessing.ISerializable; this.serverRpc.ExceptionStrategy = ExceptionProcessing.ISerializable; + this.clientRpc.AddLoadableType(typeof(PrivateSerializableException)); RemoteInvocationException exception = await Assert.ThrowsAsync(() => this.clientRpc.InvokeAsync(nameof(Server.ThrowPrivateSerializableException))); Assert.IsType(exception.InnerException); @@ -497,7 +501,7 @@ public async Task CanCallOverloadedMethod() public async Task ThrowsIfCannotFindMethod() { await Assert.ThrowsAsync(() => this.clientRpc.InvokeAsync("missingMethod", 50)); - await Assert.ThrowsAsync(() => this.clientRpc.InvokeAsync(nameof(Server.OverloadedMethod), new { X = 100 })); + await Assert.ThrowsAsync(() => this.clientRpc.InvokeAsync(nameof(Server.OverloadedMethod), new XAndYProperties { x = 100 })); } [Fact] @@ -829,9 +833,9 @@ public async Task NullableParameters() [Fact] public async Task NullableReturnType() { - int? result = await this.clientRpc.InvokeAsync(nameof(Server.MethodReturnsNullableInt), 0); + int? result = await this.clientRpc.InvokeAsync(nameof(Server.MethodReturnsNullableInt), 0).WithCancellation(this.TimeoutToken); Assert.Null(result); - result = await this.clientRpc.InvokeAsync(nameof(Server.MethodReturnsNullableInt), 5); + result = await this.clientRpc.InvokeAsync(nameof(Server.MethodReturnsNullableInt), 5).WithCancellation(this.TimeoutToken); Assert.Equal(5, result); } @@ -1384,7 +1388,7 @@ public async Task ProgressParameterHasStableCompletionRelativeToRpcTask() received = n; }); - Task result = this.clientRpc.InvokeAsync(nameof(Server.MethodWithProgressParameter), [progress]); + Task result = this.clientRpc.InvokeWithCancellationAsync(nameof(Server.MethodWithProgressParameter), [progress], [typeof(IProgress)], this.TimeoutToken); await Assert.ThrowsAsync(() => result.WithCancellation(ExpectedTimeoutToken)); evt.Set(); await result.WithCancellation(this.TimeoutToken); @@ -1394,7 +1398,7 @@ public async Task ProgressParameterHasStableCompletionRelativeToRpcTask() public async Task ReportProgressWithUnserializableData_LeavesTraceEvidence() { var progress = new Progress(); - await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.MethodWithUnserializableProgressType), new object[] { progress }, cancellationToken: this.TimeoutToken); + await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.MethodWithUnserializableProgressType), new object[] { progress }, [typeof(IProgress)], cancellationToken: this.TimeoutToken); // Verify that the trace explains what went wrong with the original exception message. while (!this.serverTraces.Messages.Any(m => m.Contains("Can't touch this"))) @@ -2255,6 +2259,8 @@ public async Task ExceptionTreeThrownFromServerIsDeserializedAtClient(ExceptionP this.serverRpc.AllowModificationWhileListening = true; this.clientRpc.ExceptionStrategy = exceptionStrategy; this.serverRpc.ExceptionStrategy = exceptionStrategy; + this.clientRpc.AddLoadableType(typeof(FileNotFoundException)); + this.clientRpc.AddLoadableType(typeof(ApplicationException)); var exception = await Assert.ThrowsAnyAsync(() => this.clientRpc.InvokeAsync(nameof(Server.MethodThatThrowsDeeplyNestedExceptions))); @@ -2408,6 +2414,9 @@ public async Task DisposeOnDisconnect_VsThreadingAsyncDisposable(bool throwFromD [Fact] public async Task SerializableExceptions() { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLoadableType(typeof(FileNotFoundException)); + // Create a full exception with inner exceptions. We have to throw so that its stacktrace is initialized. Exception? exceptionToSend; try @@ -2514,6 +2523,8 @@ public async Task ExceptionCanDeserializeExtensibility() { this.serverRpc.AllowModificationWhileListening = true; this.serverRpc.ExceptionOptions = new ExceptionFilter(ExceptionSettings.UntrustedData.RecursionLimit); + this.serverRpc.AddLoadableType(typeof(TaskCanceledException)); + this.serverRpc.AddLoadableType(typeof(OperationCanceledException)); Exception originalException = new TaskCanceledException(); await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.SendException), new[] { originalException }, new[] { typeof(Exception) }, this.TimeoutToken); @@ -2525,6 +2536,9 @@ public async Task ExceptionCanDeserializeExtensibility() [Fact] public async Task ArgumentOutOfRangeException_WithNullArgValue() { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLoadableType(typeof(ArgumentOutOfRangeException)); + Exception? exceptionToSend = new ArgumentOutOfRangeException("t", "msg"); await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.SendException), new[] { exceptionToSend }, new[] { typeof(Exception) }, this.TimeoutToken); @@ -2538,6 +2552,9 @@ public async Task ArgumentOutOfRangeException_WithNullArgValue() [Fact] public async Task ArgumentOutOfRangeException_WithStringArgValue() { + this.serverRpc.AllowModificationWhileListening = true; + this.serverRpc.AddLoadableType(typeof(ArgumentOutOfRangeException)); + Exception? exceptionToSend = new ArgumentOutOfRangeException("t", "argValue", "msg"); await this.clientRpc.InvokeWithCancellationAsync(nameof(Server.SendException), new[] { exceptionToSend }, new[] { typeof(Exception) }, this.TimeoutToken); @@ -2545,7 +2562,7 @@ public async Task ArgumentOutOfRangeException_WithStringArgValue() // Make sure the exception is its own unique (deserialized) instance, but equal by value. Assert.NotSame(this.server.ReceivedException, exceptionToSend); - if (this.clientMessageFormatter is MessagePackFormatter) + if (this.clientMessageFormatter is MessagePackFormatter or NerdbankMessagePackFormatter) { // MessagePack cannot (safely) deserialize a typeless value like ArgumentOutOfRangeException.ActualValue, // So assert that a placeholder was put there instead. @@ -2605,14 +2622,7 @@ public async Task SerializableExceptions_RedirectType() this.serverRpc.AddLocalRpcTarget(this.server); - this.serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose | SourceLevels.ActivityTracing); - this.clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose | SourceLevels.ActivityTracing); - - this.serverRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); - this.clientRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); - - this.serverRpc.TraceSource.Listeners.Add(this.serverTraces = new CollectingTraceListener()); - this.clientRpc.TraceSource.Listeners.Add(this.clientTraces = new CollectingTraceListener()); + this.AddTracing(); this.serverRpc.StartListening(); this.clientRpc.StartListening(); @@ -3172,6 +3182,11 @@ protected void ReinitializeRpcWithoutListening(bool controlledFlushingClient = f this.serverRpc = new JsonRpc(this.serverMessageHandler, this.server); this.clientRpc = new JsonRpc(this.clientMessageHandler); + this.AddTracing(); + } + + protected void AddTracing() + { this.serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose | SourceLevels.ActivityTracing); this.clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose | SourceLevels.ActivityTracing); @@ -3291,8 +3306,8 @@ private async Task NotifyAsyncWithProgressParameter_NoMemoryLeakC WeakReference weakRef = new WeakReference(progress); - var ex = await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync(nameof(Server.MethodWithProgressParameter), new { p = progress })); - Assert.IsType(ex.InnerException); + var ex = await Assert.ThrowsAnyAsync(() => this.clientRpc.NotifyAsync(nameof(Server.MethodWithProgressParameter), new XAndYPropertiesWithProgress { p = progress })); + Assert.IsType(ex.GetBaseException()); await progress.WaitAsync(); @@ -3316,7 +3331,7 @@ public class BaseClass } #pragma warning disable CA1801 // use all parameters - public class Server : BaseClass, IServerDerived + public partial class Server : BaseClass, IServerDerived { internal const string ExceptionMessage = "some message"; internal const string ThrowAfterCancellationMessage = "Throw after cancellation"; @@ -3815,7 +3830,7 @@ public void MethodWithRefParameter(ref int i) public void ThrowLocalRpcException() { - throw new LocalRpcException { ErrorCode = 2, ErrorData = new { myCustomData = "hi" } }; + throw new LocalRpcException { ErrorCode = 2, ErrorData = new CustomErrorData("hi") }; } public void SendException(Exception? ex) @@ -3861,6 +3876,9 @@ internal void InternalMethodWithAttribute() internal void InternalIgnoredMethod() { } + + [GenerateShape, MessagePack.MessagePackObject(keyAsPropertyName: true)] + internal partial record CustomErrorData(string MyCustomData); } #pragma warning restore CA1801 // use all parameters @@ -3907,10 +3925,12 @@ public ValueTask DisposeAsync() } [DataContract] - public class ParamsObjectWithCustomNames + [GenerateShape] + public partial class ParamsObjectWithCustomNames { [DataMember(Name = "argument")] [STJ.JsonPropertyName("argument")] + [PropertyShape(Name = "argument")] public string? TheArgument { get; set; } } @@ -3967,27 +3987,32 @@ public void IgnoredMethod() } [DataContract] - public class Foo + [GenerateShape] + public partial class Foo { [DataMember(Order = 0, IsRequired = true)] [STJ.JsonRequired, STJ.JsonPropertyOrder(0)] + [PropertyShape(Order = 0, IsRequired = true)] public string? Bar { get; set; } [DataMember(Order = 1)] [STJ.JsonPropertyOrder(1)] + [PropertyShape(Order = 1)] public int Bazz { get; set; } } - public class CustomSerializedType + [GenerateShape] + public partial class CustomSerializedType { // Ignore this so default serializers will drop it, proving that custom serializers were used if the value propagates. [JsonNET.JsonIgnore] [IgnoreDataMember] + [PropertyShape(Ignore = true)] public string? Value { get; set; } } - [Serializable, DataContract] - public class CustomISerializableData : ISerializable + [Serializable, DataContract, GenerateShape] + public partial class CustomISerializableData : ISerializable { [MessagePack.SerializationConstructor] public CustomISerializableData(int major) @@ -4025,7 +4050,8 @@ public class TypeThrowsWhenDeserialized } [DataContract] - public class XAndYProperties + [GenerateShape] + public partial class XAndYProperties { // We disable SA1300 because we must use lowercase members as required to match the parameter names. #pragma warning disable SA1300 // Accessible properties should begin with upper-case letter @@ -4300,6 +4326,16 @@ public JsonRpcThatSubstitutesType(IJsonRpcMessageHandler messageHandler) return base.LoadType(typeFullName, assemblyName); } + + protected override Type? LoadTypeTrimSafe(string typeFullName, string? assemblyName) + { + if (typeFullName == typeof(ArgumentOutOfRangeException).FullName) + { + return typeof(ArgumentException); + } + + return base.LoadTypeTrimSafe(typeFullName, assemblyName); + } } private class ControlledProgress(Action reported) : IProgress diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs new file mode 100644 index 00000000..97e0d02f --- /dev/null +++ b/test/StreamJsonRpc.Tests/MarshalableProxyNerdbankMessagePackTests.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Nerdbank.MessagePack; +using PolyType; + +public partial class MarshalableProxyNerdbankMessagePackTests : MarshalableProxyTests +{ + public MarshalableProxyNerdbankMessagePackTests(ITestOutputHelper logger) + : base(logger) + { + } + + protected override Type FormatterExceptionType => typeof(MessagePackSerializationException); + + protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter() { TypeShapeProvider = Witness.ShapeProvider }; + + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor>] + [GenerateShapeFor>] + [GenerateShapeFor>>] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor] + [GenerateShapeFor>] + [GenerateShapeFor>] + [GenerateShapeFor>] + private partial class Witness; +} diff --git a/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs b/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs index a69ddef7..dafac241 100644 --- a/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs +++ b/test/StreamJsonRpc.Tests/MarshalableProxyTests.cs @@ -6,8 +6,10 @@ using System.Runtime.Serialization; using MessagePack; using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; using Nerdbank.Streams; using Newtonsoft.Json; +using PolyType; /// /// Tests the proxying of interfaces marked with . @@ -42,8 +44,22 @@ protected MarshalableProxyTests(ITestOutputHelper logger) [RpcMarshalable] [JsonConverter(typeof(MarshalableConverter))] [MessagePackFormatter(typeof(MarshalableFormatter))] + [MessagePackConverter(typeof(MarshalableNerdbankConverter))] public interface IMarshalableAndSerializable : IMarshalable { + internal class MarshalableNerdbankConverter : Nerdbank.MessagePack.MessagePackConverter + { + public override IMarshalableAndSerializable? Read(ref Nerdbank.MessagePack.MessagePackReader reader, Nerdbank.MessagePack.SerializationContext context) + { + throw new NotImplementedException(); + } + + public override void Write(ref Nerdbank.MessagePack.MessagePackWriter writer, in IMarshalableAndSerializable? value, Nerdbank.MessagePack.SerializationContext context) + { + throw new NotImplementedException(); + } + } + private class MarshalableConverter : JsonConverter { public override bool CanConvert(Type objectType) @@ -64,12 +80,12 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer private class MarshalableFormatter : MessagePack.Formatters.IMessagePackFormatter { - public IMarshalableAndSerializable Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) + public IMarshalableAndSerializable Deserialize(ref MessagePack.MessagePackReader reader, MessagePackSerializerOptions options) { throw new NotImplementedException(); } - public void Serialize(ref MessagePackWriter writer, IMarshalableAndSerializable value, MessagePackSerializerOptions options) + public void Serialize(ref MessagePack.MessagePackWriter writer, IMarshalableAndSerializable value, MessagePackSerializerOptions options) { throw new NotImplementedException(); } @@ -962,6 +978,7 @@ public async Task RpcMarshalable_CallScopedLifetime_AsyncEnumerableThrown() this.serverRpc.AllowModificationWhileListening = true; this.clientRpc.ExceptionStrategy = ExceptionProcessing.ISerializable; this.serverRpc.ExceptionStrategy = ExceptionProcessing.ISerializable; + this.clientRpc.AddLoadableType(typeof(ExceptionWithAsyncEnumerable)); MarshalableAndSerializable marshaled = new(); var outerException = await Assert.ThrowsAsync(() => this.client.CallScopedMarshalableThrowsWithAsyncEnumerable(marshaled)); @@ -1243,8 +1260,10 @@ public Data(Action? disposeAction) [DataMember] public int Value { get; set; } + [PropertyShape(Ignore = true)] public bool IsDisposed { get; private set; } + [PropertyShape(Ignore = true)] public bool DoSomethingCalled { get; private set; } public void Dispose() diff --git a/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs new file mode 100644 index 00000000..cfd284c4 --- /dev/null +++ b/test/StreamJsonRpc.Tests/NerdbankMessagePackFormatterTests.cs @@ -0,0 +1,478 @@ +using System.Diagnostics; +using System.Runtime.Serialization; +using Microsoft.VisualStudio.Threading; +using Nerdbank.MessagePack; +using Nerdbank.Streams; +using PolyType; +using PolyType.Abstractions; +using PolyType.ReflectionProvider; + +public partial class NerdbankMessagePackFormatterTests : FormatterTestBase +{ + private static readonly MessagePackSerializer AnonymousTypeSerializer = new() + { + Converters = [new ObjectConverter()], + }; + + public NerdbankMessagePackFormatterTests(ITestOutputHelper logger) + : base(logger) + { + } + + [Fact] + public void JsonRpcRequest_PositionalArgs() + { + var original = new JsonRpcRequest + { + RequestId = new RequestId(5), + Method = "test", + ArgumentsList = new object[] { 5, "hi", new CustomType { Age = 8 } }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Method, actual.Method); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 0, typeof(int), out object? actualArg0)); + Assert.Equal(original.ArgumentsList[0], actualArg0); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 1, typeof(string), out object? actualArg1)); + Assert.Equal(original.ArgumentsList[1], actualArg1); + + Assert.True(actual.TryGetArgumentByNameOrIndex(null, 2, typeof(CustomType), out object? actualArg2)); + Assert.Equal(((CustomType?)original.ArgumentsList[2])!.Age, ((CustomType)actualArg2!).Age); + } + + [Fact] + public void JsonRpcRequest_NamedArgs() + { + var original = new JsonRpcRequest + { + RequestId = new RequestId(5), + Method = "test", + NamedArguments = new Dictionary + { + { "Number", 5 }, + { "Message", "hi" }, + { "Custom", new CustomType { Age = 8 } }, + }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Method, actual.Method); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Number", -1, typeof(int), out object? actualArg0)); + Assert.Equal(original.NamedArguments["Number"], actualArg0); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Message", -1, typeof(string), out object? actualArg1)); + Assert.Equal(original.NamedArguments["Message"], actualArg1); + + Assert.True(actual.TryGetArgumentByNameOrIndex("Custom", -1, typeof(CustomType), out object? actualArg2)); + Assert.Equal(((CustomType?)original.NamedArguments["Custom"])!.Age, ((CustomType)actualArg2!).Age); + } + + [Fact] + public void JsonRpcResult() + { + var original = new JsonRpcResult + { + RequestId = new RequestId(5), + Result = new CustomType { Age = 7 }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(((CustomType?)original.Result)!.Age, actual.GetResult().Age); + } + + [Fact] + public void JsonRpcError() + { + var original = new JsonRpcError + { + RequestId = new RequestId(5), + Error = new JsonRpcError.ErrorDetail + { + Code = JsonRpcErrorCode.InvocationError, + Message = "Oops", + Data = new CustomType { Age = 15 }, + }, + }; + + var actual = this.Roundtrip(original); + Assert.Equal(original.RequestId, actual.RequestId); + Assert.Equal(original.Error.Code, actual.Error!.Code); + Assert.Equal(original.Error.Message, actual.Error.Message); + Assert.Equal(((CustomType)original.Error.Data).Age, actual.Error.GetData().Age); + } + + [Fact] + public async Task BasicJsonRpc() + { + var (clientStream, serverStream) = FullDuplexStream.CreatePair(); + var clientFormatter = new NerdbankMessagePackFormatter { TypeShapeProvider = Witness.ShapeProvider }; + var serverFormatter = new NerdbankMessagePackFormatter { TypeShapeProvider = Witness.ShapeProvider }; + + var clientHandler = new LengthHeaderMessageHandler(clientStream.UsePipe(cancellationToken: TestContext.Current.CancellationToken), clientFormatter); + var serverHandler = new LengthHeaderMessageHandler(serverStream.UsePipe(cancellationToken: TestContext.Current.CancellationToken), serverFormatter); + + var clientRpc = new JsonRpc(clientHandler); + var serverRpc = new JsonRpc(serverHandler, new Server()); + + serverRpc.TraceSource = new TraceSource("Server", SourceLevels.Verbose); + clientRpc.TraceSource = new TraceSource("Client", SourceLevels.Verbose); + + serverRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); + clientRpc.TraceSource.Listeners.Add(new XunitTraceListener(this.Logger)); + + clientRpc.StartListening(); + serverRpc.StartListening(); + + int result = await clientRpc.InvokeAsync(nameof(Server.Add), 3, 5).WithCancellation(this.TimeoutToken); + Assert.Equal(8, result); + } + + [Fact] + public void Resolver_RequestArgInArray() + { + var originalArg = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + ArgumentsList = new object[] { originalArg }, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(null, 0, typeof(TypeRequiringCustomFormatter), out object? roundtripArgObj)); + var roundtripArg = (TypeRequiringCustomFormatter)roundtripArgObj!; + Assert.Equal(originalArg.Prop1, roundtripArg.Prop1); + Assert.Equal(originalArg.Prop2, roundtripArg.Prop2); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_AnonymousType() + { + this.Formatter = new() + { + TypeShapeProvider = Witness.ShapeProvider, + UserDataSerializer = this.Formatter.UserDataSerializer with { Converters = [.. this.Formatter.UserDataSerializer.Converters, new CustomConverter()] }, + }; + + var originalArg = new { Prop1 = 3, Prop2 = 5 }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.Prop1), -1, typeof(int), out object? prop1)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.Prop2), -1, typeof(int), out object? prop2)); + Assert.Equal(originalArg.Prop1, prop1); + Assert.Equal(originalArg.Prop2, prop2); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_DataContractObject() + { + this.Formatter = new() + { + TypeShapeProvider = Witness.ShapeProvider, + UserDataSerializer = this.Formatter.UserDataSerializer with { Converters = [.. this.Formatter.UserDataSerializer.Converters, new CustomConverter()] }, + }; + + var originalArg = new DataContractWithSubsetOfMembersIncluded { ExcludedField = "A", ExcludedProperty = "B", IncludedField = "C", IncludedProperty = "D" }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedProperty), -1, typeof(string), out object? _)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.IncludedField), -1, typeof(string), out object? includedField)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.IncludedProperty), -1, typeof(string), out object? includedProperty)); + Assert.Equal(originalArg.IncludedProperty, includedProperty); + Assert.Equal(originalArg.IncludedField, includedField); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_NonDataContractObject() + { + this.Formatter = new() + { + TypeShapeProvider = Witness.ShapeProvider, + UserDataSerializer = this.Formatter.UserDataSerializer with { Converters = [.. this.Formatter.UserDataSerializer.Converters, new CustomConverter()] }, + }; + + var originalArg = new NonDataContractWithExcludedMembers { ExcludedField = "A", ExcludedProperty = "B", InternalField = "C", InternalProperty = "D", PublicField = "E", PublicProperty = "F" }; + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = originalArg, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.ExcludedProperty), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.InternalField), -1, typeof(string), out object? _)); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.InternalProperty), -1, typeof(string), out object? _)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.PublicField), -1, typeof(string), out object? publicField)); + Assert.True(roundtripRequest.TryGetArgumentByNameOrIndex(nameof(originalArg.PublicProperty), -1, typeof(string), out object? publicProperty)); + Assert.Equal(originalArg.PublicProperty, publicProperty); + Assert.Equal(originalArg.PublicField, publicField); + } + + [Fact] + public void Resolver_RequestArgInNamedArgs_NullObject() + { + var originalRequest = new JsonRpcRequest + { + RequestId = new RequestId(1), + Method = "Eat", + Arguments = null, + }; + var roundtripRequest = this.Roundtrip(originalRequest); + Assert.Null(roundtripRequest.Arguments); + Assert.False(roundtripRequest.TryGetArgumentByNameOrIndex("AnythingReally", -1, typeof(string), out object? _)); + } + + [Fact] + public void Resolver_Result() + { + this.Formatter = new() + { + TypeShapeProvider = Witness.ShapeProvider, + UserDataSerializer = this.Formatter.UserDataSerializer with { Converters = [.. this.Formatter.UserDataSerializer.Converters, new CustomConverter()] }, + }; + + var originalResultValue = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalResult = new JsonRpcResult + { + RequestId = new RequestId(1), + Result = originalResultValue, + }; + var roundtripResult = this.Roundtrip(originalResult); + var roundtripResultValue = roundtripResult.GetResult(); + Assert.Equal(originalResultValue.Prop1, roundtripResultValue.Prop1); + Assert.Equal(originalResultValue.Prop2, roundtripResultValue.Prop2); + } + + [Fact] + public void Resolver_ErrorData() + { + var originalErrorData = new TypeRequiringCustomFormatter { Prop1 = 3, Prop2 = 5 }; + var originalError = new JsonRpcError + { + RequestId = new RequestId(1), + Error = new JsonRpcError.ErrorDetail + { + Data = originalErrorData, + }, + }; + var roundtripError = this.Roundtrip(originalError); + var roundtripErrorData = roundtripError.Error!.GetData(); + Assert.Equal(originalErrorData.Prop1, roundtripErrorData.Prop1); + Assert.Equal(originalErrorData.Prop2, roundtripErrorData.Prop2); + } + + [Fact] + public void CanDeserializeWithExtraProperty_JsonRpcRequest() + { + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = new object[] { "hi" }, + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.method, request.Method); + Assert.Equal(dynamic.@params.Length, request.ArgumentCount); + Assert.True(request.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg)); + Assert.Equal(dynamic.@params[0], arg); + } + + [Fact] + public void CanDeserializeWithExtraProperty_JsonRpcResult() + { + var dynamic = new + { + jsonrpc = "2.0", + id = 2, + extra = (object?)null, + result = "hi", + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.id, request.RequestId.Number); + Assert.Equal(dynamic.result, request.GetResult()); + } + + [Fact] + public void CanDeserializeWithExtraProperty_JsonRpcError() + { + var dynamic = new + { + jsonrpc = "2.0", + id = 2, + extra = (object?)null, + error = new { extra = 2, code = 5 }, + }; + var request = this.Read(dynamic); + Assert.Equal(dynamic.jsonrpc, request.Version); + Assert.Equal(dynamic.id, request.RequestId.Number); + Assert.Equal(dynamic.error.code, (int?)request.Error?.Code); + } + + [Fact] + public void StringsInUserDataAreInterned() + { + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = new object[] { "hi" }, + }; + var request1 = this.Read(dynamic); + var request2 = this.Read(dynamic); + Assert.True(request1.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg1)); + Assert.True(request2.TryGetArgumentByNameOrIndex(null, 0, typeof(string), out object? arg2)); + Assert.Same(arg2, arg1); // reference equality to ensure it was interned. + } + + [Fact] + public void StringValuesOfStandardPropertiesAreInterned() + { + var dynamic = new + { + jsonrpc = "2.0", + method = "something", + extra = (object?)null, + @params = Array.Empty(), + }; + var request1 = this.Read(dynamic); + var request2 = this.Read(dynamic); + Assert.Same(request1.Method, request2.Method); // reference equality to ensure it was interned. + } + + protected override NerdbankMessagePackFormatter CreateFormatter() + { + NerdbankMessagePackFormatter formatter = new() + { + TypeShapeProvider = Witness.ShapeProvider, + }; + + return formatter; + } + + private T Read(object anonymousObject) + where T : JsonRpcMessage + { + var sequence = new Sequence(); + var writer = new MessagePackWriter(sequence); + AnonymousTypeSerializer.SerializeObject(ref writer, anonymousObject, ReflectionTypeShapeProvider.Default.Resolve(anonymousObject.GetType())); + writer.Flush(); + this.Logger.WriteLine(MessagePackSerializer.ConvertToJson(sequence)); + return (T)this.Formatter.Deserialize(sequence); + } + + [DataContract] + [GenerateShape] + public partial class DataContractWithSubsetOfMembersIncluded + { + [PropertyShape(Ignore = true)] + public string? ExcludedField; + + [DataMember] + internal string? IncludedField; + + [PropertyShape(Ignore = true)] + public string? ExcludedProperty { get; set; } + + [DataMember] + internal string? IncludedProperty { get; set; } + } + + [GenerateShape] + public partial class NonDataContractWithExcludedMembers + { + [IgnoreDataMember] + [PropertyShape(Ignore = true)] + public string? ExcludedField; + + public string? PublicField; + + internal string? InternalField; + + [IgnoreDataMember] + [PropertyShape(Ignore = true)] + public string? ExcludedProperty { get; set; } + + public string? PublicProperty { get; set; } + + internal string? InternalProperty { get; set; } + } + + [MessagePackConverter(typeof(CustomConverter))] + [GenerateShape] + public partial class TypeRequiringCustomFormatter + { + internal int Prop1 { get; set; } + + internal int Prop2 { get; set; } + } + + internal class CustomConverter : MessagePackConverter + { + public override TypeRequiringCustomFormatter Read(ref MessagePackReader reader, SerializationContext context) + { + Assert.Equal(2, reader.ReadArrayHeader()); + return new TypeRequiringCustomFormatter + { + Prop1 = reader.ReadInt32(), + Prop2 = reader.ReadInt32(), + }; + } + + public override void Write(ref MessagePackWriter writer, in TypeRequiringCustomFormatter? value, SerializationContext context) + { + Requires.NotNull(value!, nameof(value)); + + writer.WriteArrayHeader(2); + writer.Write(value.Prop1); + writer.Write(value.Prop2); + } + } + + private class Server + { + public int Add(int a, int b) => a + b; + } + + private class ObjectConverter : MessagePackConverter + { + public override object? Read(ref MessagePackReader reader, SerializationContext context) + { + throw new NotSupportedException(); + } + + public override void Write(ref MessagePackWriter writer, in object? value, SerializationContext context) + { + if (value is null) + { + writer.WriteNil(); + return; + } + + context.GetConverter(value.GetType(), null).WriteObject(ref writer, value, context); + } + } + + [GenerateShapeFor] + private partial class Witness; +} diff --git a/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs new file mode 100644 index 00000000..493c6c57 --- /dev/null +++ b/test/StreamJsonRpc.Tests/ObserverMarshalingNerdbankMessagePackTests.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using PolyType; + +public partial class ObserverMarshalingNerdbankMessagePackTests(ITestOutputHelper logger) : ObserverMarshalingTests(logger) +{ + protected override IJsonRpcMessageFormatter CreateFormatter() => new NerdbankMessagePackFormatter { TypeShapeProvider = Witness.ShapeProvider }; + + [GenerateShapeFor] + [GenerateShapeFor>] + private partial class Witness; +} diff --git a/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs b/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs index a6e88900..f82e569f 100644 --- a/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs +++ b/test/StreamJsonRpc.Tests/ObserverMarshalingTests.cs @@ -83,6 +83,9 @@ public async Task ReturnThenPushSequence() [Fact] public async Task FaultImmediately() { + this.clientRpc.AllowModificationWhileListening = true; + this.clientRpc.AddLoadableType(typeof(ApplicationException)); + var observer = new MockObserver(); await Task.Run(() => this.client.FaultImmediately(observer), TestContext.Current.CancellationToken).WithCancellation(this.TimeoutToken); var ex = await Assert.ThrowsAnyAsync(() => observer.Completion); diff --git a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj index e637384f..7479a266 100644 --- a/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj +++ b/test/StreamJsonRpc.Tests/StreamJsonRpc.Tests.csproj @@ -9,6 +9,8 @@ to avoid proliferation of #if sections in our test code. So suppress it unless we're targeting the oldest framework among our targets. --> $(NoWarn);xUnit1051 + + $(NoWarn);VSTHRD012 @@ -16,30 +18,40 @@ + + + + + + + + + + diff --git a/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs new file mode 100644 index 00000000..59b800be --- /dev/null +++ b/test/StreamJsonRpc.Tests/TargetObjectEventsNerdbankMessagePackTests.cs @@ -0,0 +1,17 @@ +using PolyType; + +public partial class TargetObjectEventsNerdbankMessagePackTests(ITestOutputHelper logger) : TargetObjectEventsTests(logger) +{ + protected override void InitializeFormattersAndHandlers() + { + NerdbankMessagePackFormatter serverMessageFormatter = new() { TypeShapeProvider = Witness.ShapeProvider }; + NerdbankMessagePackFormatter clientMessageFormatter = new() { TypeShapeProvider = Witness.ShapeProvider }; + + this.serverMessageHandler = new LengthHeaderMessageHandler(this.serverStream, this.serverStream, serverMessageFormatter); + this.clientMessageHandler = new LengthHeaderMessageHandler(this.clientStream, this.clientStream, clientMessageFormatter); + } + + [GenerateShapeFor] + [GenerateShapeFor>] + private partial class Witness; +} diff --git a/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs b/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs index cd98caf4..8303c396 100644 --- a/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs +++ b/test/StreamJsonRpc.Tests/TargetObjectEventsTests.cs @@ -2,8 +2,9 @@ using System.Runtime.Serialization; using Microsoft.VisualStudio.Threading; using Nerdbank; +using PolyType; -public abstract class TargetObjectEventsTests : TestBase +public abstract partial class TargetObjectEventsTests : TestBase { protected IJsonRpcMessageHandler serverMessageHandler = null!; protected IJsonRpcMessageHandler clientMessageHandler = null!; @@ -31,7 +32,9 @@ public TargetObjectEventsTests(ITestOutputHelper logger) } [MessagePack.Union(key: 0, typeof(Fruit))] - public interface IFruit + [GenerateShape] + [DerivedTypeShape(typeof(Fruit), Tag = 0)] + public partial interface IFruit { string Name { get; } } @@ -356,8 +359,10 @@ private void ReinitializeRpcWithoutListening() } [DataContract] - public class Fruit : IFruit + [GenerateShape] + public partial class Fruit : IFruit { + [ConstructorShape] internal Fruit(string name) { this.Name = name; @@ -367,6 +372,22 @@ internal Fruit(string name) public string Name { get; } } + [DataContract] + [GenerateShape] + protected internal partial class CustomEventArgs : EventArgs + { + [DataMember] + public int Seeds { get; set; } + } + + [DataContract] + protected internal class MessageEventArgs : EventArgs + where T : class + { + [DataMember] + public T? Message { get; set; } + } + protected class Client { internal Action? ServerEventRaised { get; set; } @@ -485,19 +506,4 @@ protected class ServerWithIncompatibleEvents public event MyDelegate? MyEvent; #pragma warning restore CS0067 } - - [DataContract] - protected class CustomEventArgs : EventArgs - { - [DataMember] - public int Seeds { get; set; } - } - - [DataContract] - protected class MessageEventArgs : EventArgs - where T : class - { - [DataMember] - public T? Message { get; set; } - } } diff --git a/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs b/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs new file mode 100644 index 00000000..acd5b8be --- /dev/null +++ b/test/StreamJsonRpc.Tests/WebSocketMessageHandlerNerdbankMessagePackTests.cs @@ -0,0 +1,12 @@ +using PolyType; + +public partial class WebSocketMessageHandlerNerdbankMessagePackTests : WebSocketMessageHandlerTests +{ + public WebSocketMessageHandlerNerdbankMessagePackTests(ITestOutputHelper logger) + : base(new NerdbankMessagePackFormatter { TypeShapeProvider = Witness.ShapeProvider }, logger) + { + } + + [GenerateShapeFor] + private partial class Witness; +}