Skip to content

Commit c295291

Browse files
authored
Merge pull request #925 from jbogard/fixing-behavior-registration
Fixing registration problem; moving exception behaviors first so that…
2 parents 7fe73da + f28cdc3 commit c295291

File tree

4 files changed

+135
-40
lines changed

4 files changed

+135
-40
lines changed

src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using MediatR;
66
using MediatR.NotificationPublishers;
77
using MediatR.Pipeline;
8+
using MediatR.Registration;
89

910
namespace Microsoft.Extensions.DependencyInjection;
1011

@@ -133,15 +134,14 @@ public MediatRServiceConfiguration AddBehavior<TImplementationType>(ServiceLifet
133134
/// <returns>This</returns>
134135
public MediatRServiceConfiguration AddBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
135136
{
136-
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
137-
var implementedBehaviorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IPipelineBehavior<,>)));
137+
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IPipelineBehavior<,>)).ToList();
138138

139-
if (implementedBehaviorTypes.Count == 0)
139+
if (implementedGenericInterfaces.Count == 0)
140140
{
141141
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IPipelineBehavior<,>).FullName}");
142142
}
143143

144-
foreach (var implementedBehaviorType in implementedBehaviorTypes)
144+
foreach (var implementedBehaviorType in implementedGenericInterfaces)
145145
{
146146
BehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime));
147147
}
@@ -233,15 +233,14 @@ public MediatRServiceConfiguration AddStreamBehavior<TImplementationType>(Servic
233233
/// <returns>This</returns>
234234
public MediatRServiceConfiguration AddStreamBehavior(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
235235
{
236-
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
237-
var implementedBehaviorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IStreamPipelineBehavior<,>)));
236+
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IStreamPipelineBehavior<,>)).ToList();
238237

239-
if (implementedBehaviorTypes.Count == 0)
238+
if (implementedGenericInterfaces.Count == 0)
240239
{
241240
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IStreamPipelineBehavior<,>).FullName}");
242241
}
243242

244-
foreach (var implementedBehaviorType in implementedBehaviorTypes)
243+
foreach (var implementedBehaviorType in implementedGenericInterfaces)
245244
{
246245
StreamBehaviorsToRegister.Add(new ServiceDescriptor(implementedBehaviorType, implementationType, serviceLifetime));
247246
}
@@ -320,15 +319,14 @@ public MediatRServiceConfiguration AddRequestPreProcessor<TImplementationType>(
320319
/// <returns>This</returns>
321320
public MediatRServiceConfiguration AddRequestPreProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
322321
{
323-
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
324-
var implementedPreProcessorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPreProcessor<>)));
322+
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPreProcessor<>)).ToList();
325323

326-
if (implementedPreProcessorTypes.Count == 0)
324+
if (implementedGenericInterfaces.Count == 0)
327325
{
328326
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPreProcessor<>).FullName}");
329327
}
330328

