@@ -36,7 +36,11 @@ def _(article):
36
36
"""
37
37
from __future__ import annotations
38
38
39
+ import asyncio
39
40
import enum
41
+ import functools
42
+ import inspect
43
+ from contextlib import contextmanager
40
44
from dataclasses import dataclass , field
41
45
from itertools import count
42
46
from typing import Any , Callable , Iterable , TypeVar
@@ -66,6 +70,7 @@ class StepFunctionContext:
66
70
parser : StepParser
67
71
converters : dict [str , Callable [..., Any ]] = field (default_factory = dict )
68
72
target_fixture : str | None = None
73
+ is_async : bool = False
69
74
70
75
71
76
def get_step_fixture_name (step : Step ) -> str :
@@ -86,6 +91,7 @@ def given(
86
91
{<param_name>: <converter function>}.
87
92
:param target_fixture: Target fixture name to replace by steps definition function.
88
93
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
94
+ :param is_async: True if the step is asynchronous. (Default: False)
89
95
90
96
:return: Decorator function for the step.
91
97
"""
@@ -105,6 +111,7 @@ def when(
105
111
{<param_name>: <converter function>}.
106
112
:param target_fixture: Target fixture name to replace by steps definition function.
107
113
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
114
+ :param is_async: True if the step is asynchronous. (Default: False)
108
115
109
116
:return: Decorator function for the step.
110
117
"""
@@ -124,6 +131,7 @@ def then(
124
131
{<param_name>: <converter function>}.
125
132
:param target_fixture: Target fixture name to replace by steps definition function.
126
133
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
134
+ :param is_async: True if the step is asynchronous. (Default: False)
127
135
128
136
:return: Decorator function for the step.
129
137
"""
@@ -144,6 +152,7 @@ def step(
144
152
:param converters: Optional step arguments converters mapping.
145
153
:param target_fixture: Optional fixture name to replace by step definition.
146
154
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
155
+ :param is_async: True if the step is asynchronous. (Default: False)
147
156
148
157
:return: Decorator function for the step.
149
158
@@ -159,6 +168,11 @@ def step(
159
168
def decorator (func : TCallable ) -> TCallable :
160
169
parser = get_parser (name )
161
170
171
+ if inspect .isasyncgenfunction (func ):
172
+ func = wrap_asyncgen (func )
173
+ elif inspect .iscoroutinefunction (func ):
174
+ func = wrap_coroutine (func )
175
+
162
176
context = StepFunctionContext (
163
177
type = type_ ,
164
178
step_func = func ,
@@ -177,11 +191,79 @@ def step_function_marker() -> StepFunctionContext:
177
191
f"{ StepNamePrefix .step_def .value } _{ type_ or '*' } _{ parser .name } " , seen = caller_locals .keys ()
178
192
)
179
193
caller_locals [fixture_step_name ] = pytest .fixture (name = fixture_step_name )(step_function_marker )
194
+
180
195
return func
181
196
182
197
return decorator
183
198
184
199
200
+ def _synchronize (func : Callable , async_wrapper : Callable ) -> Callable :
201
+ """Provide a synchronous wrapper for an async function or generator.
202
+
203
+ :param func: The async function / generator to wrap.
204
+ :param async_wrapper: A function taking an event loop and either a
205
+ coroutine or an async_generator (the result of calling func)
206
+ and returning the result.
207
+
208
+ :returns: The wrapped async function.
209
+ """
210
+
211
+ @functools .wraps (func )
212
+ def _wrapper (* args , ** kwargs ):
213
+ try :
214
+ loop , created = asyncio .get_running_loop (), False
215
+ except RuntimeError :
216
+ loop , created = asyncio .get_event_loop_policy ().new_event_loop (), True
217
+
218
+ try :
219
+ yield async_wrapper (loop , func (* args , ** kwargs ))
220
+ except :
221
+ raise
222
+ finally :
223
+ if created :
224
+ loop .close ()
225
+
226
+ return _wrapper
227
+
228
+
229
+ def wrap_asyncgen (func : Callable ) -> Callable :
230
+ """Wrapper for an async_generator function.
231
+
232
+ :param func: The function to wrap.
233
+
234
+ :returns: The wrapped function. The wrapped function will raise ValueError
235
+ if the generator yields more than once.
236
+ """
237
+
238
+ def _wrapper (loop : asyncio .events .AbstractEventLoop , async_obj ):
239
+ result = loop .run_until_complete (async_obj .__anext__ ())
240
+ try :
241
+ loop .run_until_complete (async_obj .__anext__ ())
242
+ except StopAsyncIteration :
243
+ pass
244
+ else :
245
+ msg = "Async genetator should yield only once."
246
+ raise ValueError (msg )
247
+
248
+ return result
249
+
250
+ return _synchronize (func , _wrapper )
251
+
252
+
253
+ def wrap_coroutine (func : Callable ) -> Callable :
254
+ """Wrapper for a coroutine function.
255
+
256
+ :param func: The function to wrap.
257
+
258
+ :returns: The wrapped function.
259
+ """
260
+
261
+ def _wrapper (loop : asyncio .events .AbstractEventLoop , async_obj ):
262
+ return loop .run_until_complete (async_obj )
263
+
264
+ return _synchronize (func , _wrapper )
265
+
266
+
185
267
def find_unique_name (name : str , seen : Iterable [str ]) -> str :
186
268
"""Find unique name among a set of strings.
187
269
0 commit comments