diff --git a/src/MediatR.Contracts/MediatR.Contracts.csproj b/src/MediatR.Contracts/MediatR.Contracts.csproj index a311f618..38183c67 100644 --- a/src/MediatR.Contracts/MediatR.Contracts.csproj +++ b/src/MediatR.Contracts/MediatR.Contracts.csproj @@ -6,7 +6,6 @@ Contracts package for requests, responses, and notifications Copyright Jimmy Bogard netstandard2.0 - enable strict mediator;request;response;queries;commands;notifications true diff --git a/src/MediatR/Entities/OpenBehavior.cs b/src/MediatR/Entities/OpenBehavior.cs new file mode 100644 index 00000000..e7201da3 --- /dev/null +++ b/src/MediatR/Entities/OpenBehavior.cs @@ -0,0 +1,20 @@ +using System; +using Microsoft.Extensions.DependencyInjection; + +namespace MediatR.Entities; +/// +/// Creates open behavior entity. +/// +public class OpenBehavior +{ + public OpenBehavior(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + OpenBehaviorType = openBehaviorType; + ServiceLifetime = serviceLifetime; + } + + public Type? OpenBehaviorType { get; } + public ServiceLifetime ServiceLifetime { get; } + + +} \ No newline at end of file diff --git a/src/MediatR/IPipelineBehavior.cs b/src/MediatR/IPipelineBehavior.cs index be6b1cb3..142ec707 100644 --- a/src/MediatR/IPipelineBehavior.cs +++ b/src/MediatR/IPipelineBehavior.cs @@ -9,7 +9,7 @@ namespace MediatR; /// /// Response type /// Awaitable task returning a -public delegate Task RequestHandlerDelegate(); +public delegate Task RequestHandlerDelegate(CancellationToken t = default); /// /// Pipeline behavior to surround the inner handler. diff --git a/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs b/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs index bc785f91..f4bbf090 100644 --- a/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs +++ b/src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs @@ -3,6 +3,7 @@ using System.Linq; using System.Reflection; using MediatR; +using MediatR.Entities; using MediatR.NotificationPublishers; using MediatR.Pipeline; using MediatR.Registration; @@ -15,7 +16,7 @@ public class MediatRServiceConfiguration /// Optional filter for types to register. Default value is a function returning true. /// public Func TypeEvaluator { get; set; } = t => true; - + /// /// Mediator implementation type to register. Default is /// @@ -69,31 +70,6 @@ public class MediatRServiceConfiguration /// public bool AutoRegisterRequestProcessors { get; set; } - /// - /// Configure the maximum number of type parameters that a generic request handler can have. To Disable this constraint, set the value to 0. - /// - public int MaxGenericTypeParameters { get; set; } = 10; - - /// - /// Configure the maximum number of types that can close a generic request type parameter constraint. To Disable this constraint, set the value to 0. - /// - public int MaxTypesClosing { get; set; } = 100; - - /// - /// Configure the Maximum Amount of Generic RequestHandler Types MediatR will try to register. To Disable this constraint, set the value to 0. - /// - public int MaxGenericTypeRegistrations { get; set; } = 125000; - - /// - /// Configure the Timeout in Milliseconds that the GenericHandler Registration Process will exit with error. To Disable this constraint, set the value to 0. - /// - public int RegistrationTimeout { get; set; } = 15000; - - /// - /// Flag that controlls whether MediatR will attempt to register handlers that containg generic type parameters. - /// - public bool RegisterGenericHandlers { get; set; } = false; - /// /// Register various handlers from assembly containing given type /// @@ -222,6 +198,37 @@ public MediatRServiceConfiguration AddOpenBehavior(Type openBehaviorType, Servic return this; } + /// + /// Registers multiple open behavior types against the open generic interface type + /// + /// An open generic behavior type list includes multiple open generic behavior types. + /// Optional service lifetime, defaults to . + /// This + public MediatRServiceConfiguration AddOpenBehaviors(IEnumerable openBehaviorTypes, ServiceLifetime serviceLifetime = ServiceLifetime.Transient) + { + foreach (var openBehaviorType in openBehaviorTypes) + { + AddOpenBehavior(openBehaviorType, serviceLifetime); + } + + return this; + } + + /// + /// Registers open behaviors against the open generic interface type + /// + /// An open generic behavior list includes multiple open generic behaviors. + /// This + public MediatRServiceConfiguration AddOpenBehaviors(IEnumerable openBehaviors) + { + foreach (var openBehavior in openBehaviors) + { + AddOpenBehavior(openBehavior.OpenBehaviorType!, openBehavior.ServiceLifetime); + } + + return this; + } + /// /// Register a closed stream behavior type /// @@ -231,7 +238,7 @@ public MediatRServiceConfiguration AddOpenBehavior(Type openBehaviorType, Servic /// This public MediatRServiceConfiguration AddStreamBehavior(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) => AddStreamBehavior(typeof(TServiceType), typeof(TImplementationType), serviceLifetime); - + /// /// Register a closed stream behavior type /// @@ -245,7 +252,7 @@ public MediatRServiceConfiguration AddStreamBehavior(Type serviceType, Type impl return this; } - + /// /// Register a closed stream behavior type against all implementations /// @@ -254,7 +261,7 @@ public MediatRServiceConfiguration AddStreamBehavior(Type serviceType, Type impl /// This public MediatRServiceConfiguration AddStreamBehavior(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) => AddStreamBehavior(typeof(TImplementationType), serviceLifetime); - + /// /// Register a closed stream behavior type against all implementations /// @@ -277,7 +284,7 @@ public MediatRServiceConfiguration AddStreamBehavior(Type implementationType, Se return this; } - + /// /// Registers an open stream behavior type against the open generic interface type /// @@ -316,7 +323,7 @@ public MediatRServiceConfiguration AddOpenStreamBehavior(Type openBehaviorType, /// This public MediatRServiceConfiguration AddRequestPreProcessor(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) => AddRequestPreProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime); - + /// /// Register a closed request pre processor type /// @@ -360,10 +367,10 @@ public MediatRServiceConfiguration AddRequestPreProcessor(Type implementationTyp { RequestPreProcessorsToRegister.Add(new ServiceDescriptor(implementedPreProcessorType, implementationType, serviceLifetime)); } - + return this; } - + /// /// Registers an open request pre processor type against the open generic interface type /// @@ -392,7 +399,7 @@ public MediatRServiceConfiguration AddOpenRequestPreProcessor(Type openBehaviorT return this; } - + /// /// Register a closed request post processor type /// @@ -402,7 +409,7 @@ public MediatRServiceConfiguration AddOpenRequestPreProcessor(Type openBehaviorT /// This public MediatRServiceConfiguration AddRequestPostProcessor(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) => AddRequestPostProcessor(typeof(TServiceType), typeof(TImplementationType), serviceLifetime); - + /// /// Register a closed request post processor type /// @@ -416,7 +423,7 @@ public MediatRServiceConfiguration AddRequestPostProcessor(Type serviceType, Typ return this; } - + /// /// Register a closed request post processor type against all implementations /// @@ -425,7 +432,7 @@ public MediatRServiceConfiguration AddRequestPostProcessor(Type serviceType, Typ /// This public MediatRServiceConfiguration AddRequestPostProcessor(ServiceLifetime serviceLifetime = ServiceLifetime.Transient) => AddRequestPostProcessor(typeof(TImplementationType), serviceLifetime); - + /// /// Register a closed request post processor type against all implementations /// @@ -447,7 +454,7 @@ public MediatRServiceConfiguration AddRequestPostProcessor(Type implementationTy } return this; } - + /// /// Registers an open request post processor type against the open generic interface type /// diff --git a/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs b/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs index 6e211b27..328cf949 100644 --- a/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs +++ b/src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs @@ -23,7 +23,7 @@ public static class ServiceCollectionExtensions /// Service collection /// The action used to configure the options /// Service collection - public static IServiceCollection AddMediatR(this IServiceCollection services, + public static IServiceCollection AddMediatR(this IServiceCollection services, Action configuration) { var serviceConfig = new MediatRServiceConfiguration(); @@ -31,15 +31,15 @@ public static IServiceCollection AddMediatR(this IServiceCollection services, configuration.Invoke(serviceConfig); return services.AddMediatR(serviceConfig); - } - + } + /// /// Registers handlers and mediator types from the specified assemblies /// /// Service collection /// Configuration options /// Service collection - public static IServiceCollection AddMediatR(this IServiceCollection services, + public static IServiceCollection AddMediatR(this IServiceCollection services, MediatRServiceConfiguration configuration) { if (!configuration.AssembliesToRegister.Any()) @@ -47,9 +47,7 @@ public static IServiceCollection AddMediatR(this IServiceCollection services, throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers."); } - ServiceRegistrar.SetGenericRequestHandlerRegistrationLimitations(configuration); - - ServiceRegistrar.AddMediatRClassesWithTimeout(services, configuration); + ServiceRegistrar.AddMediatRClasses(services, configuration); ServiceRegistrar.AddRequiredServices(services, configuration); diff --git a/src/MediatR/Registration/ServiceRegistrar.cs b/src/MediatR/Registration/ServiceRegistrar.cs index 5d22cc8f..48106dc0 100644 --- a/src/MediatR/Registration/ServiceRegistrar.cs +++ b/src/MediatR/Registration/ServiceRegistrar.cs @@ -2,7 +2,6 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; -using System.Threading; using MediatR.Pipeline; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; @@ -10,42 +9,13 @@ namespace MediatR.Registration; public static class ServiceRegistrar -{ - private static int MaxGenericTypeParameters; - private static int MaxTypesClosing; - private static int MaxGenericTypeRegistrations; - private static int RegistrationTimeout; - - public static void SetGenericRequestHandlerRegistrationLimitations(MediatRServiceConfiguration configuration) - { - MaxGenericTypeParameters = configuration.MaxGenericTypeParameters; - MaxTypesClosing = configuration.MaxTypesClosing; - MaxGenericTypeRegistrations = configuration.MaxGenericTypeRegistrations; - RegistrationTimeout = configuration.RegistrationTimeout; - } - - public static void AddMediatRClassesWithTimeout(IServiceCollection services, MediatRServiceConfiguration configuration) - { - using(var cts = new CancellationTokenSource(RegistrationTimeout)) - { - try - { - AddMediatRClasses(services, configuration, cts.Token); - } - catch (OperationCanceledException) - { - throw new TimeoutException("The generic handler registration process timed out."); - } - } - } - - public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration, CancellationToken cancellationToken = default) - { - +{ + public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration) + { var assembliesToScan = configuration.AssembliesToRegister.Distinct().ToArray(); - ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration, cancellationToken); - ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration, cancellationToken); + ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration); + ConnectImplementationsToTypesClosing(typeof(IRequestHandler<>), services, assembliesToScan, false, configuration); ConnectImplementationsToTypesClosing(typeof(INotificationHandler<>), services, assembliesToScan, true, configuration); ConnectImplementationsToTypesClosing(typeof(IStreamRequestHandler<,>), services, assembliesToScan, false, configuration); ConnectImplementationsToTypesClosing(typeof(IRequestExceptionHandler<,,>), services, assembliesToScan, true, configuration); @@ -93,41 +63,23 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa IServiceCollection services, IEnumerable assembliesToScan, bool addIfAlreadyExists, - MediatRServiceConfiguration configuration, - CancellationToken cancellationToken = default) + MediatRServiceConfiguration configuration) { - var concretions = new List(); + var concretions = new List(); var interfaces = new List(); - var genericConcretions = new List(); - var genericInterfaces = new List(); - - var types = assembliesToScan - .SelectMany(a => a.DefinedTypes) - .Where(t => !t.ContainsGenericParameters || configuration.RegisterGenericHandlers) - .Where(t => t.IsConcrete() && t.FindInterfacesThatClose(openRequestInterface).Any()) - .Where(configuration.TypeEvaluator) - .ToList(); - - foreach (var type in types) + foreach (var type in assembliesToScan.SelectMany(a => a.DefinedTypes).Where(t => !t.IsOpenGeneric()).Where(configuration.TypeEvaluator)) { var interfaceTypes = type.FindInterfacesThatClose(openRequestInterface).ToArray(); + if (!interfaceTypes.Any()) continue; - if (!type.IsOpenGeneric()) + if (type.IsConcrete()) { concretions.Add(type); - - foreach (var interfaceType in interfaceTypes) - { - interfaces.Fill(interfaceType); - } } - else + + foreach (var interfaceType in interfaceTypes) { - genericConcretions.Add(type); - foreach (var interfaceType in interfaceTypes) - { - genericInterfaces.Fill(interfaceType); - } + interfaces.Fill(interfaceType); } } @@ -159,12 +111,6 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa AddConcretionsThatCouldBeClosed(@interface, concretions, services); } } - - foreach (var @interface in genericInterfaces) - { - var exactMatches = genericConcretions.Where(x => x.CanBeCastTo(@interface)).ToList(); - AddAllConcretionsThatClose(@interface, exactMatches, services, assembliesToScan, cancellationToken); - } } private static bool IsMatchingWithInterface(Type? handlerType, Type handlerInterface) @@ -204,117 +150,6 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List } } - private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(Type openRequestHandlerInterface, Type concreteGenericTRequest, Type openRequestHandlerImplementation) - { - var closingTypes = concreteGenericTRequest.GetGenericArguments(); - - var concreteTResponse = concreteGenericTRequest.GetInterfaces() - .FirstOrDefault(x => x.IsGenericType && x.GetGenericTypeDefinition() == typeof(IRequest<>)) - ?.GetGenericArguments() - .FirstOrDefault(); - - var typeDefinition = openRequestHandlerInterface.GetGenericTypeDefinition(); - - var serviceType = concreteTResponse != null ? - typeDefinition.MakeGenericType(concreteGenericTRequest, concreteTResponse) : - typeDefinition.MakeGenericType(concreteGenericTRequest); - - return (serviceType, openRequestHandlerImplementation.MakeGenericType(closingTypes)); - } - - private static List? GetConcreteRequestTypes(Type openRequestHandlerInterface, Type openRequestHandlerImplementation, IEnumerable assembliesToScan, CancellationToken cancellationToken) - { - //request generic type constraints - var constraintsForEachParameter = openRequestHandlerImplementation - .GetGenericArguments() - .Select(x => x.GetGenericParameterConstraints()) - .ToList(); - - var typesThatCanCloseForEachParameter = constraintsForEachParameter - .Select(constraints => assembliesToScan - .SelectMany(assembly => assembly.GetTypes()) - .Where(type => type.IsClass && !type.IsAbstract && constraints.All(constraint => constraint.IsAssignableFrom(type))).ToList() - ).ToList(); - - var requestType = openRequestHandlerInterface.GenericTypeArguments.First(); - - if (requestType.IsGenericParameter) - return null; - - var requestGenericTypeDefinition = requestType.GetGenericTypeDefinition(); - - var combinations = GenerateCombinations(requestType, typesThatCanCloseForEachParameter, 0, cancellationToken); - - return combinations.Select(types => requestGenericTypeDefinition.MakeGenericType(types.ToArray())).ToList(); - } - - // Method to generate combinations recursively - public static List> GenerateCombinations(Type requestType, List> lists, int depth = 0, CancellationToken cancellationToken = default) - { - if (depth == 0) - { - // Initial checks - if (MaxGenericTypeParameters > 0 && lists.Count > MaxGenericTypeParameters) - throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. The number of generic type parameters exceeds the maximum allowed ({MaxGenericTypeParameters})."); - - foreach (var list in lists) - { - if (MaxTypesClosing > 0 && list.Count > MaxTypesClosing) - throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. One of the generic type parameter's count of types that can close exceeds the maximum length allowed ({MaxTypesClosing})."); - } - - // Calculate the total number of combinations - long totalCombinations = 1; - foreach (var list in lists) - { - totalCombinations *= list.Count; - if (MaxGenericTypeParameters > 0 && totalCombinations > MaxGenericTypeRegistrations) - throw new ArgumentException($"Error registering the generic type: {requestType.FullName}. The total number of generic type registrations exceeds the maximum allowed ({MaxGenericTypeRegistrations})."); - } - } - - if (depth >= lists.Count) - return new List> { new List() }; - - cancellationToken.ThrowIfCancellationRequested(); - - var currentList = lists[depth]; - var childCombinations = GenerateCombinations(requestType, lists, depth + 1, cancellationToken); - var combinations = new List>(); - - foreach (var item in currentList) - { - foreach (var childCombination in childCombinations) - { - var currentCombination = new List { item }; - currentCombination.AddRange(childCombination); - combinations.Add(currentCombination); - } - } - - return combinations; - } - - private static void AddAllConcretionsThatClose(Type openRequestInterface, List concretions, IServiceCollection services, IEnumerable assembliesToScan, CancellationToken cancellationToken) - { - foreach (var concretion in concretions) - { - var concreteRequests = GetConcreteRequestTypes(openRequestInterface, concretion, assembliesToScan, cancellationToken); - - if (concreteRequests is null) - continue; - - var registrationTypes = concreteRequests - .Select(concreteRequest => GetConcreteRegistrationTypes(openRequestInterface, concreteRequest, concretion)); - - foreach (var (Service, Implementation) in registrationTypes) - { - cancellationToken.ThrowIfCancellationRequested(); - services.AddTransient(Service, Implementation); - } - } - } - internal static bool CouldCloseTo(this Type openConcretion, Type closedInterface) { var openInterface = closedInterface.GetGenericTypeDefinition(); @@ -424,8 +259,8 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi foreach (var serviceDescriptor in serviceConfiguration.BehaviorsToRegister) { services.TryAddEnumerable(serviceDescriptor); - } - + } + foreach (var serviceDescriptor in serviceConfiguration.StreamBehaviorsToRegister) { services.TryAddEnumerable(serviceDescriptor); @@ -435,7 +270,7 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi private static void RegisterBehaviorIfImplementationsExist(IServiceCollection services, Type behaviorType, Type subBehaviorType) { var hasAnyRegistrationsOfSubBehaviorType = services - .Where(service => !service.IsKeyedService) + .Where(service => !service.IsKeyedService) .Select(service => service.ImplementationType) .OfType() .SelectMany(type => type.GetInterfaces()) @@ -448,4 +283,4 @@ private static void RegisterBehaviorIfImplementationsExist(IServiceCollection se services.TryAddEnumerable(new ServiceDescriptor(typeof(IPipelineBehavior<,>), behaviorType, ServiceLifetime.Transient)); } } -} +} \ No newline at end of file diff --git a/src/MediatR/Wrappers/RequestHandlerWrapper.cs b/src/MediatR/Wrappers/RequestHandlerWrapper.cs index 1550ecf8..dc213560 100644 --- a/src/MediatR/Wrappers/RequestHandlerWrapper.cs +++ b/src/MediatR/Wrappers/RequestHandlerWrapper.cs @@ -34,14 +34,14 @@ public class RequestHandlerWrapperImpl : RequestHandlerWrap public override Task Handle(IRequest request, IServiceProvider serviceProvider, CancellationToken cancellationToken) { - Task Handler() => serviceProvider.GetRequiredService>() - .Handle((TRequest) request, cancellationToken); + Task Handler(CancellationToken t = default) => serviceProvider.GetRequiredService>() + .Handle((TRequest) request, t == default ? cancellationToken : t); return serviceProvider .GetServices>() .Reverse() .Aggregate((RequestHandlerDelegate) Handler, - (next, pipeline) => () => pipeline.Handle((TRequest) request, next, cancellationToken))(); + (next, pipeline) => (t) => pipeline.Handle((TRequest) request, next, t == default ? cancellationToken : t))(); } } @@ -55,10 +55,10 @@ public class RequestHandlerWrapperImpl : RequestHandlerWrapper public override Task Handle(IRequest request, IServiceProvider serviceProvider, CancellationToken cancellationToken) { - async Task Handler() + async Task Handler(CancellationToken t = default) { await serviceProvider.GetRequiredService>() - .Handle((TRequest) request, cancellationToken); + .Handle((TRequest) request, t == default ? cancellationToken : t); return Unit.Value; } @@ -67,6 +67,6 @@ await serviceProvider.GetRequiredService>() .GetServices>() .Reverse() .Aggregate((RequestHandlerDelegate) Handler, - (next, pipeline) => () => pipeline.Handle((TRequest) request, next, cancellationToken))(); + (next, pipeline) => (t) => pipeline.Handle((TRequest) request, next, t == default ? cancellationToken : t))(); } } \ No newline at end of file diff --git a/test/MediatR.Tests/GenericRequestHandlerTests.cs b/test/MediatR.Tests/GenericRequestHandlerTests.cs index e6944e3e..304076ff 100644 --- a/test/MediatR.Tests/GenericRequestHandlerTests.cs +++ b/test/MediatR.Tests/GenericRequestHandlerTests.cs @@ -32,7 +32,7 @@ public void ShouldResolveAllCombinationsOfGenericHandler(int numberOfClasses, in services.AddMediatR(cfg => { cfg.RegisterServicesFromAssemblies(dynamicAssembly); - cfg.RegisterGenericHandlers = true; + // cfg.RegisterGenericHandlers = true; }); var provider = services.BuildServiceProvider(); @@ -107,7 +107,7 @@ public void ShouldThrowExceptionWhenTypesClosingExceedsMaximum() services.AddMediatR(cfg => { cfg.RegisterServicesFromAssembly(assembly); - cfg.RegisterGenericHandlers = true; + //cfg.RegisterGenericHandlers = true; }); }) .Message.ShouldContain("One of the generic type parameter's count of types that can close exceeds the maximum length allowed"); @@ -126,7 +126,7 @@ public void ShouldThrowExceptionWhenGenericHandlerRegistrationsExceedsMaximum() services.AddMediatR(cfg => { cfg.RegisterServicesFromAssembly(assembly); - cfg.RegisterGenericHandlers = true; + //cfg.RegisterGenericHandlers = true; }); }) .Message.ShouldContain("The total number of generic type registrations exceeds the maximum allowed"); @@ -145,7 +145,7 @@ public void ShouldThrowExceptionWhenGenericTypeParametersExceedsMaximum() services.AddMediatR(cfg => { cfg.RegisterServicesFromAssembly(assembly); - cfg.RegisterGenericHandlers = true; + // cfg.RegisterGenericHandlers = true; }); }) .Message.ShouldContain("The number of generic type parameters exceeds the maximum allowed"); @@ -163,11 +163,11 @@ public void ShouldThrowExceptionWhenTimeoutOccurs() { services.AddMediatR(cfg => { - cfg.MaxGenericTypeParameters = 0; - cfg.MaxGenericTypeRegistrations = 0; - cfg.MaxTypesClosing = 0; - cfg.RegistrationTimeout = 1000; - cfg.RegisterGenericHandlers = true; + //cfg.MaxGenericTypeParameters = 0; + //cfg.MaxGenericTypeRegistrations = 0; + //cfg.MaxTypesClosing = 0; + //cfg.RegistrationTimeout = 1000; + //cfg.RegisterGenericHandlers = true; cfg.RegisterServicesFromAssembly(assembly); }); }) @@ -184,7 +184,7 @@ public void ShouldNotRegisterGenericHandlersWhenOptingOut() services.AddMediatR(cfg => { //opt out flag set - cfg.RegisterGenericHandlers = false; + //cfg.RegisterGenericHandlers = false; cfg.RegisterServicesFromAssembly(assembly); }); diff --git a/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs b/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs index 6e27757e..d0e8659d 100644 --- a/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs +++ b/test/MediatR.Tests/MicrosoftExtensionsDI/AssemblyResolutionTests.cs @@ -18,7 +18,7 @@ public AssemblyResolutionTests() services.AddMediatR(cfg => { cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly); - cfg.RegisterGenericHandlers = true; + //cfg.RegisterGenericHandlers = true; }); _provider = services.BuildServiceProvider(); } diff --git a/test/MediatR.Tests/SendTests.cs b/test/MediatR.Tests/SendTests.cs index abbf7717..a59d7215 100644 --- a/test/MediatR.Tests/SendTests.cs +++ b/test/MediatR.Tests/SendTests.cs @@ -1,17 +1,18 @@ -using System.Threading; - -using System; -using System.Threading.Tasks; -using Shouldly; -using Xunit; -using Microsoft.Extensions.DependencyInjection; +using System.Threading; + +using System; +using System.Threading.Tasks; +using Shouldly; +using Xunit; +using Microsoft.Extensions.DependencyInjection; using System.Reflection; - -namespace MediatR.Tests; -public class SendTests -{ +using MediatR.Pipeline; + +namespace MediatR.Tests; +public class SendTests +{ private readonly IServiceProvider _serviceProvider; - private Dependency _dependency; + private Dependency _dependency; private readonly IMediator _mediator; public SendTests() @@ -21,92 +22,93 @@ public SendTests() services.AddMediatR(cfg => { cfg.RegisterServicesFromAssemblies(typeof(Ping).Assembly); - cfg.RegisterGenericHandlers = true; + //cfg.RegisterGenericHandlers = true; + cfg.AddOpenBehavior(typeof(TimeoutBehavior<,>), ServiceLifetime.Transient); }); services.AddSingleton(_dependency); _serviceProvider = services.BuildServiceProvider(); _mediator = _serviceProvider.GetService()!; - } - - public class Ping : IRequest - { - public string? Message { get; set; } - } - - public class VoidPing : IRequest - { - } - - public class Pong - { - public string? Message { get; set; } - } - - public class PingHandler : IRequestHandler - { - public Task Handle(Ping request, CancellationToken cancellationToken) - { - return Task.FromResult(new Pong { Message = request.Message + " Pong" }); - } - } - - public class Dependency - { - public bool Called { get; set; } - public bool CalledSpecific { get; set; } - } - - public class VoidPingHandler : IRequestHandler - { - private readonly Dependency _dependency; - - public VoidPingHandler(Dependency dependency) => _dependency = dependency; - - public Task Handle(VoidPing request, CancellationToken cancellationToken) - { - _dependency.Called = true; - - return Task.CompletedTask; - } - } - - public class GenericPing : IRequest - where T : Pong - { - public T? Pong { get; set; } - } - - public class GenericPingHandler : IRequestHandler, T> - where T : Pong - { - private readonly Dependency _dependency; - - public GenericPingHandler(Dependency dependency) => _dependency = dependency; - - public Task Handle(GenericPing request, CancellationToken cancellationToken) - { - _dependency.Called = true; - request.Pong!.Message += " Pong"; - return Task.FromResult(request.Pong!); - } - } - - public class VoidGenericPing : IRequest - where T : Pong - { } - - public class VoidGenericPingHandler : IRequestHandler> - where T : Pong - { - private readonly Dependency _dependency; - public VoidGenericPingHandler(Dependency dependency) => _dependency = dependency; - - public Task Handle(VoidGenericPing request, CancellationToken cancellationToken) - { - _dependency.Called = true; - - return Task.CompletedTask; - } + } + + public class Ping : IRequest + { + public string? Message { get; set; } + } + + public class VoidPing : IRequest + { + } + + public class Pong + { + public string? Message { get; set; } + } + + public class PingHandler : IRequestHandler + { + public Task Handle(Ping request, CancellationToken cancellationToken) + { + return Task.FromResult(new Pong { Message = request.Message + " Pong" }); + } + } + + public class Dependency + { + public bool Called { get; set; } + public bool CalledSpecific { get; set; } + } + + public class VoidPingHandler : IRequestHandler + { + private readonly Dependency _dependency; + + public VoidPingHandler(Dependency dependency) => _dependency = dependency; + + public Task Handle(VoidPing request, CancellationToken cancellationToken) + { + _dependency.Called = true; + + return Task.CompletedTask; + } + } + + public class GenericPing : IRequest + where T : Pong + { + public T? Pong { get; set; } + } + + public class GenericPingHandler : IRequestHandler, T> + where T : Pong + { + private readonly Dependency _dependency; + + public GenericPingHandler(Dependency dependency) => _dependency = dependency; + + public Task Handle(GenericPing request, CancellationToken cancellationToken) + { + _dependency.Called = true; + request.Pong!.Message += " Pong"; + return Task.FromResult(request.Pong!); + } + } + + public class VoidGenericPing : IRequest + where T : Pong + { } + + public class VoidGenericPingHandler : IRequestHandler> + where T : Pong + { + private readonly Dependency _dependency; + public VoidGenericPingHandler(Dependency dependency) => _dependency = dependency; + + public Task Handle(VoidGenericPing request, CancellationToken cancellationToken) + { + _dependency.Called = true; + + return Task.CompletedTask; + } } public class PongExtension : Pong @@ -127,22 +129,22 @@ public Task Handle(VoidGenericPing request, CancellationToken can } } - public interface ITestInterface1 { } - public interface ITestInterface2 { } + public interface ITestInterface1 { } + public interface ITestInterface2 { } public interface ITestInterface3 { } public class TestClass1 : ITestInterface1 { } public class TestClass2 : ITestInterface2 { } public class TestClass3 : ITestInterface3 { } - public class MultipleGenericTypeParameterRequest : IRequest - where T1 : ITestInterface1 - where T2 : ITestInterface2 - where T3 : ITestInterface3 - { - public int Foo { get; set; } - } - + public class MultipleGenericTypeParameterRequest : IRequest + where T1 : ITestInterface1 + where T2 : ITestInterface2 + where T3 : ITestInterface3 + { + public int Foo { get; set; } + } + public class MultipleGenericTypeParameterRequestHandler : IRequestHandler, int> where T1 : ITestInterface1 where T2 : ITestInterface2 @@ -152,92 +154,141 @@ public class MultipleGenericTypeParameterRequestHandler : IRequestHa public MultipleGenericTypeParameterRequestHandler(Dependency dependency) => _dependency = dependency; - public Task Handle(MultipleGenericTypeParameterRequest request, CancellationToken cancellationToken) - { - _dependency.Called = true; - return Task.FromResult(1); - } - } - - [Fact] - public async Task Should_resolve_main_handler() - { - var response = await _mediator.Send(new Ping { Message = "Ping" }); - - response.Message.ShouldBe("Ping Pong"); - } - - [Fact] - public async Task Should_resolve_main_void_handler() - { - await _mediator.Send(new VoidPing()); - - _dependency.Called.ShouldBeTrue(); - } - - [Fact] - public async Task Should_resolve_main_handler_via_dynamic_dispatch() - { - object request = new Ping { Message = "Ping" }; - var response = await _mediator.Send(request); - - var pong = response.ShouldBeOfType(); - pong.Message.ShouldBe("Ping Pong"); - } - - [Fact] - public async Task Should_resolve_main_void_handler_via_dynamic_dispatch() - { - object request = new VoidPing(); - var response = await _mediator.Send(request); - - response.ShouldBeOfType(); - - _dependency.Called.ShouldBeTrue(); - } - - [Fact] - public async Task Should_resolve_main_handler_by_specific_interface() - { - var response = await _mediator.Send(new Ping { Message = "Ping" }); - - response.Message.ShouldBe("Ping Pong"); - } - - [Fact] - public async Task Should_resolve_main_handler_by_given_interface() - { - // wrap requests in an array, so this test won't break on a 'replace with var' refactoring - var requests = new IRequest[] { new VoidPing() }; - await _mediator.Send(requests[0]); - - _dependency.Called.ShouldBeTrue(); - } - - [Fact] - public Task Should_raise_execption_on_null_request() => Should.ThrowAsync(async () => await _mediator.Send(default!)); - - [Fact] - public async Task Should_resolve_generic_handler() + public Task Handle(MultipleGenericTypeParameterRequest request, CancellationToken cancellationToken) + { + _dependency.Called = true; + return Task.FromResult(1); + } + } + + public class TimeoutBehavior : IPipelineBehavior + where TRequest : notnull + { + public async Task Handle(TRequest request, RequestHandlerDelegate next, CancellationToken cancellationToken) + { + using (var cts = new CancellationTokenSource(500)) + { + return await next(cts.Token); + } + } + } + + public class TimeoutRequest : IRequest + { + } + + public class TimeoutRequest2 : IRequest + { + } + + public class TimeoutRequestHandler : IRequestHandler + { + private readonly Dependency _dependency; + + public TimeoutRequestHandler(Dependency dependency) => _dependency = dependency; + + public async Task Handle(TimeoutRequest request, CancellationToken cancellationToken) + { + await Task.Delay(2000, cancellationToken); + + _dependency.Called = true; + } + } + + public class TimeoutRequest2Handler : IRequestHandler + { + private readonly Dependency _dependency; + + public TimeoutRequest2Handler(Dependency dependency) => _dependency = dependency; + + public async Task Handle(TimeoutRequest2 request, CancellationToken cancellationToken) + { + await Task.Delay(2000, cancellationToken); + + _dependency.Called = true; + return 1; + } + } + + [Fact] + public async Task Should_resolve_main_handler() + { + var response = await _mediator.Send(new Ping { Message = "Ping" }); + + response.Message.ShouldBe("Ping Pong"); + } + + [Fact] + public async Task Should_resolve_main_void_handler() + { + await _mediator.Send(new VoidPing()); + + _dependency.Called.ShouldBeTrue(); + } + + [Fact] + public async Task Should_resolve_main_handler_via_dynamic_dispatch() + { + object request = new Ping { Message = "Ping" }; + var response = await _mediator.Send(request); + + var pong = response.ShouldBeOfType(); + pong.Message.ShouldBe("Ping Pong"); + } + + [Fact] + public async Task Should_resolve_main_void_handler_via_dynamic_dispatch() + { + object request = new VoidPing(); + var response = await _mediator.Send(request); + + response.ShouldBeOfType(); + + _dependency.Called.ShouldBeTrue(); + } + + [Fact] + public async Task Should_resolve_main_handler_by_specific_interface() + { + var response = await _mediator.Send(new Ping { Message = "Ping" }); + + response.Message.ShouldBe("Ping Pong"); + } + + [Fact] + public async Task Should_resolve_main_handler_by_given_interface() + { + // wrap requests in an array, so this test won't break on a 'replace with var' refactoring + var requests = new IRequest[] { new VoidPing() }; + await _mediator.Send(requests[0]); + + _dependency.Called.ShouldBeTrue(); + } + + [Fact] + public Task Should_raise_execption_on_null_request() => Should.ThrowAsync(async () => await _mediator.Send(default!)); + + [Fact] + public async Task Should_resolve_generic_handler() { var request = new GenericPing { Pong = new Pong { Message = "Ping" } }; - var result = await _mediator.Send(request); - - var pong = result.ShouldBeOfType(); - pong.Message.ShouldBe("Ping Pong"); - - _dependency.Called.ShouldBeTrue(); - } - - [Fact] - public async Task Should_resolve_generic_void_handler() - { - var request = new VoidGenericPing(); - await _mediator.Send(request); - - _dependency.Called.ShouldBeTrue(); - } - + var result = await _mediator.Send(request); + + var pong = result.ShouldBeOfType(); + pong.Message.ShouldBe("Ping Pong"); + + _dependency.Called.ShouldBeTrue(); + } + + [Fact] + public async Task Should_resolve_generic_void_handler() + { + var request = new VoidGenericPing(); + await _mediator.Send(request); + + _dependency.Called.ShouldBeTrue(); + } + [Fact] public async Task Should_resolve_multiple_type_parameter_generic_handler() { @@ -245,9 +296,9 @@ public async Task Should_resolve_multiple_type_parameter_generic_handler() await _mediator.Send(request); _dependency.Called.ShouldBeTrue(); - } - - [Fact] + } + + [Fact] public async Task Should_resolve_closed_handler_if_defined() { var dependency = new Dependency(); @@ -256,7 +307,7 @@ public async Task Should_resolve_closed_handler_if_defined() services.AddMediatR(cfg => { cfg.RegisterServicesFromAssemblies(Assembly.GetExecutingAssembly()); - cfg.RegisterGenericHandlers = true; + //cfg.RegisterGenericHandlers = true; }); services.AddTransient>,TestClass1PingRequestHandler>(); @@ -268,9 +319,9 @@ public async Task Should_resolve_closed_handler_if_defined() dependency.Called.ShouldBeFalse(); dependency.CalledSpecific.ShouldBeTrue(); - } - - [Fact] + } + + [Fact] public async Task Should_resolve_open_handler_if_not_defined() { var dependency = new Dependency(); @@ -279,7 +330,7 @@ public async Task Should_resolve_open_handler_if_not_defined() services.AddMediatR(cfg => { cfg.RegisterServicesFromAssemblies(Assembly.GetExecutingAssembly()); - cfg.RegisterGenericHandlers = true; + //cfg.RegisterGenericHandlers = true; }); services.AddTransient>, TestClass1PingRequestHandler>(); var serviceProvider = services.BuildServiceProvider(); @@ -290,5 +341,31 @@ public async Task Should_resolve_open_handler_if_not_defined() dependency.Called.ShouldBeTrue(); dependency.CalledSpecific.ShouldBeFalse(); - } + } + + [Fact] + public async Task TimeoutBehavior_Void_Should_Cancel_Long_Running_Task_And_Throw_Exception() + { + var request = new TimeoutRequest(); + + var exception = await Should.ThrowAsync(() => _mediator.Send(request)); + + exception.ShouldNotBeNull(); + exception.ShouldBeAssignableTo(); + _dependency.Called.ShouldBeFalse(); + } + + [Fact] + public async Task TimeoutBehavior_NonVoid_Should_Cancel_Long_Running_Task_And_Throw_Exception() + { + var request = new TimeoutRequest2(); + int result = 0; + + var exception = await Should.ThrowAsync(async () => { result = await _mediator.Send(request); }); + + exception.ShouldNotBeNull(); + exception.ShouldBeAssignableTo(); + _dependency.Called.ShouldBeFalse(); + result.ShouldBe(0); + } } \ No newline at end of file