331-
foreach (var implementedPreProcessorType in implementedPreProcessorTypes)
329+
foreach (var implementedPreProcessorType in implementedGenericInterfaces)
332330
{
333331
RequestPreProcessorsToRegister.Add(new ServiceDescriptor(implementedPreProcessorType, implementationType, serviceLifetime));
334332
}
@@ -406,15 +404,14 @@ public MediatRServiceConfiguration AddRequestPostProcessor<TImplementationType>(
406404
/// <returns>This</returns>
407405
public MediatRServiceConfiguration AddRequestPostProcessor(Type implementationType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
408406
{
409-
var implementedGenericInterfaces = implementationType.GetInterfaces().Where(i => i.IsGenericType).Select(i => i.GetGenericTypeDefinition());
410-
var implementedPostProcessorTypes = new HashSet<Type>(implementedGenericInterfaces.Where(i => i == typeof(IRequestPostProcessor<,>)));
407+
var implementedGenericInterfaces = implementationType.FindInterfacesThatClose(typeof(IRequestPostProcessor<,>)).ToList();
411408

412-
if (implementedPostProcessorTypes.Count == 0)
409+
if (implementedGenericInterfaces.Count == 0)
413410
{
414411
throw new InvalidOperationException($"{implementationType.Name} must implement {typeof(IRequestPostProcessor<,>).FullName}");
415412
}
416413

417-
foreach (var implementedPostProcessorType in implementedPostProcessorTypes)
414+
foreach (var implementedPostProcessorType in implementedGenericInterfaces)
418415
{
419416
RequestPostProcessorsToRegister.Add(new ServiceDescriptor(implementedPostProcessorType, implementationType, serviceLifetime));
420417
}

src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,26 @@ public static IServiceCollection AddMediatR(this IServiceCollection services,
3030

3131
configuration.Invoke(serviceConfig);
3232

33-
if (!serviceConfig.AssembliesToRegister.Any())
33+
return services.AddMediatR(serviceConfig);
34+
}
35+
36+
/// <summary>
37+
/// Registers handlers and mediator types from the specified assemblies
38+
/// </summary>
39+
/// <param name="services">Service collection</param>
40+
/// <param name="configuration">Configuration options</param>
41+
/// <returns>Service collection</returns>
42+
public static IServiceCollection AddMediatR(this IServiceCollection services,
43+
MediatRServiceConfiguration configuration)
44+
{
45+
if (!configuration.AssembliesToRegister.Any())
3446
{
3547
throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers.");
3648
}
3749

38-
ServiceRegistrar.AddMediatRClasses(services, serviceConfig);
50+
ServiceRegistrar.AddMediatRClasses(services, configuration);
3951

40-
ServiceRegistrar.AddRequiredServices(services, serviceConfig);
52+
ServiceRegistrar.AddRequiredServices(services, configuration);
4153

4254
return services;
4355
}

src/MediatR/Registration/ServiceRegistrar.cs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List<Type>
138138
}
139139
}
140140

141-
private static bool CouldCloseTo(this Type openConcretion, Type closedInterface)
141+
internal static bool CouldCloseTo(this Type openConcretion, Type closedInterface)
142142
{
143143
var openInterface = closedInterface.GetGenericTypeDefinition();
144144
var arguments = closedInterface.GenericTypeArguments;
@@ -161,7 +161,7 @@ private static bool IsOpenGeneric(this Type type)
161161
return type.IsGenericTypeDefinition || type.ContainsGenericParameters;
162162
}
163163

164-
private static IEnumerable<Type> FindInterfacesThatClose(this Type pluggedType, Type templateType)
164+
internal static IEnumerable<Type> FindInterfacesThatClose(this Type pluggedType, Type templateType)
165165
{
166166
return FindInterfacesThatClosesCore(pluggedType, templateType).Distinct();
167167
}
@@ -221,6 +221,17 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
221221
services.TryAdd(notificationPublisherServiceDescriptor);
222222

223223
// Register pre processors, then post processors, then behaviors
224+
if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
225+
{
226+
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
227+
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
228+
}
229+
else
230+
{
231+
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
232+
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
233+
}
234+
224235
if (serviceConfiguration.RequestPreProcessorsToRegister.Any())
225236
{
226237
services.TryAddEnumerable(new ServiceDescriptor(typeof(IPipelineBehavior<,>), typeof(RequestPreProcessorBehavior<,>), ServiceLifetime.Transient));
@@ -242,17 +253,6 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
242253
{
243254
services.TryAddEnumerable(serviceDescriptor);
244255
}
245-
246-
if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
247-
{
248-
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
249-
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
250-
}
251-
else
252-
{
253-
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>), typeof(IRequestExceptionHandler<,,>));
254-
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>), typeof(IRequestExceptionAction<,>));
255-
}
256256
}
257257

258258
private static void RegisterBehaviorIfImplementationsExist(IServiceCollection services, Type behaviorType, Type subBehaviorType)

test/MediatR.Tests/MicrosoftExtensionsDI/PipelineTests.cs

Lines changed: 94 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,11 @@ public class NotAnOpenBehavior : IPipelineBehavior<Ping, Pong>
361361
public Task<Pong> Handle(Ping request, RequestHandlerDelegate<Pong> next, CancellationToken cancellationToken) => next();
362362
}
363363

364+
public class ThrowingBehavior : IPipelineBehavior<Ping, Pong>
365+
{
366+
public Task<Pong> Handle(Ping request, RequestHandlerDelegate<Pong> next, CancellationToken cancellationToken) => throw new Exception(request.Message);
367+
}
368+
364369
public class NotAnOpenStreamBehavior : IStreamPipelineBehavior<Ping, Pong>
365370
{
366371
public IAsyncEnumerable<Pong> Handle(Ping request, StreamHandlerDelegate<Pong> next, CancellationToken cancellationToken) => next();
@@ -524,6 +529,27 @@ public void Should_pick_up_base_exception_behaviors()
524529
output.Messages.ShouldContain("Logging generic exception");
525530
}
526531

