Skip to content

Commit 4508b98

Browse files
committed
fix(typing): improve decorator type hinting
The type hinting for the most commonly used decorators were incomplete, resulting in decorated functions being obscured. This makes use of the special type variable `ParamSpec` which allows the type hinting a view into the parameters of a function. As ``ParamSpec` was introduced in Python 3.10, `ParamSpec` is imported from the `typing_extensions` module instead of the standard library. I have also taken the opportunity to fix other instances of `Callable` type hints missing their arguments. Signed-off-by: JP-Ellis <[email protected]>
1 parent 5707669 commit 4508b98

File tree

5 files changed

+48
-29
lines changed

5 files changed

+48
-29
lines changed

src/pytest_bdd/plugin.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
"""Pytest plugin entry point. Used for any fixtures needed."""
22
from __future__ import annotations
33

4-
from typing import TYPE_CHECKING, Callable, cast
4+
from typing import TYPE_CHECKING, Any, Callable, Generator, TypeVar, cast
55

66
import pytest
7+
from typing_extensions import ParamSpec
78

89
from . import cucumber_json, generation, gherkin_terminal_reporter, given, reporting, then, when
910
from .utils import CONFIG_STACK
1011

1112
if TYPE_CHECKING:
12-
from typing import Any, Generator
13-
1413
from _pytest.config import Config, PytestPluginManager
1514
from _pytest.config.argparsing import Parser
1615
from _pytest.fixtures import FixtureRequest
@@ -21,6 +20,10 @@
2120
from .parser import Feature, Scenario, Step
2221

2322

23+
P = ParamSpec("P")
24+
T = TypeVar("T")
25+
26+
2427
def pytest_addhooks(pluginmanager: PytestPluginManager) -> None:
2528
"""Register plugin hooks."""
2629
from pytest_bdd import hooks
@@ -93,7 +96,7 @@ def pytest_bdd_step_error(
9396
feature: Feature,
9497
scenario: Scenario,
9598
step: Step,
96-
step_func: Callable,
99+
step_func: Callable[..., Any],
97100
step_func_args: dict,
98101
exception: Exception,
99102
) -> None:
@@ -102,7 +105,11 @@ def pytest_bdd_step_error(
102105

103106
@pytest.hookimpl(tryfirst=True)
104107
def pytest_bdd_before_step(
105-
request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable
108+
request: FixtureRequest,
109+
feature: Feature,
110+
scenario: Scenario,
111+
step: Step,
112+
step_func: Callable[..., Any],
106113
) -> None:
107114
reporting.before_step(request, feature, scenario, step, step_func)
108115

@@ -113,7 +120,7 @@ def pytest_bdd_after_step(
113120
feature: Feature,
114121
scenario: Scenario,
115122
step: Step,
116-
step_func: Callable,
123+
step_func: Callable[..., Any],
117124
step_func_args: dict[str, Any],
118125
) -> None:
119126
reporting.after_step(request, feature, scenario, step, step_func, step_func_args)
@@ -123,7 +130,7 @@ def pytest_cmdline_main(config: Config) -> int | None:
123130
return generation.cmdline_main(config)
124131

125132

126-
def pytest_bdd_apply_tag(tag: str, function: Callable) -> Callable:
133+
def pytest_bdd_apply_tag(tag: str, function: Callable[P, T]) -> Callable[P, T]:
127134
mark = getattr(pytest.mark, tag)
128135
marked = mark(function)
129-
return cast(Callable, marked)
136+
return cast(Callable[P, T], marked)

src/pytest_bdd/reporting.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,15 +155,21 @@ def step_error(
155155
feature: Feature,
156156
scenario: Scenario,
157157
step: Step,
158-
step_func: Callable,
158+
step_func: Callable[..., Any],
159159
step_func_args: dict,
160160
exception: Exception,
161161
) -> None:
162162
"""Finalize the step report as failed."""
163163
request.node.__scenario_report__.fail()
164164

165165

166-
def before_step(request: FixtureRequest, feature: Feature, scenario: Scenario, step: Step, step_func: Callable) -> None:
166+
def before_step(
167+
request: FixtureRequest,
168+
feature: Feature,
169+
scenario: Scenario,
170+
step: Step,
171+
step_func: Callable[..., Any],
172+
) -> None:
167173
"""Store step start time."""
168174
request.node.__scenario_report__.add_step_report(StepReport(step=step))
169175

src/pytest_bdd/scenario.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,25 @@
1616
import logging
1717
import os
1818
import re
19-
from typing import TYPE_CHECKING, Callable, Iterator, cast
19+
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, cast
2020

2121
import pytest
2222
from _pytest.fixtures import FixtureDef, FixtureManager, FixtureRequest, call_fixture_func
2323
from _pytest.nodes import iterparentnodeids
24+
from typing_extensions import ParamSpec
2425

2526
from . import exceptions
2627
from .feature import get_feature, get_features
2728
from .steps import StepFunctionContext, get_step_fixture_name, inject_fixture
2829
from .utils import CONFIG_STACK, get_args, get_caller_module_locals, get_caller_module_path
2930

3031
if TYPE_CHECKING:
31-
from typing import Any, Iterable
32-
3332
from _pytest.mark.structures import ParameterSet
3433

3534
from .parser import Feature, Scenario, ScenarioTemplate, Step
3635

36+
P = ParamSpec("P")
37+
T = TypeVar("T")
3738

3839
logger = logging.getLogger(__name__)
3940

@@ -197,14 +198,14 @@ def _execute_scenario(feature: Feature, scenario: Scenario, request: FixtureRequ
197198

198199
def _get_scenario_decorator(
199200
feature: Feature, feature_name: str, templated_scenario: ScenarioTemplate, scenario_name: str
200-
) -> Callable[[Callable], Callable]:
201+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
201202
# HACK: Ideally we would use `def decorator(fn)`, but we want to return a custom exception
202203
# when the decorator is misused.
203204
# Pytest inspect the signature to determine the required fixtures, and in that case it would look
204205
# for a fixture called "fn" that doesn't exist (if it exists then it's even worse).
205206
# It will error with a "fixture 'fn' not found" message instead.
206207
# We can avoid this hack by using a pytest hook and check for misuse instead.
207-
def decorator(*args: Callable) -> Callable:
208+
def decorator(*args: Callable[P, T]) -> Callable[P, T]:
208209
if not args:
209210
raise exceptions.ScenarioIsDecoratorOnly(
210211
"scenario function can only be used as a decorator. Refer to the documentation."
@@ -236,7 +237,7 @@ def scenario_wrapper(request: FixtureRequest, _pytest_bdd_example: dict[str, str
236237

237238
scenario_wrapper.__doc__ = f"{feature_name}: {scenario_name}"
238239
scenario_wrapper.__scenario__ = templated_scenario
239-
return cast(Callable, scenario_wrapper)
240+
return cast(Callable[P, T], scenario_wrapper)
240241

241242
return decorator
242243

@@ -254,8 +255,11 @@ def collect_example_parametrizations(
254255

255256

256257
def scenario(
257-
feature_name: str, scenario_name: str, encoding: str = "utf-8", features_base_dir=None
258-
) -> Callable[[Callable], Callable]:
258+
feature_name: str,
259+
scenario_name: str,
260+
encoding: str = "utf-8",
261+
features_base_dir: str | None = None,
262+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
259263
"""Scenario decorator.
260264
261265
:param str feature_name: Feature file name. Absolute or relative to the configured feature base path.

src/pytest_bdd/steps.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,15 @@ def _(article):
4343

4444
import pytest
4545
from _pytest.fixtures import FixtureDef, FixtureRequest
46+
from typing_extensions import ParamSpec
4647

4748
from .parser import Step
4849
from .parsers import StepParser, get_parser
4950
from .types import GIVEN, THEN, WHEN
5051
from .utils import get_caller_module_locals
5152

52-
TCallable = TypeVar("TCallable", bound=Callable[..., Any])
53+
P = ParamSpec("P")
54+
T = TypeVar("T")
5355

5456

5557
@enum.unique
@@ -74,10 +76,10 @@ def get_step_fixture_name(step: Step) -> str:
7476

7577
def given(
7678
name: str | StepParser,
77-
converters: dict[str, Callable] | None = None,
79+
converters: dict[str, Callable[[Any], Any]] | None = None,
7880
target_fixture: str | None = None,
7981
stacklevel: int = 1,
80-
) -> Callable:
82+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
8183
"""Given step decorator.
8284
8385
:param name: Step name or a parser object.
@@ -93,10 +95,10 @@ def given(
9395

9496
def when(
9597
name: str | StepParser,
96-
converters: dict[str, Callable] | None = None,
98+
converters: dict[str, Callable[[Any], Any]] | None = None,
9799
target_fixture: str | None = None,
98100
stacklevel: int = 1,
99-
) -> Callable:
101+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
100102
"""When step decorator.
101103
102104
:param name: Step name or a parser object.
@@ -112,10 +114,10 @@ def when(
112114

113115
def then(
114116
name: str | StepParser,
115-
converters: dict[str, Callable] | None = None,
117+
converters: dict[str, Callable[[Any], Any]] | None = None,
116118
target_fixture: str | None = None,
117119
stacklevel: int = 1,
118-
) -> Callable:
120+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
119121
"""Then step decorator.
120122
121123
:param name: Step name or a parser object.
@@ -132,10 +134,10 @@ def then(
132134
def step(
133135
name: str | StepParser,
134136
type_: Literal["given", "when", "then"] | None = None,
135-
converters: dict[str, Callable] | None = None,
137+
converters: dict[str, Callable[[Any], Any]] | None = None,
136138
target_fixture: str | None = None,
137139
stacklevel: int = 1,
138-
) -> Callable[[TCallable], TCallable]:
140+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
139141
"""Generic step decorator.
140142
141143
:param name: Step name as in the feature file.
@@ -155,7 +157,7 @@ def step(
155157
if converters is None:
156158
converters = {}
157159

158-
def decorator(func: TCallable) -> TCallable:
160+
def decorator(func: Callable[P, T]) -> Callable[P, T]:
159161
parser = get_parser(name)
160162

161163
context = StepFunctionContext(

src/pytest_bdd/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
CONFIG_STACK: list[Config] = []
2020

2121

22-
def get_args(func: Callable) -> list[str]:
22+
def get_args(func: Callable[..., Any]) -> list[str]:
2323
"""Get a list of argument names for a function.
2424
2525
:param func: The function to inspect.

0 commit comments

Comments
 (0)