diff --git a/Bonsai.Sgen/CSharpClassCodeArtifact.cs b/Bonsai.Sgen/CSharpClassCodeArtifact.cs new file mode 100644 index 0000000..3d77eef --- /dev/null +++ b/Bonsai.Sgen/CSharpClassCodeArtifact.cs @@ -0,0 +1,20 @@ +using NJsonSchema.CodeGeneration; + +namespace Bonsai.Sgen +{ + internal class CSharpClassCodeArtifact : CodeArtifact + { + public CSharpClassCodeArtifact(CSharpClassTemplateModel model, ITemplate template) : base( + model.ClassName, + model.BaseClassName, + CodeArtifactType.Class, + CodeArtifactLanguage.CSharp, + CodeArtifactCategory.Contract, + template) + { + Model = model; + } + + public CSharpClassTemplateModel Model { get; } + } +} diff --git a/Bonsai.Sgen/CSharpCodeDomGenerator.cs b/Bonsai.Sgen/CSharpCodeDomGenerator.cs index 67b7cd2..e279c61 100644 --- a/Bonsai.Sgen/CSharpCodeDomGenerator.cs +++ b/Bonsai.Sgen/CSharpCodeDomGenerator.cs @@ -52,7 +52,7 @@ private CodeArtifact GenerateClass(JsonSchema schema, string typeName) { var model = new CSharpClassTemplateModel(typeName, Settings, _resolver, schema, RootObject); var template = new CSharpClassTemplate(model, _provider, _options, Settings); - return new CodeArtifact(typeName, model.BaseClassName, CodeArtifactType.Class, CodeArtifactLanguage.CSharp, CodeArtifactCategory.Contract, template); + return new CSharpClassCodeArtifact(model, template); } private CodeArtifact GenerateClass(CSharpCodeDomTemplate template) @@ -83,10 +83,16 @@ public override IEnumerable GenerateTypes() var types = base.GenerateTypes(); var extraTypes = new List(); var schema = (JsonSchema)RootObject; - var classTypes = types - .Where(type => type.Type == CodeArtifactType.Class) - .ExceptBy(new[] { nameof(JsonInheritanceAttribute), nameof(JsonInheritanceConverter) }, r => r.TypeName) - .ToList(); + var classTypes = (from type in types + let classType = type as CSharpClassCodeArtifact + where classType != null + select classType).ToList(); + var discriminatorTypes = classTypes.Where(modelType => modelType.Model.HasDiscriminator).ToList(); + foreach (var type in discriminatorTypes) + { + var matchTemplate = new CSharpTypeMatchTemplate(type, _provider, _options, Settings); + extraTypes.Add(GenerateClass(matchTemplate)); + } if (Settings.SerializerLibraries.HasFlag(SerializerLibraries.NewtonsoftJson)) { var serializer = new CSharpJsonSerializerTemplate(classTypes, _provider, _options, Settings); @@ -96,7 +102,6 @@ public override IEnumerable GenerateTypes() } if (Settings.SerializerLibraries.HasFlag(SerializerLibraries.YamlDotNet)) { - var discriminatorTypes = classTypes.Where(modelType => modelType.Code.Contains("YamlDiscriminator")).ToList(); if (discriminatorTypes.Count > 0) { var discriminator = new CSharpYamlDiscriminatorTemplate(_provider, _options, Settings); diff --git a/Bonsai.Sgen/CSharpTypeMatchTemplate.cs b/Bonsai.Sgen/CSharpTypeMatchTemplate.cs new file mode 100644 index 0000000..1f1cc48 --- /dev/null +++ b/Bonsai.Sgen/CSharpTypeMatchTemplate.cs @@ -0,0 +1,92 @@ +using System.CodeDom; +using System.CodeDom.Compiler; +using System.ComponentModel; +using System.Xml.Serialization; + +namespace Bonsai.Sgen +{ + internal class CSharpTypeMatchTemplate : CSharpCodeDomTemplate + { + public CSharpTypeMatchTemplate( + CSharpClassCodeArtifact modelType, + CodeDomProvider provider, + CodeGeneratorOptions options, + CSharpCodeDomGeneratorSettings settings) + : base(provider, options, settings) + { + ModelType = modelType; + } + + public CSharpClassCodeArtifact ModelType { get; } + + public override string TypeName => $"Match{ModelType.TypeName}"; + + public override void BuildType(CodeTypeDeclaration type) + { + type.BaseTypes.Add(new CodeTypeReference("Bonsai.Expressions.SingleArgumentExpressionBuilder")); + type.CustomAttributes.Add(new CodeAttributeDeclaration( + new CodeTypeReference(typeof(DefaultPropertyAttribute)), + new CodeAttributeArgument(new CodePrimitiveExpression("Type")))); + type.CustomAttributes.Add(new CodeAttributeDeclaration( + new CodeTypeReference("Bonsai.WorkflowElementCategoryAttribute"), + new CodeAttributeArgument(new CodeFieldReferenceExpression( + new CodeTypeReferenceExpression("Bonsai.ElementCategory"), + "Combinator")))); + foreach (var modelType in ModelType.Model.DerivedClasses) + { + type.CustomAttributes.Add(new CodeAttributeDeclaration( + new CodeTypeReference(typeof(XmlIncludeAttribute)), + new CodeAttributeArgument(new CodeTypeOfExpression( + new CodeTypeReference( + "Bonsai.Expressions.TypeMapping", + new CodeTypeReference(modelType.ClassName)))))); + } + + type.Members.Add(new CodeSnippetTypeMember( +@$" public Bonsai.Expressions.TypeMapping Type {{ get; set; }} + + public override System.Linq.Expressions.Expression Build(System.Collections.Generic.IEnumerable arguments) + {{ + var typeMapping = Type; + var returnType = typeMapping != null ? typeMapping.GetType().GetGenericArguments()[0] : typeof({ModelType.TypeName}); + return System.Linq.Expressions.Expression.Call( + typeof({TypeName}), + ""Process"", + new System.Type[] {{ returnType }}, + System.Linq.Enumerable.Single(arguments)); + }} +")); + var sourceTypeReference = new CodeTypeReference(ModelType.TypeName); + var genericTypeParameter = new CodeTypeParameter("TResult") { Constraints = { sourceTypeReference } }; + var sourceParameter = new CodeParameterDeclarationExpression( + new CodeTypeReference(typeof(IObservable<>)) { TypeArguments = { sourceTypeReference } }, "source"); + type.Members.Add(new CodeMemberMethod + { + Name = "Process", + Attributes = MemberAttributes.Private | MemberAttributes.Static, + TypeParameters = { genericTypeParameter }, + Parameters = { sourceParameter }, + ReturnType = new CodeTypeReference(typeof(IObservable<>)) + { + TypeArguments = { new CodeTypeReference(genericTypeParameter) } + }, + Statements = + { + new CodeExpressionStatement(new CodeSnippetExpression( +@$"return System.Reactive.Linq.Observable.Create<{genericTypeParameter.Name}>(observer => + {{ + var sourceObserver = System.Reactive.Observer.Create<{ModelType.TypeName}>( + value => + {{ + var match = value as {genericTypeParameter.Name}; + if (match != null) observer.OnNext(match); + }}, + observer.OnError, + observer.OnCompleted); + return System.ObservableExtensions.SubscribeSafe(source, sourceObserver); + }})")) + } + }); + } + } +}