532+
[Fact]
533+
public void Should_handle_exceptions_from_behaviors()
534+
{
535+
var output = new Logger();
536+
IServiceCollection services = new ServiceCollection();
537+
services.AddSingleton(output);
538+
services.AddMediatR(cfg =>
539+
{
540+
cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly);
541+
cfg.AddBehavior<ThrowingBehavior>();
542+
});
543+
var provider = services.BuildServiceProvider();
544+
545+
var mediator = provider.GetRequiredService<IMediator>();
546+
547+
Should.Throw<Exception>(async () => await mediator.Send(new Ping {Message = "Ping"}));
548+
549+
output.Messages.ShouldContain("Ping Logged by Generic Type");
550+
output.Messages.ShouldContain("Logging generic exception");
551+
}
552+
527553
[Fact]
528554
public void Should_pick_up_exception_actions()
529555
{
@@ -648,6 +674,16 @@ public void Should_handle_open_behavior_registration()
648674
cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
649675
cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
650676
cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
677+
678+
var services = new ServiceCollection();
679+
680+
cfg.RegisterServicesFromAssemblyContaining<Ping>();
681+
682+
Should.NotThrow(() =>
683+
{
684+
services.AddMediatR(cfg);
685+
services.BuildServiceProvider();
686+
});
651687
}
652688

653689
[Fact]
@@ -659,16 +695,26 @@ public void Should_handle_inferred_behavior_registration()
659695

660696
cfg.BehaviorsToRegister.Count.ShouldBe(2);
661697

662-
cfg.BehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IPipelineBehavior<,>));
698+
cfg.BehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IPipelineBehavior<Ping, Pong>));
663699
cfg.BehaviorsToRegister[0].ImplementationType.ShouldBe(typeof(InnerBehavior));
664700
cfg.BehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
665701
cfg.BehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
666702
cfg.BehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
667-
cfg.BehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IPipelineBehavior<,>));
703+
cfg.BehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IPipelineBehavior<Ping, Pong>));
668704
cfg.BehaviorsToRegister[1].ImplementationType.ShouldBe(typeof(OuterBehavior));
669705
cfg.BehaviorsToRegister[1].ImplementationFactory.ShouldBeNull();
670706
cfg.BehaviorsToRegister[1].ImplementationInstance.ShouldBeNull();
671707
cfg.BehaviorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);
708+
709+
var services = new ServiceCollection();
710+
711+
cfg.RegisterServicesFromAssemblyContaining<Ping>();
712+
713+
Should.NotThrow(() =>
714+
{
715+
services.AddMediatR(cfg);
716+
services.BuildServiceProvider();
717+
});
672718
}
673719

674720

@@ -681,16 +727,26 @@ public void Should_handle_inferred_stream_behavior_registration()
681727

682728
cfg.StreamBehaviorsToRegister.Count.ShouldBe(2);
683729

684-
cfg.StreamBehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<,>));
730+
cfg.StreamBehaviorsToRegister[0].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<Ping, Pong>));
685731
cfg.StreamBehaviorsToRegister[0].ImplementationType.ShouldBe(typeof(InnerStreamBehavior));
686732
cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
687733
cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
688734
cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
689-
cfg.StreamBehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<,>));
735+
cfg.StreamBehaviorsToRegister[1].ServiceType.ShouldBe(typeof(IStreamPipelineBehavior<Ping, Pong>));
690736
cfg.StreamBehaviorsToRegister[1].ImplementationType.ShouldBe(typeof(OuterStreamBehavior));
691737
cfg.StreamBehaviorsToRegister[1].ImplementationFactory.ShouldBeNull();
692738
cfg.StreamBehaviorsToRegister[1].ImplementationInstance.ShouldBeNull();
693739
cfg.StreamBehaviorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);
740+
741+
var services = new ServiceCollection();
742+
743+
cfg.RegisterServicesFromAssemblyContaining<Ping>();
744+
745+
Should.NotThrow(() =>
746+
{
747+
services.AddMediatR(cfg);
748+
services.BuildServiceProvider();
749+
});
694750
}
695751

