From 7446c2704eef3061913dc21e685e06dcf89e0a7e Mon Sep 17 00:00:00 2001 From: Jasper Park Date: Mon, 27 Jan 2025 09:05:04 +0100 Subject: [PATCH 1/6] feat: add FunicularSwitch.Generator code for generic base type --- Source/DocSamples/ReadmeSamples.cs | 6 +- .../RoslynExtensions.cs | 32 ++++ .../EnumType/Generator.cs | 5 +- .../Generation/Indent.cs | 15 +- .../UnionType/Generator.cs | 140 ++++++++++++++---- .../UnionType/Parser.cs | 50 ++++--- .../UnionType/UnionTypeSchema.cs | 6 +- .../ExampleResult.cs | 16 ++ .../UnionTypeMethods.cs | 14 ++ .../UnionTypeGeneratorSpecs.cs | 20 +++ 10 files changed, 241 insertions(+), 63 deletions(-) diff --git a/Source/DocSamples/ReadmeSamples.cs b/Source/DocSamples/ReadmeSamples.cs index 693bab7..dd92234 100644 --- a/Source/DocSamples/ReadmeSamples.cs +++ b/Source/DocSamples/ReadmeSamples.cs @@ -110,12 +110,12 @@ static IEnumerable CheckFruit(Fruit fruit) } var salad = - await ingredients - .Select(ingredient => + await ingredients.Select(ingredient => stock .Where(fruit => fruit.Name == ingredient) .FirstOk(CheckFruit, onEmpty: () => $"No {ingredient} in stock") - ) + ) + .Aggregate() .Bind(fruits => CutIntoPieces(fruits, cookSkillLevel)) .Map(Serve); diff --git a/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs b/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs index de5216a..052b7b6 100644 --- a/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs +++ b/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs @@ -99,6 +99,38 @@ public static QualifiedTypeName QualifiedName(this BaseTypeDeclarationSyntax dec return new(dec.Name(), typeNames); } + public static QualifiedTypeName QualifiedNameWithGenerics(this BaseTypeDeclarationSyntax dec) + { + var current = dec.Parent as BaseTypeDeclarationSyntax; + var typeNames = new Stack(); + while (current != null) + { + typeNames.Push(current.Name() + FormatTypeParameters(current.GetTypeParameterList())); + current = current.Parent as BaseTypeDeclarationSyntax; + } + + return new(dec.Name() + FormatTypeParameters(dec.GetTypeParameterList()), typeNames); + } + + public static EquatableArray GetTypeParameterList(this BaseTypeDeclarationSyntax dec) + { + if (dec is not TypeDeclarationSyntax tds) + { + return []; + } + + return tds.TypeParameterList?.Parameters.Select(tps => tps.Identifier.Text).ToImmutableArray() ?? ImmutableArray.Empty; + } + + public static string FormatTypeParameters(EquatableArray typeParameters) + { + if (typeParameters.Length == 0) + { + return string.Empty; + } + + return "<" + string.Join(", ", typeParameters) + ">"; + } public static string Name(this BaseTypeDeclarationSyntax declaration) => declaration.Identifier.ToString(); diff --git a/Source/FunicularSwitch.Generators/EnumType/Generator.cs b/Source/FunicularSwitch.Generators/EnumType/Generator.cs index 6d632bf..2377d10 100644 --- a/Source/FunicularSwitch.Generators/EnumType/Generator.cs +++ b/Source/FunicularSwitch.Generators/EnumType/Generator.cs @@ -1,4 +1,5 @@ -using FunicularSwitch.Generators.Common; +using System.Collections.Immutable; +using FunicularSwitch.Generators.Common; using FunicularSwitch.Generators.Generation; using Microsoft.CodeAnalysis; @@ -55,7 +56,7 @@ void BlankLine() } builder.WriteLine("#pragma warning restore 1591"); - return (enumTypeSchema.FullTypeName.ToMatchExtensionFilename(), builder.ToString()); + return (enumTypeSchema.FullTypeName.ToMatchExtensionFilename(ImmutableArray.Empty), builder.ToString()); } static void GenerateMatchMethod(CSharpBuilder builder, EnumTypeSchema enumTypeSchema, string t) diff --git a/Source/FunicularSwitch.Generators/Generation/Indent.cs b/Source/FunicularSwitch.Generators/Generation/Indent.cs index 75c68fe..031b40f 100644 --- a/Source/FunicularSwitch.Generators/Generation/Indent.cs +++ b/Source/FunicularSwitch.Generators/Generation/Indent.cs @@ -1,4 +1,5 @@ -using FunicularSwitch.Generators.Common; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; +using FunicularSwitch.Generators.Common; namespace FunicularSwitch.Generators.Generation; @@ -242,5 +243,15 @@ public static string TrimBaseTypeName(this string value, string baseTypeName) return value; } - public static string ToMatchExtensionFilename(this string fullTypeName) => $"{fullTypeName.Replace(".", "")}MatchExtension.g.cs"; + public static string ToMatchExtensionFilename(this string fullTypeName, EquatableArray typeParameters) => $"{fullTypeName.Replace(".", "")}{FormatTypeParameterForFileName(typeParameters)}MatchExtension.g.cs"; + + public static string FormatTypeParameterForFileName(EquatableArray typeParameters) + { + if (typeParameters.Length == 0) + { + return string.Empty; + } + + return "Of" + string.Join("_", typeParameters); + } } \ No newline at end of file diff --git a/Source/FunicularSwitch.Generators/UnionType/Generator.cs b/Source/FunicularSwitch.Generators/UnionType/Generator.cs index 57a1021..9121b2f 100644 --- a/Source/FunicularSwitch.Generators/UnionType/Generator.cs +++ b/Source/FunicularSwitch.Generators/UnionType/Generator.cs @@ -1,4 +1,5 @@ using System.Collections.Immutable; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; using FunicularSwitch.Generators.Common; using FunicularSwitch.Generators.Generation; using Microsoft.CodeAnalysis; @@ -21,19 +22,30 @@ public static (string filename, string source) Emit(UnionTypeSchema unionTypeSch using (unionTypeSchema.Namespace != null ? builder.Namespace(unionTypeSchema.Namespace) : null) { - WriteMatchExtension(unionTypeSchema, builder); - if (unionTypeSchema is { IsPartial: true, StaticFactoryInfo: not null }) - { - builder.WriteLine(""); - WritePartialWithStaticFactories(unionTypeSchema, builder); - } - } + if (unionTypeSchema.TypeParameters.Length > 0) + { + if (unionTypeSchema.IsPartial) + { + WritePartialWithMatchMethods(unionTypeSchema, builder); + } + } + else + { + WriteMatchExtension(unionTypeSchema, builder); + } + + if (unionTypeSchema is { IsPartial: true, StaticFactoryInfo: not null }) + { + builder.WriteLine(""); + WritePartialWithStaticFactories(unionTypeSchema, builder); + } + } builder.WriteLine("#pragma warning restore 1591"); - return (unionTypeSchema.FullTypeName.ToMatchExtensionFilename(), builder.ToString()); + return (unionTypeSchema.FullTypeName.ToMatchExtensionFilename(unionTypeSchema.TypeParameters), builder.ToString()); } - static void WriteMatchExtension(UnionTypeSchema unionTypeSchema, CSharpBuilder builder) + private static void WriteMatchExtension(UnionTypeSchema unionTypeSchema, CSharpBuilder builder) { using (builder.StaticPartialClass($"{unionTypeSchema.TypeName.Replace(".", "_")}MatchExtension", unionTypeSchema.IsInternal ? "internal" : "public")) @@ -73,18 +85,45 @@ void WriteBodyForAsyncTaskExtension(string matchMethodName) => builder.WriteLine WriteBodyForAsyncTaskExtension(VoidMatchMethodName); } + return; + void BlankLine() { builder.WriteLine(""); } } - static void WritePartialWithStaticFactories(UnionTypeSchema unionTypeSchema, CSharpBuilder builder) + private static void WritePartialWithMatchMethods(UnionTypeSchema unionTypeSchema, CSharpBuilder builder) + { + var typeParameters = RoslynExtensions.FormatTypeParameters(unionTypeSchema.TypeParameters); + var unusedTypeParameter = GetUnusedTypeParameter(unionTypeSchema.TypeParameters); + var typeKind = GetTypeKind(unionTypeSchema); + builder.WriteLine($"{(unionTypeSchema.Modifiers.ToSeparatedString(" "))} {typeKind} {unionTypeSchema.TypeName}{typeParameters}"); + using (builder.Scope()) + { + GenerateMatchMethod(builder, unionTypeSchema, returnType: unusedTypeParameter, t: unusedTypeParameter, asExtension: false); + BlankLine(); + + GenerateSwitchMethod(builder, unionTypeSchema, isAsync: false, asExtension: false); + BlankLine(); + GenerateSwitchMethod(builder, unionTypeSchema, isAsync: true, asExtension: false); + } + + return; + + void BlankLine() + { + builder.WriteLine(""); + } + } + + private static void WritePartialWithStaticFactories(UnionTypeSchema unionTypeSchema, CSharpBuilder builder) { var info = unionTypeSchema.StaticFactoryInfo!; + var typeParameters = RoslynExtensions.FormatTypeParameters(unionTypeSchema.TypeParameters); - var typeKind = unionTypeSchema.TypeKind switch { UnionTypeTypeKind.Class => "class", UnionTypeTypeKind.Interface => "interface", UnionTypeTypeKind.Record => "record", _ => throw new ArgumentException($"Unknown type kind: {unionTypeSchema.TypeKind}") }; - builder.WriteLine($"{(info.Modifiers.ToSeparatedString(" "))} {typeKind} {unionTypeSchema.TypeName}"); + var typeKind = GetTypeKind(unionTypeSchema); + builder.WriteLine($"{(unionTypeSchema.Modifiers.ToSeparatedString(" "))} {typeKind} {unionTypeSchema.TypeName}{typeParameters}"); using (builder.Scope()) { foreach (var derivedType in unionTypeSchema.Cases) @@ -93,7 +132,7 @@ static void WritePartialWithStaticFactories(UnionTypeSchema unionTypeSchema, CSh var derivedTypeName = nameParts[nameParts.Length - 1]; var methodName = derivedType.StaticFactoryMethodName; - if ($"{unionTypeSchema.FullTypeName}.{methodName}" == derivedType.FullTypeName) //union case is nested type without underscores, so factory method name would conflict with type name + if ($"{unionTypeSchema.FullTypeNameWithTypeParameters}.{methodName}" == derivedType.FullTypeName) //union case is nested type without underscores, so factory method name would conflict with type name continue; var constructors = derivedType.Constructors; @@ -124,19 +163,32 @@ static void WritePartialWithStaticFactories(UnionTypeSchema unionTypeSchema, CSh var arguments = constructor.Parameters.ToSeparatedString(); var constructorInvocation = $"new {derivedType.FullTypeName}({(constructor.Parameters.Select(p => p.Name).ToSeparatedString())})"; - builder.WriteLine($"{(isInternal ? "internal" : "public")} static {unionTypeSchema.FullTypeName} {methodName}({arguments}) => {constructorInvocation};"); + builder.WriteLine($"{(isInternal ? "internal" : "public")} static {unionTypeSchema.FullTypeName}{typeParameters} {methodName}({arguments}) => {constructorInvocation};"); } } } } - static void GenerateMatchMethod(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, string t) + private static string GetTypeKind(UnionTypeSchema unionTypeSchema) + { + var typeKind = unionTypeSchema.TypeKind switch { UnionTypeTypeKind.Class => "class", UnionTypeTypeKind.Interface => "interface", UnionTypeTypeKind.Record => "record", _ => throw new ArgumentException($"Unknown type kind: {unionTypeSchema.TypeKind}") }; + return typeKind; + } + + static void GenerateMatchMethod(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, string returnType, string t = "T", bool asExtension = true) { var thisParameterType = unionTypeSchema.FullTypeName; var thisParameter = ThisParameter(unionTypeSchema, thisParameterType); var thisParameterName = thisParameter.Name; - WriteMatchSignature(builder, unionTypeSchema, thisParameter, t); - builder.WriteLine($"{thisParameterName} switch"); + var thisStatement = asExtension ? thisParameterName : "this"; + WriteMatchSignature( + builder: builder, + unionTypeSchema: unionTypeSchema, + thisParameter: asExtension ? thisParameter : null, + returnType: returnType, + t: t, + modifiers: asExtension ? "public static" : "public"); + builder.WriteLine($"{thisStatement} switch"); using (builder.ScopeWithSemicolon()) { var caseIndex = 0; @@ -148,19 +200,20 @@ static void GenerateMatchMethod(CSharpBuilder builder, UnionTypeSchema unionType } builder.WriteLine( - $"_ => throw new global::System.ArgumentException($\"Unknown type derived from {unionTypeSchema.FullTypeName}: {{{thisParameterName}.GetType().Name}}\")"); + $"_ => throw new global::System.ArgumentException($\"Unknown type derived from {unionTypeSchema.FullTypeName}: {{{thisStatement}.GetType().Name}}\")"); } } - static void GenerateSwitchMethod(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, bool isAsync) + static void GenerateSwitchMethod(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, bool isAsync, bool asExtension = true) { var thisParameterType = unionTypeSchema.FullTypeName; var thisParameter = ThisParameter(unionTypeSchema, thisParameterType); var thisParameterName = thisParameter.Name; - WriteSwitchSignature(builder, unionTypeSchema, thisParameter, isAsync); + var thisStatement = asExtension ? thisParameterName : "this"; + WriteSwitchSignature(builder, unionTypeSchema, asExtension ? thisParameter : null, isAsync, modifiers: asExtension ? "public static" : "public"); using (builder.Scope()) { - builder.WriteLine($"switch ({thisParameterName})"); + builder.WriteLine($"switch ({thisStatement})"); using (builder.Scope()) { var caseIndex = 0; @@ -182,30 +235,51 @@ static void GenerateSwitchMethod(CSharpBuilder builder, UnionTypeSchema unionTyp builder.WriteLine("default:"); using (builder.Indent()) { - builder.WriteLine($"throw new global::System.ArgumentException($\"Unknown type derived from {unionTypeSchema.FullTypeName}: {{{thisParameterName}.GetType().Name}}\");"); + builder.WriteLine($"throw new global::System.ArgumentException($\"Unknown type derived from {unionTypeSchema.FullTypeName}: {{{thisStatement}.GetType().Name}}\");"); } } } - } + } + + private static string GetUnusedTypeParameter(EquatableArray typeParameters) + { + return Enumerable.Range(0, 20) + .Select(i => Check(new string('_', i) + "TMatchResult")) + .FirstOrDefault(s => s is not null) ?? "T" + Guid.NewGuid().ToString("N"); - static Parameter ThisParameter(UnionTypeSchema unionTypeSchema, string thisParameterType) => new($"this {thisParameterType}", unionTypeSchema.TypeName.ToParameterName()); + string? Check(string typeName) + { + if (!typeParameters.Contains(typeName)) + { + return typeName; + } + + return null; + } + } + + static Parameter ThisParameter(UnionTypeSchema unionTypeSchema, string thisParameterType) => new($"this {thisParameterType}", unionTypeSchema.TypeName.ToParameterName()); static void WriteMatchSignature(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, - Parameter thisParameter, string returnType, string? handlerReturnType = null, string modifiers = "public static") + Parameter? thisParameter, string returnType, string? handlerReturnType = null, string modifiers = "public static", string t = "T") { handlerReturnType ??= returnType; var handlerParameters = unionTypeSchema.Cases .Select(c => new Parameter($"global::System.Func<{c.FullTypeName}, {handlerReturnType}>", c.ParameterName)); - builder.WriteMethodSignature( + if (thisParameter is not null) + { + handlerParameters = handlerParameters.Prepend(thisParameter); + } + builder.WriteMethodSignature( modifiers: modifiers, returnType: returnType, - methodName: "Match", parameters: new[] { thisParameter }.Concat(handlerParameters), + methodName: "Match<" + t + ">", parameters: handlerParameters, lambda: true); } static void WriteSwitchSignature(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, - Parameter thisParameter, bool isAsync, bool? asyncReturn = null, bool lambda = false) + Parameter? thisParameter, bool isAsync, string modifiers = "public static", bool? asyncReturn = null, bool lambda = false) { var returnType = asyncReturn ?? isAsync ? "async global::System.Threading.Tasks.Task" : "void"; var handlerParameters = unionTypeSchema.Cases @@ -218,13 +292,15 @@ static void WriteSwitchSignature(CSharpBuilder builder, UnionTypeSchema unionTyp parameterType, c.ParameterName); }); + if (thisParameter is not null) + { + handlerParameters = handlerParameters.Prepend(thisParameter); + } - string modifiers = "public static"; - - builder.WriteMethodSignature( + builder.WriteMethodSignature( modifiers: modifiers, returnType: returnType, - methodName: VoidMatchMethodName, parameters: new[] { thisParameter }.Concat(handlerParameters), + methodName: VoidMatchMethodName, parameters: handlerParameters, lambda: lambda); } } \ No newline at end of file diff --git a/Source/FunicularSwitch.Generators/UnionType/Parser.cs b/Source/FunicularSwitch.Generators/UnionType/Parser.cs index 44a67df..8905cd2 100644 --- a/Source/FunicularSwitch.Generators/UnionType/Parser.cs +++ b/Source/FunicularSwitch.Generators/UnionType/Parser.cs @@ -1,4 +1,5 @@ using System.Collections.Immutable; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; using FunicularSwitch.Generators.Common; using FunicularSwitch.Generators.Generation; using Microsoft.CodeAnalysis; @@ -17,7 +18,10 @@ public static GenerationResult GetUnionTypeSchema(Compilation c { var semanticModel = compilation.GetSemanticModel(unionTypeClass.SyntaxTree); + var typeParameters = unionTypeClass.GetTypeParameterList(); + var fullTypeName = unionTypeSymbol.FullTypeNameWithNamespace(); + var fullTypeNameWithTypeParameters = fullTypeName + RoslynExtensions.FormatTypeParameters(typeParameters); var acc = unionTypeSymbol.DeclaredAccessibility; if (acc is Accessibility.Private or Accessibility.Protected) { @@ -45,30 +49,32 @@ public static GenerationResult GetUnionTypeSchema(Compilation c return ToOrderedCases(caseOrder, derivedTypes, compilation, generateFactoryMethods, unionTypeSymbol.Name) - .Map(cases => - new UnionTypeSchema( - Namespace: fullNamespace, - TypeName: unionTypeSymbol.Name, - FullTypeName: fullTypeName, - Cases: cases, - IsInternal: acc is Accessibility.NotApplicable or Accessibility.Internal, - IsPartial: isPartial, - TypeKind: unionTypeClass switch - { - RecordDeclarationSyntax => UnionTypeTypeKind.Record, - InterfaceDeclarationSyntax => UnionTypeTypeKind.Interface, - _ => UnionTypeTypeKind.Class - }, - StaticFactoryInfo: generateFactoryMethods - ? BuildFactoryInfo(unionTypeClass, compilation) - : null - )); + .Map(cases => new UnionTypeSchema( + Namespace: fullNamespace, + TypeName: unionTypeSymbol.Name, + FullTypeName: fullTypeName, + FullTypeNameWithTypeParameters: fullTypeNameWithTypeParameters, + Cases: cases, + TypeParameters: typeParameters, + IsInternal: acc is Accessibility.NotApplicable or Accessibility.Internal, + IsPartial: isPartial, + TypeKind: unionTypeClass switch + { + RecordDeclarationSyntax => UnionTypeTypeKind.Record, + InterfaceDeclarationSyntax => UnionTypeTypeKind.Interface, + _ => UnionTypeTypeKind.Class + }, + Modifiers: unionTypeClass.Modifiers.ToEquatableModifiers(), + StaticFactoryInfo: generateFactoryMethods + ? BuildFactoryInfo(unionTypeClass, compilation) + : null + )); static GenerationResult Error(Diagnostic diagnostic) => GenerationResult.Empty.AddDiagnostics(diagnostic); } - static (string parameterName, string methodName) DeriveParameterAndStaticMethodName(string typeName, + private static (string parameterName, string methodName) DeriveParameterAndStaticMethodName(string typeName, string baseTypeName) { var candidates = ImmutableList.Empty; @@ -113,7 +119,7 @@ PropertyDeclarationSyntax p when p.Modifiers.HasModifier(SyntaxKind.StaticKeywor }) .ToImmutableArray(); - return new(staticMethods, staticFields, unionTypeClass.Modifiers.ToEquatableModifiers()); + return new(staticMethods, staticFields); } static GenerationResult> ToOrderedCases(CaseOrder caseOrder, @@ -123,7 +129,7 @@ static GenerationResult> ToOrderedCases(CaseOrder ca var ordered = derivedTypes.OrderByDescending(d => d.numberOfConctreteBaseTypes); ordered = caseOrder switch { - CaseOrder.Alphabetic => ordered.ThenBy(d => d.node.QualifiedName().Name), + CaseOrder.Alphabetic => ordered.ThenBy(d => d.node.QualifiedNameWithGenerics().Name), CaseOrder.AsDeclared => ordered.ThenBy(d => d.node.SyntaxTree.FilePath) .ThenBy(d => d.node.Span.Start), CaseOrder.Explicit => ordered.ThenBy(d => d.caseIndex), @@ -166,7 +172,7 @@ static GenerationResult> ToOrderedCases(CaseOrder ca var derived = result.Select(d => { - var qualifiedTypeName = d.node.QualifiedName(); + var qualifiedTypeName = d.node.QualifiedNameWithGenerics(); var fullNamespace = d.symbol.GetFullNamespace(); IEnumerable? constructors = null; if (getConstructors) diff --git a/Source/FunicularSwitch.Generators/UnionType/UnionTypeSchema.cs b/Source/FunicularSwitch.Generators/UnionType/UnionTypeSchema.cs index c5deda9..aa302fb 100644 --- a/Source/FunicularSwitch.Generators/UnionType/UnionTypeSchema.cs +++ b/Source/FunicularSwitch.Generators/UnionType/UnionTypeSchema.cs @@ -6,10 +6,13 @@ namespace FunicularSwitch.Generators.UnionType; public sealed record UnionTypeSchema(string? Namespace, string TypeName, string FullTypeName, + string FullTypeNameWithTypeParameters, EquatableArray Cases, + EquatableArray TypeParameters, bool IsInternal, bool IsPartial, UnionTypeTypeKind TypeKind, + EquatableArray Modifiers, StaticFactoryMethodsInfo? StaticFactoryInfo); public enum UnionTypeTypeKind @@ -21,8 +24,7 @@ public enum UnionTypeTypeKind public record StaticFactoryMethodsInfo( EquatableArray ExistingStaticMethods, - EquatableArray ExistingStaticFields, - EquatableArray Modifiers + EquatableArray ExistingStaticFields ); public sealed record DerivedType diff --git a/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer.Dependency/ExampleResult.cs b/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer.Dependency/ExampleResult.cs index c9a1cfc..939a41c 100644 --- a/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer.Dependency/ExampleResult.cs +++ b/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer.Dependency/ExampleResult.cs @@ -61,4 +61,20 @@ internal abstract record InternalUnionType { public sealed record First(string Text) : InternalUnionType; public sealed record Second(string Text) : InternalUnionType; +} + +[UnionType(CaseOrder = CaseOrder.AsDeclared)] +public abstract partial record GenericUnionType +{ + public sealed record First_(T Value) : GenericUnionType; + + public sealed record Second_ : GenericUnionType; +} + +[UnionType(CaseOrder = CaseOrder.AsDeclared)] +public abstract partial record MultiGenericUnionType +{ + public sealed record One_(TFirst First, TSecond Second) : MultiGenericUnionType; + + public sealed record Two_(TThird Third) : MultiGenericUnionType; } \ No newline at end of file diff --git a/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer/UnionTypeMethods.cs b/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer/UnionTypeMethods.cs index 4298437..b8320dc 100644 --- a/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer/UnionTypeMethods.cs +++ b/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer/UnionTypeMethods.cs @@ -94,4 +94,18 @@ public void NestedUnionType_OtherCase_IsNotFirstCase() Action(() => result.Should().BeDerivedNestedUnionType()) .Should().Throw(); } + + [Fact] + public void GenericUnionType_FirstCase_IsCase() + { + // ARRANGE + var union = GenericUnionType.First(5); + + // ASSERT + union.Match( + first => first.Value, + second => -1) + .Should() + .Be(5); + } } \ No newline at end of file diff --git a/Source/Tests/FunicularSwitch.Generators.Test/UnionTypeGeneratorSpecs.cs b/Source/Tests/FunicularSwitch.Generators.Test/UnionTypeGeneratorSpecs.cs index 38f43d6..6197a7e 100644 --- a/Source/Tests/FunicularSwitch.Generators.Test/UnionTypeGeneratorSpecs.cs +++ b/Source/Tests/FunicularSwitch.Generators.Test/UnionTypeGeneratorSpecs.cs @@ -422,4 +422,24 @@ public abstract partial record NodeMessage(string NodeInstanceId) return Verify(code); } + + [TestMethod] + public Task For_union_type_with_generic_base_class() + { + var code = """ + using FunicularSwitch.Generators; + + namespace FunicularSwitch.Test; + + [UnionType(CaseOrder = CaseOrder.AsDeclared)] + public abstract partial record BaseType(string Value) + { + public sealed record Deriving(string Value, T Other) : BaseType(Value); + + public sealed record Deriving2(string Value) : BaseType(Value); + } + """; + + return Verify(code); + } } \ No newline at end of file From e5c5771cf326111c5927d50c0df8f8deeaa3b469 Mon Sep 17 00:00:00 2001 From: Jasper Park Date: Mon, 27 Jan 2025 13:28:09 +0100 Subject: [PATCH 2/6] feat: adapt FluentAssertionsGenerator to work with generic union types --- Source/DocSamples/DocSamples.csproj | 2 +- .../RoslynExtensions.cs | 12 +++ .../SymbolWrapper.cs | 1 + .../MyAssertionExtensions_UnionType.cs | 2 +- .../MyUnionTypeAssertions_DerivedUnionType.cs | 2 +- .../FluentAssertionMethods/Generator.cs | 19 +++-- .../ResultTypeSchema.cs | 13 +-- .../FluentAssertionMethods/UnionTypeSchema.cs | 13 ++- .../Generation/Indent.cs | 12 +-- .../UnionTypeMethods.cs | 85 +++++++++++++++++-- ...ric_base_class#Attributes.g.00.verified.cs | 40 +++++++++ ...ric_base_class#Attributes.g.01.verified.cs | 28 ++++++ ...ric_base_class#Attributes.g.02.verified.cs | 61 +++++++++++++ ...estBaseTypeOfTMatchExtension.g.verified.cs | 50 +++++++++++ 14 files changed, 299 insertions(+), 41 deletions(-) create mode 100644 Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.00.verified.cs create mode 100644 Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.01.verified.cs create mode 100644 Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.02.verified.cs create mode 100644 Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#FunicularSwitchTestBaseTypeOfTMatchExtension.g.verified.cs diff --git a/Source/DocSamples/DocSamples.csproj b/Source/DocSamples/DocSamples.csproj index 20a22cb..5c831d4 100644 --- a/Source/DocSamples/DocSamples.csproj +++ b/Source/DocSamples/DocSamples.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp3.0 + net8.0 diff --git a/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs b/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs index 052b7b6..2ac2829 100644 --- a/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs +++ b/Source/FunicularSwitch.Generators.Common/RoslynExtensions.cs @@ -77,10 +77,12 @@ public static bool Implements(this INamedTypeSymbol symbol, ITypeSymbol interfac } static readonly SymbolDisplayFormat FullTypeWithNamespaceDisplayFormat = SymbolWrapper.FullTypeWithNamespaceDisplayFormat; + static readonly SymbolDisplayFormat FullTypeWithNamespaceAndGenericsDisplayFormat = SymbolWrapper.FullTypeWithNamespaceAndGenericsDisplayFormat; static readonly SymbolDisplayFormat FullTypeDisplayFormat = new(typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypes); public static string FullTypeNameWithNamespace(this INamedTypeSymbol namedTypeSymbol) => namedTypeSymbol.ToDisplayString(FullTypeWithNamespaceDisplayFormat); + public static string FullTypeNameWithNamespaceAndGenerics(this INamedTypeSymbol namedTypeSymbol) => namedTypeSymbol.ToDisplayString(FullTypeWithNamespaceAndGenericsDisplayFormat); public static string FullTypeName(this INamedTypeSymbol namedTypeSymbol) => namedTypeSymbol.ToDisplayString(FullTypeDisplayFormat); public static string FullNamespace(this INamespaceSymbol namespaceSymbol) => @@ -132,6 +134,16 @@ public static string FormatTypeParameters(EquatableArray typeParameters) return "<" + string.Join(", ", typeParameters) + ">"; } + public static string FormatTypeParameterForFileName(EquatableArray typeParameters) + { + if (typeParameters.Length == 0) + { + return string.Empty; + } + + return "Of" + string.Join("_", typeParameters); + } + public static string Name(this BaseTypeDeclarationSyntax declaration) => declaration.Identifier.ToString(); public static string Name(this MethodDeclarationSyntax declaration) => declaration.Identifier.ToString(); diff --git a/Source/FunicularSwitch.Generators.Common/SymbolWrapper.cs b/Source/FunicularSwitch.Generators.Common/SymbolWrapper.cs index a3bbb69..fb80e06 100644 --- a/Source/FunicularSwitch.Generators.Common/SymbolWrapper.cs +++ b/Source/FunicularSwitch.Generators.Common/SymbolWrapper.cs @@ -5,6 +5,7 @@ namespace FunicularSwitch.Generators.Common; public static class SymbolWrapper { internal static readonly SymbolDisplayFormat FullTypeWithNamespaceDisplayFormat = new(typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces); + internal static readonly SymbolDisplayFormat FullTypeWithNamespaceAndGenericsDisplayFormat = new(typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces, genericsOptions: SymbolDisplayGenericsOptions.IncludeTypeParameters); public static SymbolWrapper Create(T symbol) where T : ISymbol => new(symbol); } diff --git a/Source/FunicularSwitch.Generators.FluentAssertions.Templates/MyAssertionExtensions_UnionType.cs b/Source/FunicularSwitch.Generators.FluentAssertions.Templates/MyAssertionExtensions_UnionType.cs index fe98349..236ebe4 100644 --- a/Source/FunicularSwitch.Generators.FluentAssertions.Templates/MyAssertionExtensions_UnionType.cs +++ b/Source/FunicularSwitch.Generators.FluentAssertions.Templates/MyAssertionExtensions_UnionType.cs @@ -4,6 +4,6 @@ namespace FunicularSwitch.Generators.FluentAssertions.Templates { internal static class MyAssertionExtensions_UnionType { - public static MyAssertions_UnionType Should(this MyUnionType unionType) => new(unionType); + public static MyAssertions_UnionType Should(this MyUnionType unionType) => new(unionType); } } \ No newline at end of file diff --git a/Source/FunicularSwitch.Generators.FluentAssertions.Templates/MyUnionTypeAssertions_DerivedUnionType.cs b/Source/FunicularSwitch.Generators.FluentAssertions.Templates/MyUnionTypeAssertions_DerivedUnionType.cs index fca2189..030d9c6 100644 --- a/Source/FunicularSwitch.Generators.FluentAssertions.Templates/MyUnionTypeAssertions_DerivedUnionType.cs +++ b/Source/FunicularSwitch.Generators.FluentAssertions.Templates/MyUnionTypeAssertions_DerivedUnionType.cs @@ -13,7 +13,7 @@ public AndWhichConstraint BeFriendly Execute.Assertion .ForCondition(this.Subject is MyDerivedUnionType) .BecauseOf(because, becauseArgs) - .FailWith("Expected {context} to be Error with MyDerivedErrorType MyErrorType{reason}, but found {0}", + .FailWith("Expected {context} to be MyDerivedUnionType{reason}, but found {0}", this.Subject); return new(this, (this.Subject as MyDerivedUnionType)!); diff --git a/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/Generator.cs b/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/Generator.cs index 5ae659c..3f6d626 100644 --- a/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/Generator.cs +++ b/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/Generator.cs @@ -1,4 +1,6 @@ -using FunicularSwitch.Generators.Common; +using System.Collections.Immutable; +using CommunityToolkit.Mvvm.SourceGenerators.Helpers; +using FunicularSwitch.Generators.Common; using Microsoft.CodeAnalysis; namespace FunicularSwitch.Generators.FluentAssertions.FluentAssertionMethods; @@ -14,6 +16,7 @@ internal static class Generator private const string TemplateResultAssertionExtensions = "MyAssertionExtensions_Result"; private const string TemplateUnionTypeAssertionsTypeName = "MyAssertions_UnionType"; private const string TemplateUnionTypeAssertionExtensions = "MyAssertionExtensions_UnionType"; + private const string TemplateUnionTypeTypeParameterList = ""; private const string TemplateFriendlyDerivedUnionTypeName = "FriendlyDerivedUnionTypeName"; private const string TemplateAdditionalUsingDirectives = "//additional using directives"; @@ -64,9 +67,13 @@ string Replace(string code, params string[] additionalNamespaces) { var unionTypeFullName = unionTypeSchema.UnionTypeBaseType.FullTypeName().Replace('.', '_'); var unionTypeFullNameWithNamespace = unionTypeSchema.UnionTypeBaseType.FullTypeNameWithNamespace(); + var unionTypeFullNameWithNamespaceAndGenerics = unionTypeSchema.UnionTypeBaseType.FullTypeNameWithNamespaceAndGenerics(); + EquatableArray typeParameters = unionTypeSchema.UnionTypeBaseType.TypeParameters + .Select(t => t.Name).ToImmutableArray(); var unionTypeNamespace = unionTypeSchema.UnionTypeBaseType.GetFullNamespace(); + var typeParametersText = RoslynExtensions.FormatTypeParameters(typeParameters); - var generateFileHint = $"{unionTypeFullNameWithNamespace}"; + var generateFileHint = $"{unionTypeFullNameWithNamespace}{RoslynExtensions.FormatTypeParameterForFileName(typeParameters)}"; //var generatorRuns = RunCount.Increase(unionTypeSchema.UnionTypeBaseType.FullTypeNameWithNamespace()); @@ -74,9 +81,11 @@ string Replace(string code, params string[] additionalNamespaces) { code = code .Replace($"namespace {TemplateNamespace}", $"namespace {unionTypeNamespace}") - .Replace(TemplateUnionTypeName, unionTypeFullNameWithNamespace) - .Replace(TemplateUnionTypeAssertionsTypeName, $"{unionTypeFullName}Assertions") + .Replace(TemplateUnionTypeName, unionTypeFullNameWithNamespaceAndGenerics) + .Replace($"public {TemplateUnionTypeAssertionsTypeName}(", $"public {unionTypeFullName}Assertions(") + .Replace(TemplateUnionTypeAssertionsTypeName, $"{unionTypeFullName}Assertions{typeParametersText}") .Replace(TemplateUnionTypeAssertionExtensions, $"{unionTypeFullName}FluentAssertionExtensions") + .Replace(TemplateUnionTypeTypeParameterList, typeParametersText) .Replace( TemplateAdditionalUsingDirectives, additionalNamespaces @@ -99,7 +108,7 @@ string Replace(string code, params string[] additionalNamespaces) foreach (var derivedType in unionTypeSchema.DerivedTypes) { - var derivedTypeFullNameWithNamespace = derivedType.FullTypeNameWithNamespace(); + var derivedTypeFullNameWithNamespace = derivedType.FullTypeNameWithNamespaceAndGenerics(); yield return ( $"{generateFileHint}_Derived_{derivedType.Name}Assertions.g.cs", Replace(Templates.GenerateFluentAssertionsForTemplates.MyDerivedUnionTypeAssertions) diff --git a/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/ResultTypeSchema.cs b/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/ResultTypeSchema.cs index eef1cd4..74ced03 100644 --- a/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/ResultTypeSchema.cs +++ b/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/ResultTypeSchema.cs @@ -2,16 +2,9 @@ namespace FunicularSwitch.Generators.FluentAssertions.FluentAssertionMethods; -public class ResultTypeSchema +public record ResultTypeSchema( + INamedTypeSymbol ResultType, + INamedTypeSymbol? ErrorType) { - public INamedTypeSymbol ResultType { get; } - public INamedTypeSymbol? ErrorType { get; } - - public ResultTypeSchema(INamedTypeSymbol resultType, INamedTypeSymbol? errorType) - { - ResultType = resultType; - ErrorType = errorType; - } - public override string ToString() => $"{nameof(ResultType)}: {ResultType}, {nameof(ErrorType)}: {ErrorType}"; } \ No newline at end of file diff --git a/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/UnionTypeSchema.cs b/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/UnionTypeSchema.cs index 8b8a410..0e44044 100644 --- a/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/UnionTypeSchema.cs +++ b/Source/FunicularSwitch.Generators.FluentAssertions/FluentAssertionMethods/UnionTypeSchema.cs @@ -2,14 +2,13 @@ namespace FunicularSwitch.Generators.FluentAssertions.FluentAssertionMethods; -public class UnionTypeSchema +public record UnionTypeSchema( + INamedTypeSymbol UnionTypeBaseType, + IEnumerable DerivedTypes) { - public UnionTypeSchema(INamedTypeSymbol unionTypeBaseType, IEnumerable derivedTypes) + public override string ToString() { - UnionTypeBaseType = unionTypeBaseType; - DerivedTypes = derivedTypes.ToList(); + var derivedTypes = string.Join(", ", DerivedTypes.Select(d => d.ToString())); + return $"{nameof(UnionTypeBaseType)}: {UnionTypeBaseType}, {nameof(DerivedTypes)}, {derivedTypes}"; } - - public INamedTypeSymbol UnionTypeBaseType { get; } - public IReadOnlyList DerivedTypes { get; } } \ No newline at end of file diff --git a/Source/FunicularSwitch.Generators/Generation/Indent.cs b/Source/FunicularSwitch.Generators/Generation/Indent.cs index 031b40f..1d53c36 100644 --- a/Source/FunicularSwitch.Generators/Generation/Indent.cs +++ b/Source/FunicularSwitch.Generators/Generation/Indent.cs @@ -243,15 +243,7 @@ public static string TrimBaseTypeName(this string value, string baseTypeName) return value; } - public static string ToMatchExtensionFilename(this string fullTypeName, EquatableArray typeParameters) => $"{fullTypeName.Replace(".", "")}{FormatTypeParameterForFileName(typeParameters)}MatchExtension.g.cs"; + public static string ToMatchExtensionFilename(this string fullTypeName, EquatableArray typeParameters) => $"{fullTypeName.Replace(".", "")}{RoslynExtensions.FormatTypeParameterForFileName(typeParameters)}MatchExtension.g.cs"; - public static string FormatTypeParameterForFileName(EquatableArray typeParameters) - { - if (typeParameters.Length == 0) - { - return string.Empty; - } - - return "Of" + string.Join("_", typeParameters); - } + } \ No newline at end of file diff --git a/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer/UnionTypeMethods.cs b/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer/UnionTypeMethods.cs index b8320dc..172b286 100644 --- a/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer/UnionTypeMethods.cs +++ b/Source/Tests/FunicularSwitch.Generators.FluentAssertions.Consumer/UnionTypeMethods.cs @@ -1,4 +1,5 @@ -using FluentAssertions; +using System.Globalization; +using FluentAssertions; using FunicularSwitch; using FunicularSwitch.Generators.FluentAssertions.Consumer.Dependency; using Xunit.Sdk; @@ -96,16 +97,88 @@ public void NestedUnionType_OtherCase_IsNotFirstCase() } [Fact] - public void GenericUnionType_FirstCase_IsCase() + public void GenericUnionType_Match_WorksCorrectly() { // ARRANGE - var union = GenericUnionType.First(5); + var first = GenericUnionType.First(5); // ASSERT - union.Match( - first => first.Value, - second => -1) + first.Match( + first: f => f.Value, + second: _ => -1) .Should() .Be(5); + + // ARRANGE + var second = GenericUnionType.Second(); + + // ASSERT + second.Match( + first: f => -1, + second: _ => 42) + .Should() + .Be(42); + } + + [Fact] + public void GenericUnionType_Should_Be_X() + { + // ARRANGE + var first = GenericUnionType.First(5); + + // ASSERT + first.Should().BeFirst().Which.Value.Should().Be(5); + var firstShouldBeSecond = () => first.Should().BeSecond(); + firstShouldBeSecond.Should().Throw(); + + // ARRANGE + var second = GenericUnionType.Second(); + second.Should().BeSecond().Which.Should().BeSameAs(second); + var secondShouldBeFirst = () => second.Should().BeFirst(); + secondShouldBeFirst.Should().Throw(); + } + + [Fact] + public void MultiGenericUnionType_Match_WorksCorrectly() + { + // ARRANGE + var one = MultiGenericUnionType.One(5, "Test"); + + // ASSERT + one.Match( + one: o => o.First + o.Second, + two: t => "Two") + .Should() + .Be("5Test"); + + // ARRANGE + var two = MultiGenericUnionType.Two(3.14f); + + // ASSERT + two.Match( + one: o => "One", + two: t => t.Third.ToString("0.00", CultureInfo.InvariantCulture)) + .Should() + .Be("3.14"); + } + + [Fact] + public void MultiGenericUnionType_Should_Be_X() + { + // ARRANGE + var one = MultiGenericUnionType.One(5, "Test"); + + // ASSERT + one.Should().BeOne().Which.First.Should().Be(5); + var oneShouldBeTwo = () => one.Should().BeTwo(); + oneShouldBeTwo.Should().Throw(); + + // ARRANGE + var two = MultiGenericUnionType.Two(3.14f); + + // ASSERT + two.Should().BeTwo().Which.Third.Should().BeApproximately(3.14f, float.Epsilon); + var twoShouldBeOne = () => two.Should().BeOne(); + twoShouldBeOne.Should().Throw(); } } \ No newline at end of file diff --git a/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.00.verified.cs b/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.00.verified.cs new file mode 100644 index 0000000..78b467a --- /dev/null +++ b/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.00.verified.cs @@ -0,0 +1,40 @@ +//HintName: Attributes.g.cs +using System; + +// ReSharper disable once CheckNamespace +namespace FunicularSwitch.Generators +{ + /// + /// Mark an abstract partial type with a single generic argument with the ResultType attribute. + /// This type from now on has Ok | Error semantics with map and bind operations. + /// + [AttributeUsage(AttributeTargets.Class, Inherited = false)] + sealed class ResultTypeAttribute : Attribute + { + public ResultTypeAttribute() => ErrorType = typeof(string); + public ResultTypeAttribute(Type errorType) => ErrorType = errorType; + + public Type ErrorType { get; set; } + } + + /// + /// Mark a static method or a member method or you error type with the MergeErrorAttribute attribute. + /// Static signature: TError -> TError -> TError. Member signature: TError -> TError + /// We are now able to collect errors and methods like Validate, Aggregate, FirstOk that are useful to combine results are generated. + /// + [AttributeUsage(AttributeTargets.Method, Inherited = false)] + sealed class MergeErrorAttribute : Attribute + { + } + + /// + /// Mark a static method with the ExceptionToError attribute. + /// Signature: Exception -> TError + /// This method is always called, when an exception happens in a bind operation. + /// So a call like result.Map(i => i/0) will return an Error produced by the factory method instead of throwing the DivisionByZero exception. + /// + [AttributeUsage(AttributeTargets.Method, Inherited = false)] + sealed class ExceptionToError : Attribute + { + } +} \ No newline at end of file diff --git a/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.01.verified.cs b/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.01.verified.cs new file mode 100644 index 0000000..0541283 --- /dev/null +++ b/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.01.verified.cs @@ -0,0 +1,28 @@ +//HintName: Attributes.g.cs +using System; + +// ReSharper disable once CheckNamespace +namespace FunicularSwitch.Generators +{ + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Interface, Inherited = false)] + sealed class UnionTypeAttribute : Attribute + { + public CaseOrder CaseOrder { get; set; } = CaseOrder.Alphabetic; + public bool StaticFactoryMethods { get; set; } = true; + } + + enum CaseOrder + { + Alphabetic, + AsDeclared, + Explicit + } + + [AttributeUsage(AttributeTargets.Class, Inherited = false)] + sealed class UnionCaseAttribute : Attribute + { + public UnionCaseAttribute(int index) => Index = index; + + public int Index { get; } + } +} \ No newline at end of file diff --git a/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.02.verified.cs b/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.02.verified.cs new file mode 100644 index 0000000..d876a33 --- /dev/null +++ b/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#Attributes.g.02.verified.cs @@ -0,0 +1,61 @@ +//HintName: Attributes.g.cs +using System; + +// ReSharper disable once CheckNamespace +namespace FunicularSwitch.Generators +{ + [AttributeUsage(AttributeTargets.Enum)] + sealed class ExtendedEnumAttribute : Attribute + { + public EnumCaseOrder CaseOrder { get; set; } = EnumCaseOrder.AsDeclared; + public ExtensionAccessibility Accessibility { get; set; } = ExtensionAccessibility.Public; + } + + enum EnumCaseOrder + { + Alphabetic, + AsDeclared + } + + /// + /// Generate match methods for all enums defined in assembly that contains AssemblySpecifier. + /// + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + class ExtendEnumsAttribute : Attribute + { + public Type AssemblySpecifier { get; } + public EnumCaseOrder CaseOrder { get; set; } = EnumCaseOrder.AsDeclared; + public ExtensionAccessibility Accessibility { get; set; } = ExtensionAccessibility.Public; + + public ExtendEnumsAttribute() => AssemblySpecifier = typeof(ExtendEnumsAttribute); + + public ExtendEnumsAttribute(Type assemblySpecifier) + { + AssemblySpecifier = assemblySpecifier; + } + } + + /// + /// Generate match methods for Type. Must be enum. + /// + [AttributeUsage(AttributeTargets.Assembly, AllowMultiple = true)] + class ExtendEnumAttribute : Attribute + { + public Type Type { get; } + + public EnumCaseOrder CaseOrder { get; set; } = EnumCaseOrder.AsDeclared; + + public ExtensionAccessibility Accessibility { get; set; } = ExtensionAccessibility.Public; + + public ExtendEnumAttribute(Type type) + { + Type = type; + } + } + + enum ExtensionAccessibility + { + Internal, + Public + } +} \ No newline at end of file diff --git a/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#FunicularSwitchTestBaseTypeOfTMatchExtension.g.verified.cs b/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#FunicularSwitchTestBaseTypeOfTMatchExtension.g.verified.cs new file mode 100644 index 0000000..221c58b --- /dev/null +++ b/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#FunicularSwitchTestBaseTypeOfTMatchExtension.g.verified.cs @@ -0,0 +1,50 @@ +//HintName: FunicularSwitchTestBaseTypeOfTMatchExtension.g.cs +#pragma warning disable 1591 +namespace FunicularSwitch.Test +{ + public abstract partial record BaseType + { + public TMatchResult Match(global::System.Func.Deriving, TMatchResult> deriving, global::System.Func.Deriving2, TMatchResult> deriving2) => + this switch + { + FunicularSwitch.Test.BaseType.Deriving deriving1 => deriving(deriving1), + FunicularSwitch.Test.BaseType.Deriving2 deriving22 => deriving2(deriving22), + _ => throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {this.GetType().Name}") + }; + + public void Switch(global::System.Action.Deriving> deriving, global::System.Action.Deriving2> deriving2) + { + switch (this) + { + case FunicularSwitch.Test.BaseType.Deriving deriving1: + deriving(deriving1); + break; + case FunicularSwitch.Test.BaseType.Deriving2 deriving22: + deriving2(deriving22); + break; + default: + throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {this.GetType().Name}"); + } + } + + public async global::System.Threading.Tasks.Task Switch(global::System.Func.Deriving, global::System.Threading.Tasks.Task> deriving, global::System.Func.Deriving2, global::System.Threading.Tasks.Task> deriving2) + { + switch (this) + { + case FunicularSwitch.Test.BaseType.Deriving deriving1: + await deriving(deriving1).ConfigureAwait(false); + break; + case FunicularSwitch.Test.BaseType.Deriving2 deriving22: + await deriving2(deriving22).ConfigureAwait(false); + break; + default: + throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {this.GetType().Name}"); + } + } + } + + public abstract partial record BaseType + { + } +} +#pragma warning restore 1591 From 493e08ab32aa7ad93f90e84c1cb2eaafbfb657a5 Mon Sep 17 00:00:00 2001 From: Jasper Park Date: Mon, 27 Jan 2025 13:37:57 +0100 Subject: [PATCH 3/6] build: use net8 in build --- .github/workflows/dotnet-ci.yml | 2 +- Source/FunicularSwitch.sln | 5 +++++ .../FunicularSwitch.Generators.Consumer.Nuget.csproj | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/dotnet-ci.yml b/.github/workflows/dotnet-ci.yml index 70620da..92559bb 100644 --- a/.github/workflows/dotnet-ci.yml +++ b/.github/workflows/dotnet-ci.yml @@ -16,7 +16,7 @@ jobs: - name: Setup .NET uses: actions/setup-dotnet@v3 with: - dotnet-version: 6.0.x + dotnet-version: 8.0.x - name: Restore dependencies run: dotnet restore ./Source - name: Build diff --git a/Source/FunicularSwitch.sln b/Source/FunicularSwitch.sln index 2dbcfad..ee8b987 100644 --- a/Source/FunicularSwitch.sln +++ b/Source/FunicularSwitch.sln @@ -44,6 +44,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "FunicularSwitch.Generators. EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "FunicularSwitch.Generators.Consumer.StandardMinLangVersion", "Tests\FunicularSwitch.Generators.Consumer.StandardMinLangVersion\FunicularSwitch.Generators.Consumer.StandardMinLangVersion.csproj", "{18D4F137-98AF-47CF-88FE-313C4BEA4215}" EndProject +Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Build", "Build", "{1EA7B0C8-17DB-4ED9-A7B8-90CACA398BC3}" + ProjectSection(SolutionItems) = preProject + ..\.github\workflows\dotnet-ci.yml = ..\.github\workflows\dotnet-ci.yml + EndProjectSection +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU diff --git a/Source/Tests/FunicularSwitch.Generators.Consumer.Nuget/FunicularSwitch.Generators.Consumer.Nuget.csproj b/Source/Tests/FunicularSwitch.Generators.Consumer.Nuget/FunicularSwitch.Generators.Consumer.Nuget.csproj index 2eb6894..680f1b5 100644 --- a/Source/Tests/FunicularSwitch.Generators.Consumer.Nuget/FunicularSwitch.Generators.Consumer.Nuget.csproj +++ b/Source/Tests/FunicularSwitch.Generators.Consumer.Nuget/FunicularSwitch.Generators.Consumer.Nuget.csproj @@ -1,7 +1,7 @@  - net6.0 + net8.0 enable false From 249abc204148cabbd934dfbf0b8fb8933bf89289 Mon Sep 17 00:00:00 2001 From: Alexander Wiedemann Date: Tue, 28 Jan 2025 10:22:20 +0100 Subject: [PATCH 4/6] - match methods for generic union types as extensions - DocSamples back to netcoreapp3.0 (trydotnet is sill not working on .netxx) --- Source/DocSamples/DocSamples.csproj | 2 +- .../UnionType/Generator.cs | 425 +++++++-------- ...nericResultOfT_TFailureMatchExtension.g.cs | 71 +++ .../GeneratorSpecs.cs | 515 ++++++++++-------- ...estBaseTypeOfTMatchExtension.g.verified.cs | 54 +- .../UnionTypeGeneratorSpecs.cs | 4 +- 6 files changed, 591 insertions(+), 480 deletions(-) create mode 100644 Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.UnionTypeGenerator/FunicularSwitchGeneratorsConsumerGenericResultOfT_TFailureMatchExtension.g.cs diff --git a/Source/DocSamples/DocSamples.csproj b/Source/DocSamples/DocSamples.csproj index 5c831d4..20a22cb 100644 --- a/Source/DocSamples/DocSamples.csproj +++ b/Source/DocSamples/DocSamples.csproj @@ -2,7 +2,7 @@ Exe - net8.0 + netcoreapp3.0 diff --git a/Source/FunicularSwitch.Generators/UnionType/Generator.cs b/Source/FunicularSwitch.Generators/UnionType/Generator.cs index 9121b2f..c328f9c 100644 --- a/Source/FunicularSwitch.Generators/UnionType/Generator.cs +++ b/Source/FunicularSwitch.Generators/UnionType/Generator.cs @@ -9,30 +9,20 @@ namespace FunicularSwitch.Generators.UnionType; public static class Generator { - const string VoidMatchMethodName = "Switch"; - const string MatchMethodName = "Match"; + const string VoidMatchMethodName = "Switch"; + const string MatchMethodName = "Match"; - public static (string filename, string source) Emit(UnionTypeSchema unionTypeSchema, - Action reportDiagnostic, CancellationToken cancellationToken) - { - var builder = new CSharpBuilder(); - builder.WriteLine("#pragma warning disable 1591"); + public static (string filename, string source) Emit(UnionTypeSchema unionTypeSchema, + Action reportDiagnostic, CancellationToken cancellationToken) + { + var builder = new CSharpBuilder(); + builder.WriteLine("#pragma warning disable 1591"); - //builder.WriteLine($"//Generator runs: {RunCount.Increase(unionTypeSchema.FullTypeName)}"); + //builder.WriteLine($"//Generator runs: {RunCount.Increase(unionTypeSchema.FullTypeName)}"); - using (unionTypeSchema.Namespace != null ? builder.Namespace(unionTypeSchema.Namespace) : null) - { - if (unionTypeSchema.TypeParameters.Length > 0) - { - if (unionTypeSchema.IsPartial) - { - WritePartialWithMatchMethods(unionTypeSchema, builder); - } - } - else - { - WriteMatchExtension(unionTypeSchema, builder); - } + using (unionTypeSchema.Namespace != null ? builder.Namespace(unionTypeSchema.Namespace) : null) + { + WriteMatchExtension(unionTypeSchema, builder); if (unionTypeSchema is { IsPartial: true, StaticFactoryInfo: not null }) { @@ -41,133 +31,112 @@ public static (string filename, string source) Emit(UnionTypeSchema unionTypeSch } } - builder.WriteLine("#pragma warning restore 1591"); - return (unionTypeSchema.FullTypeName.ToMatchExtensionFilename(unionTypeSchema.TypeParameters), builder.ToString()); - } + builder.WriteLine("#pragma warning restore 1591"); + return (unionTypeSchema.FullTypeName.ToMatchExtensionFilename(unionTypeSchema.TypeParameters), builder.ToString()); + } private static void WriteMatchExtension(UnionTypeSchema unionTypeSchema, CSharpBuilder builder) - { - using (builder.StaticPartialClass($"{unionTypeSchema.TypeName.Replace(".", "_")}MatchExtension", - unionTypeSchema.IsInternal ? "internal" : "public")) - { - var thisTaskParameter = ThisParameter(unionTypeSchema, $"global::System.Threading.Tasks.Task<{unionTypeSchema.FullTypeName}>"); - var caseParameters = unionTypeSchema.Cases.Select(c => c.ParameterName).ToSeparatedString(); - - void WriteBodyForTaskExtension(string matchMethodName) => builder.WriteLine( - $"(await {thisTaskParameter.Name}.ConfigureAwait(false)).{matchMethodName}({caseParameters});"); - - void WriteBodyForAsyncTaskExtension(string matchMethodName) => builder.WriteLine( - $"await (await {thisTaskParameter.Name}.ConfigureAwait(false)).{matchMethodName}({caseParameters}).ConfigureAwait(false);"); - - GenerateMatchMethod(builder, unionTypeSchema, "T"); - BlankLine(); - GenerateMatchMethod(builder, unionTypeSchema, "global::System.Threading.Tasks.Task"); - BlankLine(); - - WriteMatchSignature(builder, unionTypeSchema, thisTaskParameter, "global::System.Threading.Tasks.Task", "T", "public static async"); - WriteBodyForTaskExtension(MatchMethodName); - BlankLine(); - WriteMatchSignature(builder, unionTypeSchema, thisTaskParameter, "global::System.Threading.Tasks.Task", handlerReturnType: "global::System.Threading.Tasks.Task", - "public static async"); - WriteBodyForAsyncTaskExtension(MatchMethodName); - BlankLine(); - - GenerateSwitchMethod(builder, unionTypeSchema, false); - BlankLine(); - GenerateSwitchMethod(builder, unionTypeSchema, true); - BlankLine(); - WriteSwitchSignature(builder: builder, unionTypeSchema: unionTypeSchema, thisParameter: thisTaskParameter, - isAsync: false, asyncReturn: true, lambda: true); - WriteBodyForTaskExtension(VoidMatchMethodName); - BlankLine(); - WriteSwitchSignature(builder: builder, unionTypeSchema: unionTypeSchema, thisParameter: thisTaskParameter, - isAsync: true, lambda: true); - WriteBodyForAsyncTaskExtension(VoidMatchMethodName); - } + { + using (builder.StaticPartialClass($"{unionTypeSchema.TypeName.Replace(".", "_")}MatchExtension", + unionTypeSchema.IsInternal ? "internal" : "public")) + { + var thisTaskParameter = ThisParameter(unionTypeSchema, $"global::System.Threading.Tasks.Task<{unionTypeSchema.FullTypeNameWithTypeParameters}>"); + var caseParameters = unionTypeSchema.Cases.Select(c => c.ParameterName).ToSeparatedString(); - return; + void WriteBodyForTaskExtension(string matchMethodName) => builder.WriteLine( + $"(await {thisTaskParameter.Name}.ConfigureAwait(false)).{matchMethodName}({caseParameters});"); - void BlankLine() - { - builder.WriteLine(""); - } - } + void WriteBodyForAsyncTaskExtension(string matchMethodName) => builder.WriteLine( + $"await (await {thisTaskParameter.Name}.ConfigureAwait(false)).{matchMethodName}({caseParameters}).ConfigureAwait(false);"); - private static void WritePartialWithMatchMethods(UnionTypeSchema unionTypeSchema, CSharpBuilder builder) - { - var typeParameters = RoslynExtensions.FormatTypeParameters(unionTypeSchema.TypeParameters); - var unusedTypeParameter = GetUnusedTypeParameter(unionTypeSchema.TypeParameters); - var typeKind = GetTypeKind(unionTypeSchema); - builder.WriteLine($"{(unionTypeSchema.Modifiers.ToSeparatedString(" "))} {typeKind} {unionTypeSchema.TypeName}{typeParameters}"); - using (builder.Scope()) - { - GenerateMatchMethod(builder, unionTypeSchema, returnType: unusedTypeParameter, t: unusedTypeParameter, asExtension: false); + var typeParameter = GetUnusedTypeParameter(unionTypeSchema.TypeParameters); + var taskOfTypeParameter = $"global::System.Threading.Tasks.Task<{typeParameter}>"; + + GenerateMatchMethod(builder, unionTypeSchema, typeParameter, typeParameter); BlankLine(); - GenerateSwitchMethod(builder, unionTypeSchema, isAsync: false, asExtension: false); + GenerateMatchMethod(builder, unionTypeSchema, taskOfTypeParameter, typeParameter); BlankLine(); - GenerateSwitchMethod(builder, unionTypeSchema, isAsync: true, asExtension: false); + + WriteMatchSignature(builder, unionTypeSchema, thisTaskParameter, taskOfTypeParameter, typeParameter, "public static async", typeParameter); + WriteBodyForTaskExtension(MatchMethodName); + BlankLine(); + WriteMatchSignature(builder, unionTypeSchema, thisTaskParameter, taskOfTypeParameter, handlerReturnType: taskOfTypeParameter, "public static async", typeParameter); + WriteBodyForAsyncTaskExtension(MatchMethodName); + BlankLine(); + + GenerateSwitchMethod(builder, unionTypeSchema, false); + BlankLine(); + GenerateSwitchMethod(builder, unionTypeSchema, true); + BlankLine(); + WriteSwitchSignature(builder: builder, unionTypeSchema: unionTypeSchema, thisParameter: thisTaskParameter, + isAsync: false, asyncReturn: true, lambda: true); + WriteBodyForTaskExtension(VoidMatchMethodName); + BlankLine(); + WriteSwitchSignature(builder: builder, unionTypeSchema: unionTypeSchema, thisParameter: thisTaskParameter, + isAsync: true, lambda: true); + WriteBodyForAsyncTaskExtension(VoidMatchMethodName); } return; void BlankLine() { - builder.WriteLine(""); + builder.WriteLine(""); } } - private static void WritePartialWithStaticFactories(UnionTypeSchema unionTypeSchema, CSharpBuilder builder) - { - var info = unionTypeSchema.StaticFactoryInfo!; + static void WritePartialWithStaticFactories(UnionTypeSchema unionTypeSchema, CSharpBuilder builder) + { + var info = unionTypeSchema.StaticFactoryInfo!; var typeParameters = RoslynExtensions.FormatTypeParameters(unionTypeSchema.TypeParameters); var typeKind = GetTypeKind(unionTypeSchema); builder.WriteLine($"{(unionTypeSchema.Modifiers.ToSeparatedString(" "))} {typeKind} {unionTypeSchema.TypeName}{typeParameters}"); - using (builder.Scope()) - { - foreach (var derivedType in unionTypeSchema.Cases) - { - var nameParts = derivedType.FullTypeName.Split('.'); - var derivedTypeName = nameParts[nameParts.Length - 1]; - var methodName = derivedType.StaticFactoryMethodName; - - if ($"{unionTypeSchema.FullTypeNameWithTypeParameters}.{methodName}" == derivedType.FullTypeName) //union case is nested type without underscores, so factory method name would conflict with type name - continue; - - var constructors = derivedType.Constructors; - if (constructors.Length == 0) - constructors = new[] - { - new MemberInfo($"{derivedTypeName}", - ImmutableArray.Empty.Add("public"), - ImmutableArray.Empty) - }.ToImmutableArray(); - - foreach (var constructor in constructors) - { - var isPublic = constructor.Modifiers.HasModifier(SyntaxKind.PublicKeyword); - var isInternal = !isPublic && constructor.Modifiers.HasModifier(SyntaxKind.InternalKeyword); - - if (!isInternal && !isPublic) - continue; //constructor inaccessible - - if (info.ExistingStaticFields.Contains(methodName)) - continue; //name conflict with existing field - - if (info.ExistingStaticMethods.Any(s => - s.Name == methodName && - s.Parameters.Select(p => p.Type) - .SequenceEqual(constructor.Parameters.Select(p => p.Type), SymbolEqualityComparer.Default))) - continue; //static method already exists - - var arguments = constructor.Parameters.ToSeparatedString(); - var constructorInvocation = $"new {derivedType.FullTypeName}({(constructor.Parameters.Select(p => p.Name).ToSeparatedString())})"; - builder.WriteLine($"{(isInternal ? "internal" : "public")} static {unionTypeSchema.FullTypeName}{typeParameters} {methodName}({arguments}) => {constructorInvocation};"); - } - } - } - } + using (builder.Scope()) + { + foreach (var derivedType in unionTypeSchema.Cases) + { + var nameParts = derivedType.FullTypeName.Split('.'); + var derivedTypeName = nameParts[nameParts.Length - 1]; + var methodName = derivedType.StaticFactoryMethodName; + + if ($"{unionTypeSchema.FullTypeNameWithTypeParameters}.{methodName}" == derivedType.FullTypeName) //union case is nested type without underscores, so factory method name would conflict with type name + continue; + + var constructors = derivedType.Constructors; + if (constructors.Length == 0) + constructors = new[] + { + new MemberInfo($"{derivedTypeName}", + ImmutableArray.Empty.Add("public"), + ImmutableArray.Empty) + }.ToImmutableArray(); + + foreach (var constructor in constructors) + { + var isPublic = constructor.Modifiers.HasModifier(SyntaxKind.PublicKeyword); + var isInternal = !isPublic && constructor.Modifiers.HasModifier(SyntaxKind.InternalKeyword); + + if (!isInternal && !isPublic) + continue; //constructor inaccessible + + if (info.ExistingStaticFields.Contains(methodName)) + continue; //name conflict with existing field + + if (info.ExistingStaticMethods.Any(s => + s.Name == methodName && + s.Parameters.Select(p => p.Type) + .SequenceEqual(constructor.Parameters.Select(p => p.Type), SymbolEqualityComparer.Default))) + continue; //static method already exists + + var arguments = constructor.Parameters.ToSeparatedString(); + var constructorInvocation = $"new {derivedType.FullTypeName}({(constructor.Parameters.Select(p => p.Name).ToSeparatedString())})"; + builder.WriteLine($"{(isInternal ? "internal" : "public")} static {unionTypeSchema.FullTypeName}{typeParameters} {methodName}({arguments}) => {constructorInvocation};"); + } + } + } + } private static string GetTypeKind(UnionTypeSchema unionTypeSchema) { @@ -175,77 +144,81 @@ private static string GetTypeKind(UnionTypeSchema unionTypeSchema) return typeKind; } - static void GenerateMatchMethod(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, string returnType, string t = "T", bool asExtension = true) - { - var thisParameterType = unionTypeSchema.FullTypeName; - var thisParameter = ThisParameter(unionTypeSchema, thisParameterType); - var thisParameterName = thisParameter.Name; - var thisStatement = asExtension ? thisParameterName : "this"; - WriteMatchSignature( + static void GenerateMatchMethod(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, string returnType, string t = "T") + { + var thisParameterType = unionTypeSchema.FullTypeNameWithTypeParameters; + var thisParameter = ThisParameter(unionTypeSchema, thisParameterType); + var thisParameterName = thisParameter.Name; + + WriteMatchSignature( builder: builder, unionTypeSchema: unionTypeSchema, - thisParameter: asExtension ? thisParameter : null, + thisParameter: thisParameter, returnType: returnType, - t: t, - modifiers: asExtension ? "public static" : "public"); - builder.WriteLine($"{thisStatement} switch"); - using (builder.ScopeWithSemicolon()) - { - var caseIndex = 0; - foreach (var c in unionTypeSchema.Cases) - { - caseIndex++; + t: t, + modifiers: "public static"); + builder.WriteLine($"{thisParameterName} switch"); + using (builder.ScopeWithSemicolon()) + { + var caseIndex = 0; + foreach (var c in unionTypeSchema.Cases) + { + caseIndex++; var caseVariableName = $"{c.ParameterName}{caseIndex}"; - builder.WriteLine($"{c.FullTypeName} {caseVariableName} => {c.ParameterName}({caseVariableName}),"); - } - - builder.WriteLine( - $"_ => throw new global::System.ArgumentException($\"Unknown type derived from {unionTypeSchema.FullTypeName}: {{{thisStatement}.GetType().Name}}\")"); - } - } - - static void GenerateSwitchMethod(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, bool isAsync, bool asExtension = true) - { - var thisParameterType = unionTypeSchema.FullTypeName; - var thisParameter = ThisParameter(unionTypeSchema, thisParameterType); - var thisParameterName = thisParameter.Name; - var thisStatement = asExtension ? thisParameterName : "this"; - WriteSwitchSignature(builder, unionTypeSchema, asExtension ? thisParameter : null, isAsync, modifiers: asExtension ? "public static" : "public"); - using (builder.Scope()) - { - builder.WriteLine($"switch ({thisStatement})"); - using (builder.Scope()) - { - var caseIndex = 0; - foreach (var c in unionTypeSchema.Cases) - { - caseIndex++; + builder.WriteLine($"{c.FullTypeName} {caseVariableName} => {c.ParameterName}({caseVariableName}),"); + } + + builder.WriteLine( + $"_ => throw new global::System.ArgumentException($\"Unknown type derived from {unionTypeSchema.FullTypeName}: {{{thisParameterName}.GetType().Name}}\")"); + } + } + + static void GenerateSwitchMethod(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, bool isAsync) + { + var thisParameterType = unionTypeSchema.FullTypeNameWithTypeParameters; + var thisParameter = ThisParameter(unionTypeSchema, thisParameterType); + var thisParameterName = thisParameter.Name; + WriteSwitchSignature(builder, unionTypeSchema, thisParameter, isAsync, "public static"); + using (builder.Scope()) + { + builder.WriteLine($"switch ({thisParameterName})"); + using (builder.Scope()) + { + var caseIndex = 0; + foreach (var c in unionTypeSchema.Cases) + { + caseIndex++; var caseVariableName = $"{c.ParameterName}{caseIndex}"; - builder.WriteLine($"case {c.FullTypeName} {caseVariableName}:"); - using (builder.Indent()) - { - var call = $"{c.ParameterName}({caseVariableName})"; - if (isAsync) - call = $"await {call}.ConfigureAwait(false)"; - builder.WriteLine($"{call};"); - builder.WriteLine("break;"); - } - } - - builder.WriteLine("default:"); - using (builder.Indent()) - { - builder.WriteLine($"throw new global::System.ArgumentException($\"Unknown type derived from {unionTypeSchema.FullTypeName}: {{{thisStatement}.GetType().Name}}\");"); - } - } - } + builder.WriteLine($"case {c.FullTypeName} {caseVariableName}:"); + using (builder.Indent()) + { + var call = $"{c.ParameterName}({caseVariableName})"; + if (isAsync) + call = $"await {call}.ConfigureAwait(false)"; + builder.WriteLine($"{call};"); + builder.WriteLine("break;"); + } + } + + builder.WriteLine("default:"); + using (builder.Indent()) + { + builder.WriteLine($"throw new global::System.ArgumentException($\"Unknown type derived from {unionTypeSchema.FullTypeName}: {{{thisParameterName}.GetType().Name}}\");"); + } + } + } } - private static string GetUnusedTypeParameter(EquatableArray typeParameters) + static string GetUnusedTypeParameter(EquatableArray typeParameters) { - return Enumerable.Range(0, 20) - .Select(i => Check(new string('_', i) + "TMatchResult")) - .FirstOrDefault(s => s is not null) ?? "T" + Guid.NewGuid().ToString("N"); + return + new[] { "T" } + .Concat( + Enumerable.Range(0, 20) + .Select(i => new string('_', i) + "TMatchResult") + ) + .Select(Check) + .FirstOrDefault(s => s is not null) ?? "T" + Guid.NewGuid().ToString("N"); string? Check(string typeName) { @@ -260,47 +233,49 @@ private static string GetUnusedTypeParameter(EquatableArray typeParamete static Parameter ThisParameter(UnionTypeSchema unionTypeSchema, string thisParameterType) => new($"this {thisParameterType}", unionTypeSchema.TypeName.ToParameterName()); - static void WriteMatchSignature(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, - Parameter? thisParameter, string returnType, string? handlerReturnType = null, string modifiers = "public static", string t = "T") - { - handlerReturnType ??= returnType; - var handlerParameters = unionTypeSchema.Cases - .Select(c => new Parameter($"global::System.Func<{c.FullTypeName}, {handlerReturnType}>", c.ParameterName)); + static void WriteMatchSignature(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, Parameter thisParameter, string returnType, string? handlerReturnType = null, string modifiers = "public static", string t = "T") + { + handlerReturnType ??= returnType; + var handlerParameters = unionTypeSchema.Cases + .Select(c => new Parameter($"global::System.Func<{c.FullTypeName}, {handlerReturnType}>", c.ParameterName)); + + handlerParameters = handlerParameters.Prepend(thisParameter); + + var typeParameterList = unionTypeSchema.TypeParameters.Concat([t]).ToSeparatedString(); - if (thisParameter is not null) - { - handlerParameters = handlerParameters.Prepend(thisParameter); - } builder.WriteMethodSignature( - modifiers: modifiers, - returnType: returnType, - methodName: "Match<" + t + ">", parameters: handlerParameters, - lambda: true); - } - - static void WriteSwitchSignature(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, - Parameter? thisParameter, bool isAsync, string modifiers = "public static", bool? asyncReturn = null, bool lambda = false) - { - var returnType = asyncReturn ?? isAsync ? "async global::System.Threading.Tasks.Task" : "void"; - var handlerParameters = unionTypeSchema.Cases - .Select(c => - { - var parameterType = isAsync - ? $"global::System.Func<{c.FullTypeName}, global::System.Threading.Tasks.Task>" - : $"global::System.Action<{c.FullTypeName}>"; - return new Parameter( - parameterType, - c.ParameterName); - }); + modifiers: modifiers, + returnType: returnType, + methodName: "Match<" + typeParameterList + ">", parameters: handlerParameters, + lambda: true); + } + + static void WriteSwitchSignature(CSharpBuilder builder, UnionTypeSchema unionTypeSchema, + Parameter? thisParameter, bool isAsync, string modifiers = "public static", bool? asyncReturn = null, bool lambda = false) + { + var returnType = asyncReturn ?? isAsync ? "async global::System.Threading.Tasks.Task" : "void"; + var handlerParameters = unionTypeSchema.Cases + .Select(c => + { + var parameterType = isAsync + ? $"global::System.Func<{c.FullTypeName}, global::System.Threading.Tasks.Task>" + : $"global::System.Action<{c.FullTypeName}>"; + return new Parameter( + parameterType, + c.ParameterName); + }); if (thisParameter is not null) { handlerParameters = handlerParameters.Prepend(thisParameter); } + var typeParameters = RoslynExtensions.FormatTypeParameters(unionTypeSchema.TypeParameters); + builder.WriteMethodSignature( - modifiers: modifiers, - returnType: returnType, - methodName: VoidMatchMethodName, parameters: handlerParameters, - lambda: lambda); - } + modifiers: modifiers, + returnType: returnType, + methodName: VoidMatchMethodName + typeParameters, + parameters: handlerParameters, + lambda: lambda); + } } \ No newline at end of file diff --git a/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.UnionTypeGenerator/FunicularSwitchGeneratorsConsumerGenericResultOfT_TFailureMatchExtension.g.cs b/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.UnionTypeGenerator/FunicularSwitchGeneratorsConsumerGenericResultOfT_TFailureMatchExtension.g.cs new file mode 100644 index 0000000..04d9455 --- /dev/null +++ b/Source/Tests/FunicularSwitch.Generators.Consumer/Generated/FunicularSwitch.Generators/FunicularSwitch.Generators.UnionTypeGenerator/FunicularSwitchGeneratorsConsumerGenericResultOfT_TFailureMatchExtension.g.cs @@ -0,0 +1,71 @@ +#pragma warning disable 1591 +namespace FunicularSwitch.Generators.Consumer +{ + public static partial class GenericResultMatchExtension + { + public static TMatchResult Match(this FunicularSwitch.Generators.Consumer.GenericResult genericResult, global::System.Func.Ok_, TMatchResult> ok, global::System.Func.Error_, TMatchResult> error) => + genericResult switch + { + FunicularSwitch.Generators.Consumer.GenericResult.Ok_ ok1 => ok(ok1), + FunicularSwitch.Generators.Consumer.GenericResult.Error_ error2 => error(error2), + _ => throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Generators.Consumer.GenericResult: {genericResult.GetType().Name}") + }; + + public static global::System.Threading.Tasks.Task Match(this FunicularSwitch.Generators.Consumer.GenericResult genericResult, global::System.Func.Ok_, global::System.Threading.Tasks.Task> ok, global::System.Func.Error_, global::System.Threading.Tasks.Task> error) => + genericResult switch + { + FunicularSwitch.Generators.Consumer.GenericResult.Ok_ ok1 => ok(ok1), + FunicularSwitch.Generators.Consumer.GenericResult.Error_ error2 => error(error2), + _ => throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Generators.Consumer.GenericResult: {genericResult.GetType().Name}") + }; + + public static async global::System.Threading.Tasks.Task Match(this global::System.Threading.Tasks.Task> genericResult, global::System.Func.Ok_, TMatchResult> ok, global::System.Func.Error_, TMatchResult> error) => + (await genericResult.ConfigureAwait(false)).Match(ok, error); + + public static async global::System.Threading.Tasks.Task Match(this global::System.Threading.Tasks.Task> genericResult, global::System.Func.Ok_, global::System.Threading.Tasks.Task> ok, global::System.Func.Error_, global::System.Threading.Tasks.Task> error) => + await (await genericResult.ConfigureAwait(false)).Match(ok, error).ConfigureAwait(false); + + public static void Switch(this FunicularSwitch.Generators.Consumer.GenericResult genericResult, global::System.Action.Ok_> ok, global::System.Action.Error_> error) + { + switch (genericResult) + { + case FunicularSwitch.Generators.Consumer.GenericResult.Ok_ ok1: + ok(ok1); + break; + case FunicularSwitch.Generators.Consumer.GenericResult.Error_ error2: + error(error2); + break; + default: + throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Generators.Consumer.GenericResult: {genericResult.GetType().Name}"); + } + } + + public static async global::System.Threading.Tasks.Task Switch(this FunicularSwitch.Generators.Consumer.GenericResult genericResult, global::System.Func.Ok_, global::System.Threading.Tasks.Task> ok, global::System.Func.Error_, global::System.Threading.Tasks.Task> error) + { + switch (genericResult) + { + case FunicularSwitch.Generators.Consumer.GenericResult.Ok_ ok1: + await ok(ok1).ConfigureAwait(false); + break; + case FunicularSwitch.Generators.Consumer.GenericResult.Error_ error2: + await error(error2).ConfigureAwait(false); + break; + default: + throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Generators.Consumer.GenericResult: {genericResult.GetType().Name}"); + } + } + + public static async global::System.Threading.Tasks.Task Switch(this global::System.Threading.Tasks.Task> genericResult, global::System.Action.Ok_> ok, global::System.Action.Error_> error) => + (await genericResult.ConfigureAwait(false)).Switch(ok, error); + + public static async global::System.Threading.Tasks.Task Switch(this global::System.Threading.Tasks.Task> genericResult, global::System.Func.Ok_, global::System.Threading.Tasks.Task> ok, global::System.Func.Error_, global::System.Threading.Tasks.Task> error) => + await (await genericResult.ConfigureAwait(false)).Switch(ok, error).ConfigureAwait(false); + } + + public abstract partial record GenericResult + { + public static FunicularSwitch.Generators.Consumer.GenericResult Ok(T Value) => new FunicularSwitch.Generators.Consumer.GenericResult.Ok_(Value); + public static FunicularSwitch.Generators.Consumer.GenericResult Error(TFailure Failure) => new FunicularSwitch.Generators.Consumer.GenericResult.Error_(Failure); + } +} +#pragma warning restore 1591 diff --git a/Source/Tests/FunicularSwitch.Generators.Consumer/GeneratorSpecs.cs b/Source/Tests/FunicularSwitch.Generators.Consumer/GeneratorSpecs.cs index abfe9bd..63025aa 100644 --- a/Source/Tests/FunicularSwitch.Generators.Consumer/GeneratorSpecs.cs +++ b/Source/Tests/FunicularSwitch.Generators.Consumer/GeneratorSpecs.cs @@ -17,176 +17,176 @@ namespace FunicularSwitch.Generators.Consumer; [TestClass] public class When_using_generated_result_type { - [TestMethod] - public void Then_it_feels_good() - { - static OperationResult Divide(decimal i, decimal divisor) => divisor == 0 - ? Error(Error.Generic("Division by zero")) - : i / divisor; - - OperationResult result = 42; - - var calc = result - .Bind(i => Divide(i, 0)) - .Map(i => (i * 2).ToString(CultureInfo.InvariantCulture)); - - calc.Should().BeEquivalentTo(OperationResult.Error(Error.Generic("Division by zero"))); - - var combinedError = calc.Aggregate(Error(Error.NotFound())); - var combinedErrorStatic = Aggregate(calc, Error(Error.NotFound()), (_, i) => i); - var combinedOk = Ok(42).Aggregate(Ok(" is the answer")); - var combinedOkStatic = Aggregate(Ok(42), Ok(" is the answer")); - - var transformedToInt = combinedOkStatic.As(() => Error.Generic("Unexpected type")); - - static IEnumerable IsGreaterThanFive(int i) - { - if (i <= 5) - yield return Error.Generic("To small"); - if (i == 3) - yield return Error.Generic("Uuh, it's three..."); - } - - (3.Validate(IsGreaterThanFive) is OperationResult.Error_ - { - Details: Error.Aggregated_ - { - Errors: - { - Count: 2 - } - } - }).Should().BeTrue(); - } - - [TestMethod] - public async Task Void_switches_are_generated() - { - var error = Error.Generic("This is wrong"); - - static void DoNothing(T item) { } - error.Switch( - generic: DoNothing, - notFound: DoNothing, - notAuthorized: DoNothing, - aggregated: DoNothing - ); - - static Task DoNothingAsync(T item) => Task.CompletedTask; - await error.Switch( - generic: DoNothingAsync, - notFound: DoNothingAsync, - notAuthorized: DoNothingAsync, - aggregated: DoNothingAsync - ); - } - - [TestMethod] - public async Task ExceptionsAreTurnedIntoErrors() - { - var ok = Ok(42); - - // ReSharper disable once IntDivisionByZero - var result = ok.Map(i => i / 0); - result.IsError.Should().BeTrue(); - - // ReSharper disable once IntDivisionByZero - result = await ok.Map(async i => - { - await Task.Delay(100); - return i / 0; - }); - result.IsError.Should().BeTrue(); - - 42.Validate(BuggyValidate).IsError.Should().BeTrue(); - - static IEnumerable BuggyValidate(int number) => throw new InvalidOperationException("Boom"); - - } - - [TestMethod] - public void QueryExpressionSelect() - { - Result subject = 42; - var result = - from r in subject - select r; - result.Should().BeEquivalentTo(Result.Ok(42)); - } - - [TestMethod] - public void QueryExpressionSelectMany() - { - Result ok = 42; - var error = Result.Error("fail"); - - ( - from r in ok - from r1 in error - select r1 - ).Should().BeEquivalentTo(error); - - ( - from r in error - from r1 in ok - select r1 - ).Should().BeEquivalentTo(error); - - ( - from r in ok - let x = r * 2 - from r1 in ok - select x - ).Should().BeEquivalentTo(ok.Map(r => r * 2)); - } - - [TestMethod] - public async Task QueryExpressionSelectManyAsync() - { - Task> okAsync = Task.FromResult(Result.Ok(42)); - var errorAsync = Task.FromResult(Result.Error("fail")); - - var ok = Result.Ok(1); - - (await ( - from r in okAsync - from r1 in errorAsync - select r1 - )).Should().BeEquivalentTo(await errorAsync); - - (await ( - from r in errorAsync - from r1 in okAsync - select r1 - )).Should().BeEquivalentTo(await errorAsync); - - (await ( - from r in okAsync - let x = r * 2 - from r1 in okAsync - select x - )).Should().BeEquivalentTo(await okAsync.Map(r => r * 2)); - - (await ( - from r in ok - let x = r * 2 - from r1 in okAsync - select x - )).Should().BeEquivalentTo( ok.Map(r => r * 2)); - - (await ( - from r in okAsync - let x = r * 2 - from r1 in ok - select x - )).Should().BeEquivalentTo(await okAsync.Map(r => r * 2)); - } + [TestMethod] + public void Then_it_feels_good() + { + static OperationResult Divide(decimal i, decimal divisor) => divisor == 0 + ? Error(Error.Generic("Division by zero")) + : i / divisor; + + OperationResult result = 42; + + var calc = result + .Bind(i => Divide(i, 0)) + .Map(i => (i * 2).ToString(CultureInfo.InvariantCulture)); + + calc.Should().BeEquivalentTo(OperationResult.Error(Error.Generic("Division by zero"))); + + var combinedError = calc.Aggregate(Error(Error.NotFound())); + var combinedErrorStatic = Aggregate(calc, Error(Error.NotFound()), (_, i) => i); + var combinedOk = Ok(42).Aggregate(Ok(" is the answer")); + var combinedOkStatic = Aggregate(Ok(42), Ok(" is the answer")); + + var transformedToInt = combinedOkStatic.As(() => Error.Generic("Unexpected type")); + + static IEnumerable IsGreaterThanFive(int i) + { + if (i <= 5) + yield return Error.Generic("To small"); + if (i == 3) + yield return Error.Generic("Uuh, it's three..."); + } + + (3.Validate(IsGreaterThanFive) is OperationResult.Error_ + { + Details: Error.Aggregated_ + { + Errors: + { + Count: 2 + } + } + }).Should().BeTrue(); + } + + [TestMethod] + public async Task Void_switches_are_generated() + { + var error = Error.Generic("This is wrong"); + + static void DoNothing(T item) { } + error.Switch( + generic: DoNothing, + notFound: DoNothing, + notAuthorized: DoNothing, + aggregated: DoNothing + ); + + static Task DoNothingAsync(T item) => Task.CompletedTask; + await error.Switch( + generic: DoNothingAsync, + notFound: DoNothingAsync, + notAuthorized: DoNothingAsync, + aggregated: DoNothingAsync + ); + } + + [TestMethod] + public async Task ExceptionsAreTurnedIntoErrors() + { + var ok = Ok(42); + + // ReSharper disable once IntDivisionByZero + var result = ok.Map(i => i / 0); + result.IsError.Should().BeTrue(); + + // ReSharper disable once IntDivisionByZero + result = await ok.Map(async i => + { + await Task.Delay(100); + return i / 0; + }); + result.IsError.Should().BeTrue(); + + 42.Validate(BuggyValidate).IsError.Should().BeTrue(); + + static IEnumerable BuggyValidate(int number) => throw new InvalidOperationException("Boom"); + + } + + [TestMethod] + public void QueryExpressionSelect() + { + Result subject = 42; + var result = + from r in subject + select r; + result.Should().BeEquivalentTo(Result.Ok(42)); + } + + [TestMethod] + public void QueryExpressionSelectMany() + { + Result ok = 42; + var error = Result.Error("fail"); + + ( + from r in ok + from r1 in error + select r1 + ).Should().BeEquivalentTo(error); + + ( + from r in error + from r1 in ok + select r1 + ).Should().BeEquivalentTo(error); + + ( + from r in ok + let x = r * 2 + from r1 in ok + select x + ).Should().BeEquivalentTo(ok.Map(r => r * 2)); + } + + [TestMethod] + public async Task QueryExpressionSelectManyAsync() + { + Task> okAsync = Task.FromResult(Result.Ok(42)); + var errorAsync = Task.FromResult(Result.Error("fail")); + + var ok = Result.Ok(1); + + (await ( + from r in okAsync + from r1 in errorAsync + select r1 + )).Should().BeEquivalentTo(await errorAsync); + + (await ( + from r in errorAsync + from r1 in okAsync + select r1 + )).Should().BeEquivalentTo(await errorAsync); + + (await ( + from r in okAsync + let x = r * 2 + from r1 in okAsync + select x + )).Should().BeEquivalentTo(await okAsync.Map(r => r * 2)); + + (await ( + from r in ok + let x = r * 2 + from r1 in okAsync + select x + )).Should().BeEquivalentTo(ok.Map(r => r * 2)); + + (await ( + from r in okAsync + let x = r * 2 + from r1 in ok + select x + )).Should().BeEquivalentTo(await okAsync.Map(r => r * 2)); + } [TestMethod] public void TestFactoryMethodsForClassesWithPrimaryConstructor() { var x = WithPrimaryConstructor.Derived("Hallo", 42); - Console.WriteLine($"Created {x.Match(d => $"{d.Message} {d.Test}")}"); + Console.WriteLine($"Created {x.Match(d => $"{d.Message} {d.Test}")}"); } [TestMethod] @@ -208,84 +208,84 @@ abstract partial class OperationResult public static partial class ErrorExtension { - [MergeError] - public static Error MergeErrors(this Error error, Error other) => error.Merge(other); + [MergeError] + public static Error MergeErrors(this Error error, Error other) => error.Merge(other); - [MergeError] - public static string MergeErrors(this string error, string other) => $"{error}{Environment.NewLine}{other}"; + [MergeError] + public static string MergeErrors(this string error, string other) => $"{error}{Environment.NewLine}{other}"; - [ExceptionToError] - public static string UnexpectedToStringError(Exception exception) => $"Unexpected error occurred: {exception}"; + [ExceptionToError] + public static string UnexpectedToStringError(Exception exception) => $"Unexpected error occurred: {exception}"; } [UnionType(CaseOrder = CaseOrder.AsDeclared)] public abstract partial class Error { - [ExceptionToError] - public static Error Generic(Exception exception) => Generic(exception.ToString()); - - public Error Merge(Error other) => this is Aggregated_ a - ? a.Add(other) - : other is Aggregated_ oa - ? oa.Add(this) - : Aggregated(ImmutableList.Create(this, other)); - - public class Generic_ : Error - { - public string Message { get; } - - public Generic_(string message) : base(UnionCases.Generic) - { - Message = message; - } - } - - public class NotFound_ : Error - { - public NotFound_() : base(UnionCases.NotFound) - { - } - } - - public class NotAuthorized_ : Error - { - public NotAuthorized_() : base(UnionCases.NotAuthorized) - { - } - } - - public class Aggregated_ : Error - { - public ImmutableList Errors { get; } - - public Aggregated_(ImmutableList errors) : base(UnionCases.Aggregated) => Errors = errors; - - public Error Add(Error other) => Aggregated(Errors.Add(other)); - } - - internal enum UnionCases - { - Generic, - NotFound, - NotAuthorized, - Aggregated - } - - internal UnionCases UnionCase { get; } - Error(UnionCases unionCase) => UnionCase = unionCase; - - public override string ToString() => Enum.GetName(typeof(UnionCases), UnionCase) ?? UnionCase.ToString(); - bool Equals(Error other) => UnionCase == other.UnionCase; - - public override bool Equals(object? obj) - { - if (ReferenceEquals(null, obj)) return false; - if (ReferenceEquals(this, obj)) return true; - if (obj.GetType() != GetType()) return false; - return Equals((Error)obj); - } - - public override int GetHashCode() => (int)UnionCase; + [ExceptionToError] + public static Error Generic(Exception exception) => Generic(exception.ToString()); + + public Error Merge(Error other) => this is Aggregated_ a + ? a.Add(other) + : other is Aggregated_ oa + ? oa.Add(this) + : Aggregated(ImmutableList.Create(this, other)); + + public class Generic_ : Error + { + public string Message { get; } + + public Generic_(string message) : base(UnionCases.Generic) + { + Message = message; + } + } + + public class NotFound_ : Error + { + public NotFound_() : base(UnionCases.NotFound) + { + } + } + + public class NotAuthorized_ : Error + { + public NotAuthorized_() : base(UnionCases.NotAuthorized) + { + } + } + + public class Aggregated_ : Error + { + public ImmutableList Errors { get; } + + public Aggregated_(ImmutableList errors) : base(UnionCases.Aggregated) => Errors = errors; + + public Error Add(Error other) => Aggregated(Errors.Add(other)); + } + + internal enum UnionCases + { + Generic, + NotFound, + NotAuthorized, + Aggregated + } + + internal UnionCases UnionCase { get; } + Error(UnionCases unionCase) => UnionCase = unionCase; + + public override string ToString() => Enum.GetName(typeof(UnionCases), UnionCase) ?? UnionCase.ToString(); + bool Equals(Error other) => UnionCase == other.UnionCase; + + public override bool Equals(object? obj) + { + if (ReferenceEquals(null, obj)) return false; + if (ReferenceEquals(this, obj)) return true; + if (obj.GetType() != GetType()) return false; + return Equals((Error)obj); + } + + public override int GetHashCode() => (int)UnionCase; } [UnionType] @@ -351,3 +351,46 @@ static string Print(CardType cardType) => ); } } + +[UnionType(CaseOrder = CaseOrder.AsDeclared)] +public abstract partial record GenericResult(bool IsOk) +{ + public record Ok_(T Value) : GenericResult(true); + public record Error_(TFailure Failure) : GenericResult(false); +} + +[TestClass] +public class TestGenericResult +{ + [TestMethod] + public async Task MatchIt() + { + var okResult = GenericResult.Ok(42); + var errorResult = GenericResult.Error("Ups..."); + var taskOkResult = Task.FromResult(okResult); + var taskErrorResult = Task.FromResult(errorResult); + + okResult.Match(ok => ok.Value, _ => 0).Should().Be(42); + (await taskOkResult.Match(ok => ok.Value, _ => 0)).Should().Be(42); + errorResult.Match(ok => ok.Value, _ => 0).Should().Be(0); + (await taskErrorResult.Match(ok => ok.Value, _ => 0)).Should().Be(0); + (await taskErrorResult.Match(ok => Task.FromResult(ok.Value), _ => Task.FromResult(0))).Should().Be(0); + + okResult.Switch(ok => Assert.AreEqual(ok.Value, 42), err => Assert.AreEqual("Ups...", err.Failure)); + await okResult.Switch(ok => AreEqualAsync(ok.Value, 42), err => AreEqualAsync("Ups...", err.Failure)); + + var okOrDefault = await okResult.Match(ok => Task.FromResult(ok.Value), _ => Task.FromResult(0)); + okOrDefault.Should().Be(42); + + var errorOrDefault = await errorResult.Match(ok => Task.FromResult(ok.Value), _ => Task.FromResult(0)); + errorOrDefault.Should().Be(0); + + return; + + static async Task AreEqualAsync(T expected, T actual) + { + await Task.Delay(10); + Assert.AreEqual(expected, actual); + } + } +} \ No newline at end of file diff --git a/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#FunicularSwitchTestBaseTypeOfTMatchExtension.g.verified.cs b/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#FunicularSwitchTestBaseTypeOfTMatchExtension.g.verified.cs index 221c58b..a80ec43 100644 --- a/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#FunicularSwitchTestBaseTypeOfTMatchExtension.g.verified.cs +++ b/Source/Tests/FunicularSwitch.Generators.Test/Snapshots/Run_union_type_generator.For_union_type_with_generic_base_class#FunicularSwitchTestBaseTypeOfTMatchExtension.g.verified.cs @@ -2,49 +2,71 @@ #pragma warning disable 1591 namespace FunicularSwitch.Test { - public abstract partial record BaseType + public static partial class BaseTypeMatchExtension { - public TMatchResult Match(global::System.Func.Deriving, TMatchResult> deriving, global::System.Func.Deriving2, TMatchResult> deriving2) => - this switch + public static TMatchResult Match(this FunicularSwitch.Test.BaseType baseType, global::System.Func.Deriving_, TMatchResult> deriving, global::System.Func.Deriving2_, TMatchResult> deriving2) => + baseType switch + { + FunicularSwitch.Test.BaseType.Deriving_ deriving1 => deriving(deriving1), + FunicularSwitch.Test.BaseType.Deriving2_ deriving22 => deriving2(deriving22), + _ => throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {baseType.GetType().Name}") + }; + + public static global::System.Threading.Tasks.Task Match(this FunicularSwitch.Test.BaseType baseType, global::System.Func.Deriving_, global::System.Threading.Tasks.Task> deriving, global::System.Func.Deriving2_, global::System.Threading.Tasks.Task> deriving2) => + baseType switch { - FunicularSwitch.Test.BaseType.Deriving deriving1 => deriving(deriving1), - FunicularSwitch.Test.BaseType.Deriving2 deriving22 => deriving2(deriving22), - _ => throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {this.GetType().Name}") + FunicularSwitch.Test.BaseType.Deriving_ deriving1 => deriving(deriving1), + FunicularSwitch.Test.BaseType.Deriving2_ deriving22 => deriving2(deriving22), + _ => throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {baseType.GetType().Name}") }; - public void Switch(global::System.Action.Deriving> deriving, global::System.Action.Deriving2> deriving2) + public static async global::System.Threading.Tasks.Task Match(this global::System.Threading.Tasks.Task> baseType, global::System.Func.Deriving_, TMatchResult> deriving, global::System.Func.Deriving2_, TMatchResult> deriving2) => + (await baseType.ConfigureAwait(false)).Match(deriving, deriving2); + + public static async global::System.Threading.Tasks.Task Match(this global::System.Threading.Tasks.Task> baseType, global::System.Func.Deriving_, global::System.Threading.Tasks.Task> deriving, global::System.Func.Deriving2_, global::System.Threading.Tasks.Task> deriving2) => + await (await baseType.ConfigureAwait(false)).Match(deriving, deriving2).ConfigureAwait(false); + + public static void Switch(this FunicularSwitch.Test.BaseType baseType, global::System.Action.Deriving_> deriving, global::System.Action.Deriving2_> deriving2) { - switch (this) + switch (baseType) { - case FunicularSwitch.Test.BaseType.Deriving deriving1: + case FunicularSwitch.Test.BaseType.Deriving_ deriving1: deriving(deriving1); break; - case FunicularSwitch.Test.BaseType.Deriving2 deriving22: + case FunicularSwitch.Test.BaseType.Deriving2_ deriving22: deriving2(deriving22); break; default: - throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {this.GetType().Name}"); + throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {baseType.GetType().Name}"); } } - public async global::System.Threading.Tasks.Task Switch(global::System.Func.Deriving, global::System.Threading.Tasks.Task> deriving, global::System.Func.Deriving2, global::System.Threading.Tasks.Task> deriving2) + public static async global::System.Threading.Tasks.Task Switch(this FunicularSwitch.Test.BaseType baseType, global::System.Func.Deriving_, global::System.Threading.Tasks.Task> deriving, global::System.Func.Deriving2_, global::System.Threading.Tasks.Task> deriving2) { - switch (this) + switch (baseType) { - case FunicularSwitch.Test.BaseType.Deriving deriving1: + case FunicularSwitch.Test.BaseType.Deriving_ deriving1: await deriving(deriving1).ConfigureAwait(false); break; - case FunicularSwitch.Test.BaseType.Deriving2 deriving22: + case FunicularSwitch.Test.BaseType.Deriving2_ deriving22: await deriving2(deriving22).ConfigureAwait(false); break; default: - throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {this.GetType().Name}"); + throw new global::System.ArgumentException($"Unknown type derived from FunicularSwitch.Test.BaseType: {baseType.GetType().Name}"); } } + + public static async global::System.Threading.Tasks.Task Switch(this global::System.Threading.Tasks.Task> baseType, global::System.Action.Deriving_> deriving, global::System.Action.Deriving2_> deriving2) => + (await baseType.ConfigureAwait(false)).Switch(deriving, deriving2); + + public static async global::System.Threading.Tasks.Task Switch(this global::System.Threading.Tasks.Task> baseType, global::System.Func.Deriving_, global::System.Threading.Tasks.Task> deriving, global::System.Func.Deriving2_, global::System.Threading.Tasks.Task> deriving2) => + await (await baseType.ConfigureAwait(false)).Switch(deriving, deriving2).ConfigureAwait(false); } public abstract partial record BaseType { + public static FunicularSwitch.Test.BaseType Deriving(string Value, T Other) => new FunicularSwitch.Test.BaseType.Deriving_(Value, Other); + public static FunicularSwitch.Test.BaseType Deriving2(string Value) => new FunicularSwitch.Test.BaseType.Deriving2_(Value); } } #pragma warning restore 1591 diff --git a/Source/Tests/FunicularSwitch.Generators.Test/UnionTypeGeneratorSpecs.cs b/Source/Tests/FunicularSwitch.Generators.Test/UnionTypeGeneratorSpecs.cs index 6197a7e..5580d6e 100644 --- a/Source/Tests/FunicularSwitch.Generators.Test/UnionTypeGeneratorSpecs.cs +++ b/Source/Tests/FunicularSwitch.Generators.Test/UnionTypeGeneratorSpecs.cs @@ -434,9 +434,9 @@ namespace FunicularSwitch.Test; [UnionType(CaseOrder = CaseOrder.AsDeclared)] public abstract partial record BaseType(string Value) { - public sealed record Deriving(string Value, T Other) : BaseType(Value); + public sealed record Deriving_(string Value, T Other) : BaseType(Value); - public sealed record Deriving2(string Value) : BaseType(Value); + public sealed record Deriving2_(string Value) : BaseType(Value); } """; From 73042755e8b28b58725111ad6513278bc7d360eb Mon Sep 17 00:00:00 2001 From: Alexander Wiedemann Date: Tue, 28 Jan 2025 10:31:47 +0100 Subject: [PATCH 5/6] - doc for generic union types --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 9561a94..6f75e8b 100644 --- a/README.md +++ b/README.md @@ -361,6 +361,8 @@ class ExampleConsumer } ``` +Base types of unions may also be generic types with arbitrary number of type parameters. Case types with generic arguments are not yet supported. + If you like union types but don't like excessive typing in C# try the [Switchyard](https://github.com/bluehands/Switchyard) Visual Studio extension, which generates the boilerplate code for you. It plays nicely with the FunicularSwitch.Generators package. ## ExtendedEnum attribute From 71bae0027d9464d16c4bb6a572ae02a10685ef70 Mon Sep 17 00:00:00 2001 From: Alexander Wiedemann Date: Tue, 28 Jan 2025 10:37:02 +0100 Subject: [PATCH 6/6] - increase feature version --- .../FunicularSwitch.Generators.FluentAssertions.csproj | 2 +- .../FunicularSwitch.Generators.csproj | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Source/FunicularSwitch.Generators.FluentAssertions/FunicularSwitch.Generators.FluentAssertions.csproj b/Source/FunicularSwitch.Generators.FluentAssertions/FunicularSwitch.Generators.FluentAssertions.csproj index 9a2e467..b2514a9 100644 --- a/Source/FunicularSwitch.Generators.FluentAssertions/FunicularSwitch.Generators.FluentAssertions.csproj +++ b/Source/FunicularSwitch.Generators.FluentAssertions/FunicularSwitch.Generators.FluentAssertions.csproj @@ -23,7 +23,7 @@ 1 - 2.2 + 3.0 $(MajorVersion).0.0 diff --git a/Source/FunicularSwitch.Generators/FunicularSwitch.Generators.csproj b/Source/FunicularSwitch.Generators/FunicularSwitch.Generators.csproj index 7e05a80..808ac38 100644 --- a/Source/FunicularSwitch.Generators/FunicularSwitch.Generators.csproj +++ b/Source/FunicularSwitch.Generators/FunicularSwitch.Generators.csproj @@ -23,7 +23,7 @@ 4 - 1.3 + 2.0 $(MajorVersion).0.0