Skip to content

Commit 6a301ce

Browse files
committed
Implement async steps.
1 parent d45c543 commit 6a301ce

File tree

2 files changed

+668
-0
lines changed

2 files changed

+668
-0
lines changed

src/pytest_bdd/steps.py

+82
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def _(article):
3636
"""
3737
from __future__ import annotations
3838

39+
import asyncio
3940
import enum
41+
import functools
42+
import inspect
43+
from contextlib import contextmanager
4044
from dataclasses import dataclass, field
4145
from itertools import count
4246
from typing import Any, Callable, Iterable, TypeVar
@@ -66,6 +70,7 @@ class StepFunctionContext:
6670
parser: StepParser
6771
converters: dict[str, Callable[..., Any]] = field(default_factory=dict)
6872
target_fixture: str | None = None
73+
is_async: bool = False
6974

7075

7176
def get_step_fixture_name(step: Step) -> str:
@@ -86,6 +91,7 @@ def given(
8691
{<param_name>: <converter function>}.
8792
:param target_fixture: Target fixture name to replace by steps definition function.
8893
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
94+
:param is_async: True if the step is asynchronous. (Default: False)
8995
9096
:return: Decorator function for the step.
9197
"""
@@ -105,6 +111,7 @@ def when(
105111
{<param_name>: <converter function>}.
106112
:param target_fixture: Target fixture name to replace by steps definition function.
107113
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
114+
:param is_async: True if the step is asynchronous. (Default: False)
108115
109116
:return: Decorator function for the step.
110117
"""
@@ -124,6 +131,7 @@ def then(
124131
{<param_name>: <converter function>}.
125132
:param target_fixture: Target fixture name to replace by steps definition function.
126133
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
134+
:param is_async: True if the step is asynchronous. (Default: False)
127135
128136
:return: Decorator function for the step.
129137
"""
@@ -144,6 +152,7 @@ def step(
144152
:param converters: Optional step arguments converters mapping.
145153
:param target_fixture: Optional fixture name to replace by step definition.
146154
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
155+
:param is_async: True if the step is asynchronous. (Default: False)
147156
148157
:return: Decorator function for the step.
149158
@@ -159,6 +168,11 @@ def step(
159168
def decorator(func: TCallable) -> TCallable:
160169
parser = get_parser(name)
161170

171+
if inspect.isasyncgenfunction(func):
172+
func = wrap_asyncgen(func)
173+
elif inspect.iscoroutinefunction(func):
174+
func = wrap_coroutine(func)
175+
162176
context = StepFunctionContext(
163177
type=type_,
164178
step_func=func,
@@ -177,11 +191,79 @@ def step_function_marker() -> StepFunctionContext:
177191
f"{StepNamePrefix.step_def.value}_{type_ or '*'}_{parser.name}", seen=caller_locals.keys()
178192
)
179193
caller_locals[fixture_step_name] = pytest.fixture(name=fixture_step_name)(step_function_marker)
194+
180195
return func
181196

182197
return decorator
183198

184199

200+
def _synchronize(func: Callable, async_wrapper: Callable) -> Callable:
201+
"""Provide a synchronous wrapper for an async function or generator.
202+
203+
:param func: The async function / generator to wrap.
204+
:param async_wrapper: A function taking an event loop and either a
205+
coroutine or an async_generator (the result of calling func)
206+
and returning the result.
207+
208+
:returns: The wrapped async function.
209+
"""
210+
211+
@functools.wraps(func)
212+
def _wrapper(*args, **kwargs):
213+
try:
214+
loop, created = asyncio.get_running_loop(), False
215+
except RuntimeError:
216+
loop, created = asyncio.get_event_loop_policy().new_event_loop(), True
217+
218+
try:
219+
yield async_wrapper(loop, func(*args, **kwargs))
220+
except:
221+
raise
222+
finally:
223+
if created:
224+
loop.close()
225+
226+
return _wrapper
227+
228+
229+
def wrap_asyncgen(func: Callable) -> Callable:
230+
"""Wrapper for an async_generator function.
231+
232+
:param func: The function to wrap.
233+
234+
:returns: The wrapped function. The wrapped function will raise ValueError
235+
if the generator yields more than once.
236+
"""
237+
238+
def _wrapper(loop: asyncio.events.AbstractEventLoop, async_obj):
239+
result = loop.run_until_complete(async_obj.__anext__())
240+
try:
241+
loop.run_until_complete(async_obj.__anext__())
242+
except StopAsyncIteration:
243+
pass
244+
else:
245+
msg = "Async genetator should yield only once."
246+
raise ValueError(msg)
247+
248+
return result
249+
250+
return _synchronize(func, _wrapper)
251+
252+
253+
def wrap_coroutine(func: Callable) -> Callable:
254+
"""Wrapper for a coroutine function.
255+
256+
:param func: The function to wrap.
257+
258+
:returns: The wrapped function.
259+
"""
260+
261+
def _wrapper(loop: asyncio.events.AbstractEventLoop, async_obj):
262+
return loop.run_until_complete(async_obj)
263+
264+
return _synchronize(func, _wrapper)
265+
266+
185267
def find_unique_name(name: str, seen: Iterable[str]) -> str:
186268
"""Find unique name among a set of strings.
187269

0 commit comments

Comments
 (0)