@@ -9,23 +9,27 @@ import types
9
9
from . import providers
10
10
from .wiring import _Marker
11
11
12
+ from .providers cimport Provider
13
+
12
14
13
15
def _get_sync_patched (fn ):
14
16
@ functools.wraps (fn)
15
17
def _patched (*args , **kwargs ):
16
18
cdef object result
17
19
cdef dict to_inject
20
+ cdef object arg_key
21
+ cdef Provider provider
18
22
19
23
to_inject = kwargs.copy()
20
- for injection , provider in _patched.__injections__.items():
21
- if injection not in kwargs or isinstance (kwargs[injection ], _Marker):
22
- to_inject[injection ] = provider()
24
+ for arg_key , provider in _patched.__injections__.items():
25
+ if arg_key not in kwargs or isinstance (kwargs[arg_key ], _Marker):
26
+ to_inject[arg_key ] = provider()
23
27
24
28
result = fn(* args, ** to_inject)
25
29
26
30
if _patched.__closing__:
27
- for injection , provider in _patched.__closing__.items():
28
- if injection in kwargs and not isinstance (kwargs[injection ], _Marker):
31
+ for arg_key , provider in _patched.__closing__.items():
32
+ if arg_key in kwargs and not isinstance (kwargs[arg_key ], _Marker):
29
33
continue
30
34
if not isinstance (provider, providers.Resource):
31
35
continue
@@ -35,49 +39,45 @@ def _get_sync_patched(fn):
35
39
return _patched
36
40
37
41
38
- def _get_async_patched (fn ):
39
- @ functools.wraps (fn)
40
- async def _patched(* args, ** kwargs):
41
- cdef object result
42
- cdef dict to_inject
43
- cdef list to_inject_await = []
44
- cdef list to_close_await = []
45
-
46
- to_inject = kwargs.copy()
47
- for injection, provider in _patched.__injections__.items():
48
- if injection not in kwargs or isinstance (kwargs[injection], _Marker):
49
- provide = provider()
50
- if _isawaitable(provide):
51
- to_inject_await.append((injection, provide))
52
- else :
53
- to_inject[injection] = provide
54
-
55
- if to_inject_await:
56
- async_to_inject = await asyncio.gather(* (provide for _, provide in to_inject_await))
57
- for provide, (injection, _) in zip (async_to_inject, to_inject_await):
58
- to_inject[injection] = provide
59
-
60
- result = await fn(* args, ** to_inject)
61
-
62
- if _patched.__closing__:
63
- for injection, provider in _patched.__closing__.items():
64
- if injection in kwargs \
65
- and isinstance (kwargs[injection], _Marker):
66
- continue
67
- if not isinstance (provider, providers.Resource):
68
- continue
69
- shutdown = provider.shutdown()
70
- if _isawaitable(shutdown):
71
- to_close_await.append(shutdown)
72
-
73
- await asyncio.gather(* to_close_await)
74
-
75
- return result
76
-
77
- # Hotfix for iscoroutinefunction() for Cython < 3.0.0; can be removed after migration to Cython 3.0.0+
78
- _patched._is_coroutine = asyncio.coroutines._is_coroutine
79
-
80
- return _patched
42
+ async def _async_inject(object fn, tuple args, dict kwargs, dict injections, dict closings):
43
+ cdef object result
44
+ cdef dict to_inject
45
+ cdef list to_inject_await = []
46
+ cdef list to_close_await = []
47
+ cdef object arg_key
48
+ cdef Provider provider
49
+
50
+ to_inject = kwargs.copy()
51
+ for arg_key, provider in injections.items():
52
+ if arg_key not in kwargs or isinstance (kwargs[arg_key], _Marker):
53
+ provide = provider()
54
+ if provider.is_async_mode_enabled():
55
+ to_inject_await.append((arg_key, provide))
56
+ elif _isawaitable(provide):
57
+ to_inject_await.append((arg_key, provide))
58
+ else :
59
+ to_inject[arg_key] = provide
60
+
61
+ if to_inject_await:
62
+ async_to_inject = await asyncio.gather(* (provide for _, provide in to_inject_await))
63
+ for provide, (injection, _) in zip (async_to_inject, to_inject_await):
64
+ to_inject[injection] = provide
65
+
66
+ result = await fn(* args, ** to_inject)
67
+
68
+ if closings:
69
+ for arg_key, provider in closings.items():
70
+ if arg_key in kwargs and isinstance (kwargs[arg_key], _Marker):
71
+ continue
72
+ if not isinstance (provider, providers.Resource):
73
+ continue
74
+ shutdown = provider.shutdown()
75
+ if _isawaitable(shutdown):
76
+ to_close_await.append(shutdown)
77
+
78
+ await asyncio.gather(* to_close_await)
79
+
80
+ return result
81
81
82
82
83
83
cdef bint _isawaitable(object instance):
0 commit comments