696752
[Fact]
@@ -702,16 +758,26 @@ public void Should_handle_inferred_pre_processor_registration()
702758

703759
cfg.RequestPreProcessorsToRegister.Count.ShouldBe(2);
704760

705-
cfg.RequestPreProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPreProcessor<>));
761+
cfg.RequestPreProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPreProcessor<Ping>));
706762
cfg.RequestPreProcessorsToRegister[0].ImplementationType.ShouldBe(typeof(FirstConcretePreProcessor));
707763
cfg.RequestPreProcessorsToRegister[0].ImplementationFactory.ShouldBeNull();
708764
cfg.RequestPreProcessorsToRegister[0].ImplementationInstance.ShouldBeNull();
709765
cfg.RequestPreProcessorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
710-
cfg.RequestPreProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPreProcessor<>));
766+
cfg.RequestPreProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPreProcessor<Ping>));
711767
cfg.RequestPreProcessorsToRegister[1].ImplementationType.ShouldBe(typeof(NextConcretePreProcessor));
712768
cfg.RequestPreProcessorsToRegister[1].ImplementationFactory.ShouldBeNull();
713769
cfg.RequestPreProcessorsToRegister[1].ImplementationInstance.ShouldBeNull();
714770
cfg.RequestPreProcessorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);
771+
772+
var services = new ServiceCollection();
773+
774+
cfg.RegisterServicesFromAssemblyContaining<Ping>();
775+
776+
Should.NotThrow(() =>
777+
{
778+
services.AddMediatR(cfg);
779+
services.BuildServiceProvider();
780+
});
715781
}
716782

717783
[Fact]
@@ -723,16 +789,26 @@ public void Should_handle_inferred_post_processor_registration()
723789

724790
cfg.RequestPostProcessorsToRegister.Count.ShouldBe(2);
725791

726-
cfg.RequestPostProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPostProcessor<,>));
792+
cfg.RequestPostProcessorsToRegister[0].ServiceType.ShouldBe(typeof(IRequestPostProcessor<Ping, Pong>));
727793
cfg.RequestPostProcessorsToRegister[0].ImplementationType.ShouldBe(typeof(FirstConcretePostProcessor));
728794
cfg.RequestPostProcessorsToRegister[0].ImplementationFactory.ShouldBeNull();
729795
cfg.RequestPostProcessorsToRegister[0].ImplementationInstance.ShouldBeNull();
730796
cfg.RequestPostProcessorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Transient);
731-
cfg.RequestPostProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPostProcessor<,>));
797+
cfg.RequestPostProcessorsToRegister[1].ServiceType.ShouldBe(typeof(IRequestPostProcessor<Ping, Pong>));
732798
cfg.RequestPostProcessorsToRegister[1].ImplementationType.ShouldBe(typeof(NextConcretePostProcessor));
733799
cfg.RequestPostProcessorsToRegister[1].ImplementationFactory.ShouldBeNull();
734800
cfg.RequestPostProcessorsToRegister[1].ImplementationInstance.ShouldBeNull();
735801
cfg.RequestPostProcessorsToRegister[1].Lifetime.ShouldBe(ServiceLifetime.Transient);
802+
803+
var services = new ServiceCollection();
804+
805+
cfg.RegisterServicesFromAssemblyContaining<Ping>();
806+
807+
Should.NotThrow(() =>
808+
{
809+
services.AddMediatR(cfg);
810+
services.BuildServiceProvider();
811+
});
736812
}
737813

738814
[Fact]
@@ -756,5 +832,15 @@ public void Should_handle_open_behaviors_registration_from_a_single_type()
756832
cfg.StreamBehaviorsToRegister[0].ImplementationFactory.ShouldBeNull();
757833
cfg.StreamBehaviorsToRegister[0].ImplementationInstance.ShouldBeNull();
758834
cfg.StreamBehaviorsToRegister[0].Lifetime.ShouldBe(ServiceLifetime.Singleton);
835+
836+
var services = new ServiceCollection();
837+
838+
cfg.RegisterServicesFromAssemblyContaining<Ping>();
839+
840+
Should.NotThrow(() =>
841+
{
842+
services.AddMediatR(cfg);
843+
services.BuildServiceProvider();
844+
});
759845
}
760846
}

0 commit comments

Comments
 (0)