22using System . Collections . Generic ;
33using System . Linq ;
44using System . Reflection ;
5+ using System . Threading ;
56using MediatR . Pipeline ;
67using Microsoft . Extensions . DependencyInjection ;
78using Microsoft . Extensions . DependencyInjection . Extensions ;
89
910namespace MediatR . Registration ;
1011
1112public static class ServiceRegistrar
12- {
13- public static void AddMediatRClasses ( IServiceCollection services , MediatRServiceConfiguration configuration )
14- {
13+ {
14+ private static int MaxGenericTypeParameters ;
15+ private static int MaxTypesClosing ;
16+ private static int MaxGenericTypeRegistrations ;
17+ private static int RegistrationTimeout ;
18+
19+ public static void SetGenericRequestHandlerRegistrationLimitations ( MediatRServiceConfiguration configuration )
20+ {
21+ MaxGenericTypeParameters = configuration . MaxGenericTypeParameters ;
22+ MaxTypesClosing = configuration . MaxTypesClosing ;
23+ MaxGenericTypeRegistrations = configuration . MaxGenericTypeRegistrations ;
24+ RegistrationTimeout = configuration . RegistrationTimeout ;
25+ }
26+
27+ public static void AddMediatRClassesWithTimeout ( IServiceCollection services , MediatRServiceConfiguration configuration )
28+ {
29+ using ( var cts = new CancellationTokenSource ( RegistrationTimeout ) )
30+ {
31+ try
32+ {
33+ AddMediatRClasses ( services , configuration , cts . Token ) ;
34+ }
35+ catch ( OperationCanceledException )
36+ {
37+ throw new TimeoutException ( "The generic handler registration process timed out." ) ;
38+ }
39+ }
40+ }
41+
42+ public static void AddMediatRClasses ( IServiceCollection services , MediatRServiceConfiguration configuration , CancellationToken cancellationToken = default )
43+ {
44+
1545 var assembliesToScan = configuration . AssembliesToRegister . Distinct ( ) . ToArray ( ) ;
1646
17- ConnectImplementationsToTypesClosing ( typeof ( IRequestHandler < , > ) , services , assembliesToScan , false , configuration ) ;
18- ConnectImplementationsToTypesClosing ( typeof ( IRequestHandler < > ) , services , assembliesToScan , false , configuration ) ;
47+ ConnectImplementationsToTypesClosing ( typeof ( IRequestHandler < , > ) , services , assembliesToScan , false , configuration , cancellationToken ) ;
48+ ConnectImplementationsToTypesClosing ( typeof ( IRequestHandler < > ) , services , assembliesToScan , false , configuration , cancellationToken ) ;
1949 ConnectImplementationsToTypesClosing ( typeof ( INotificationHandler < > ) , services , assembliesToScan , true , configuration ) ;
2050 ConnectImplementationsToTypesClosing ( typeof ( IStreamRequestHandler < , > ) , services , assembliesToScan , false , configuration ) ;
2151 ConnectImplementationsToTypesClosing ( typeof ( IRequestExceptionHandler < , , > ) , services , assembliesToScan , true , configuration ) ;
@@ -63,7 +93,8 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
6393 IServiceCollection services ,
6494 IEnumerable < Assembly > assembliesToScan ,
6595 bool addIfAlreadyExists ,
66- MediatRServiceConfiguration configuration )
96+ MediatRServiceConfiguration configuration ,
97+ CancellationToken cancellationToken = default )
6798 {
6899 var concretions = new List < Type > ( ) ;
69100 var interfaces = new List < Type > ( ) ;
@@ -72,9 +103,10 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
72103
73104 var types = assembliesToScan
74105 . SelectMany ( a => a . DefinedTypes )
106+ . Where ( t => ! t . ContainsGenericParameters || configuration . RegisterGenericHandlers )
75107 . Where ( t => t . IsConcrete ( ) && t . FindInterfacesThatClose ( openRequestInterface ) . Any ( ) )
76108 . Where ( configuration . TypeEvaluator )
77- . ToList ( ) ;
109+ . ToList ( ) ;
78110
79111 foreach ( var type in types )
80112 {
@@ -131,7 +163,7 @@ private static void ConnectImplementationsToTypesClosing(Type openRequestInterfa
131163 foreach ( var @interface in genericInterfaces )
132164 {
133165 var exactMatches = genericConcretions . Where ( x => x . CanBeCastTo ( @interface ) ) . ToList ( ) ;
134- AddAllConcretionsThatClose ( @interface , exactMatches , services , assembliesToScan ) ;
166+ AddAllConcretionsThatClose ( @interface , exactMatches , services , assembliesToScan , cancellationToken ) ;
135167 }
136168 }
137169
@@ -174,7 +206,7 @@ private static void AddConcretionsThatCouldBeClosed(Type @interface, List<Type>
174206
175207 private static ( Type Service , Type Implementation ) GetConcreteRegistrationTypes ( Type openRequestHandlerInterface , Type concreteGenericTRequest , Type openRequestHandlerImplementation )
176208 {
177- var closingType = concreteGenericTRequest . GetGenericArguments ( ) . First ( ) ;
209+ var closingTypes = concreteGenericTRequest . GetGenericArguments ( ) ;
178210
179211 var concreteTResponse = concreteGenericTRequest . GetInterfaces ( )
180212 . FirstOrDefault ( x => x . IsGenericType && x . GetGenericTypeDefinition ( ) == typeof ( IRequest < > ) )
@@ -187,33 +219,90 @@ private static (Type Service, Type Implementation) GetConcreteRegistrationTypes(
187219 typeDefinition . MakeGenericType ( concreteGenericTRequest , concreteTResponse ) :
188220 typeDefinition . MakeGenericType ( concreteGenericTRequest ) ;
189221
190- return ( serviceType , openRequestHandlerImplementation . MakeGenericType ( closingType ) ) ;
222+ return ( serviceType , openRequestHandlerImplementation . MakeGenericType ( closingTypes ) ) ;
191223 }
192224
193- private static List < Type > ? GetConcreteRequestTypes ( Type openRequestHandlerInterface , Type openRequestHandlerImplementation , IEnumerable < Assembly > assembliesToScan )
225+ private static List < Type > ? GetConcreteRequestTypes ( Type openRequestHandlerInterface , Type openRequestHandlerImplementation , IEnumerable < Assembly > assembliesToScan , CancellationToken cancellationToken )
194226 {
195- var constraints = openRequestHandlerImplementation . GetGenericArguments ( ) . First ( ) . GetGenericParameterConstraints ( ) ;
196-
197- var typesThatCanClose = assembliesToScan
198- . SelectMany ( assembly => assembly . GetTypes ( ) )
199- . Where ( type => type . IsClass && ! type . IsAbstract && constraints . All ( constraint => constraint . IsAssignableFrom ( type ) ) )
200- . ToList ( ) ;
227+ //request generic type constraints
228+ var constraintsForEachParameter = openRequestHandlerImplementation
229+ . GetGenericArguments ( )
230+ . Select ( x => x . GetGenericParameterConstraints ( ) )
231+ . ToList ( ) ;
232+
233+ if ( constraintsForEachParameter . Count > 2 && constraintsForEachParameter . Any ( constraints => ! constraints . Where ( x => x . IsInterface || x . IsClass ) . Any ( ) ) )
234+ throw new ArgumentException ( $ "Error registering the generic handler type: { openRequestHandlerImplementation . FullName } . When registering generic requests with more than two type parameters, each type parameter must have at least one constraint of type interface or class.") ;
235+
236+ var typesThatCanCloseForEachParameter = constraintsForEachParameter
237+ . Select ( constraints => assembliesToScan
238+ . SelectMany ( assembly => assembly . GetTypes ( ) )
239+ . Where ( type => type . IsClass && ! type . IsAbstract && constraints . All ( constraint => constraint . IsAssignableFrom ( type ) ) ) . ToList ( )
240+ ) . ToList ( ) ;
201241
202242 var requestType = openRequestHandlerInterface . GenericTypeArguments . First ( ) ;
203243
204244 if ( requestType . IsGenericParameter )
205245 return null ;
206246
207247 var requestGenericTypeDefinition = requestType . GetGenericTypeDefinition ( ) ;
248+
249+ var combinations = GenerateCombinations ( requestType , typesThatCanCloseForEachParameter , 0 , cancellationToken ) ;
250+
251+ return combinations . Select ( types => requestGenericTypeDefinition . MakeGenericType ( types . ToArray ( ) ) ) . ToList ( ) ;
252+ }
253+
254+ // Method to generate combinations recursively
255+ public static List < List < Type > > GenerateCombinations ( Type requestType , List < List < Type > > lists , int depth = 0 , CancellationToken cancellationToken = default )
256+ {
257+ if ( depth == 0 )
258+ {
259+ // Initial checks
260+ if ( MaxGenericTypeParameters > 0 && lists . Count > MaxGenericTypeParameters )
261+ throw new ArgumentException ( $ "Error registering the generic type: { requestType . FullName } . The number of generic type parameters exceeds the maximum allowed ({ MaxGenericTypeParameters } ).") ;
262+
263+ foreach ( var list in lists )
264+ {
265+ if ( MaxTypesClosing > 0 && list . Count > MaxTypesClosing )
266+ throw new ArgumentException ( $ "Error registering the generic type: { requestType . FullName } . One of the generic type parameter's count of types that can close exceeds the maximum length allowed ({ MaxTypesClosing } ).") ;
267+ }
268+
269+ // Calculate the total number of combinations
270+ long totalCombinations = 1 ;
271+ foreach ( var list in lists )
272+ {
273+ totalCombinations *= list . Count ;
274+ if ( MaxGenericTypeParameters > 0 && totalCombinations > MaxGenericTypeRegistrations )
275+ throw new ArgumentException ( $ "Error registering the generic type: { requestType . FullName } . The total number of generic type registrations exceeds the maximum allowed ({ MaxGenericTypeRegistrations } ).") ;
276+ }
277+ }
278+
279+ if ( depth >= lists . Count )
280+ return new List < List < Type > > { new List < Type > ( ) } ;
281+
282+ cancellationToken . ThrowIfCancellationRequested ( ) ;
208283
209- return typesThatCanClose . Select ( type => requestGenericTypeDefinition . MakeGenericType ( type ) ) . ToList ( ) ;
284+ var currentList = lists [ depth ] ;
285+ var childCombinations = GenerateCombinations ( requestType , lists , depth + 1 , cancellationToken ) ;
286+ var combinations = new List < List < Type > > ( ) ;
287+
288+ foreach ( var item in currentList )
289+ {
290+ foreach ( var childCombination in childCombinations )
291+ {
292+ var currentCombination = new List < Type > { item } ;
293+ currentCombination . AddRange ( childCombination ) ;
294+ combinations . Add ( currentCombination ) ;
295+ }
296+ }
297+
298+ return combinations ;
210299 }
211300
212- private static void AddAllConcretionsThatClose ( Type openRequestInterface , List < Type > concretions , IServiceCollection services , IEnumerable < Assembly > assembliesToScan )
301+ private static void AddAllConcretionsThatClose ( Type openRequestInterface , List < Type > concretions , IServiceCollection services , IEnumerable < Assembly > assembliesToScan , CancellationToken cancellationToken )
213302 {
214303 foreach ( var concretion in concretions )
215- {
216- var concreteRequests = GetConcreteRequestTypes ( openRequestInterface , concretion , assembliesToScan ) ;
304+ {
305+ var concreteRequests = GetConcreteRequestTypes ( openRequestInterface , concretion , assembliesToScan , cancellationToken ) ;
217306
218307 if ( concreteRequests is null )
219308 continue ;
@@ -223,6 +312,7 @@ private static void AddAllConcretionsThatClose(Type openRequestInterface, List<T
223312
224313 foreach ( var ( Service , Implementation ) in registrationTypes )
225314 {
315+ cancellationToken . ThrowIfCancellationRequested ( ) ;
226316 services . AddTransient ( Service , Implementation ) ;
227317 }
228318 }
0 commit comments