diff --git a/stdlib/asyncio/coroutines.pyi b/stdlib/asyncio/coroutines.pyi index 38f2bffd5c2b..14fb627ae6fe 100644 --- a/stdlib/asyncio/coroutines.pyi +++ b/stdlib/asyncio/coroutines.pyi @@ -1,21 +1,28 @@ import sys -from collections.abc import Coroutine -from typing import Any -from typing_extensions import TypeGuard +from collections.abc import Awaitable, Callable, Coroutine +from typing import Any, TypeVar, overload +from typing_extensions import ParamSpec, TypeGuard if sys.version_info >= (3, 11): __all__ = ("iscoroutinefunction", "iscoroutine") else: __all__ = ("coroutine", "iscoroutinefunction", "iscoroutine") -if sys.version_info < (3, 11): - from collections.abc import Callable - from typing import TypeVar +_T = TypeVar("_T") +_FunctionT = TypeVar("_FunctionT", bound=Callable[..., Any]) +_P = ParamSpec("_P") - _F = TypeVar("_F", bound=Callable[..., Any]) - def coroutine(func: _F) -> _F: ... +if sys.version_info < (3, 11): + def coroutine(func: _FunctionT) -> _FunctionT: ... -def iscoroutinefunction(func: object) -> bool: ... +@overload +def iscoroutinefunction(func: Callable[..., Coroutine[Any, Any, Any]]) -> bool: ... +@overload +def iscoroutinefunction(func: Callable[_P, Awaitable[_T]]) -> TypeGuard[Callable[_P, Coroutine[Any, Any, _T]]]: ... +@overload +def iscoroutinefunction(func: Callable[_P, object]) -> TypeGuard[Callable[_P, Coroutine[Any, Any, Any]]]: ... +@overload +def iscoroutinefunction(func: object) -> TypeGuard[Callable[..., Coroutine[Any, Any, Any]]]: ... # Can actually be a generator-style coroutine on Python 3.7 def iscoroutine(obj: object) -> TypeGuard[Coroutine[Any, Any, Any]]: ... diff --git a/stdlib/inspect.pyi b/stdlib/inspect.pyi index 74080af9389b..7f9667c6ebed 100644 --- a/stdlib/inspect.pyi +++ b/stdlib/inspect.pyi @@ -4,7 +4,7 @@ import sys import types from _typeshed import Self from collections import OrderedDict -from collections.abc import Awaitable, Callable, Coroutine, Generator, Mapping, Sequence, Set as AbstractSet +from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Generator, Mapping, Sequence, Set as AbstractSet from types import ( AsyncGeneratorType, BuiltinFunctionType, @@ -25,7 +25,7 @@ from types import ( TracebackType, WrapperDescriptorType, ) -from typing import Any, ClassVar, NamedTuple, Protocol, TypeVar, Union +from typing import Any, ClassVar, NamedTuple, Protocol, TypeVar, Union, overload from typing_extensions import Literal, ParamSpec, TypeAlias, TypeGuard if sys.version_info >= (3, 11): @@ -129,6 +129,7 @@ if sys.version_info >= (3, 11): ] _P = ParamSpec("_P") +_T = TypeVar("_T") _T_cont = TypeVar("_T_cont", contravariant=True) _V_cont = TypeVar("_V_cont", contravariant=True) @@ -176,22 +177,56 @@ def ismethod(object: object) -> TypeGuard[MethodType]: ... def isfunction(object: object) -> TypeGuard[FunctionType]: ... if sys.version_info >= (3, 8): - def isgeneratorfunction(obj: object) -> bool: ... - def iscoroutinefunction(obj: object) -> bool: ... + @overload + def isgeneratorfunction(obj: Callable[..., Generator[Any, Any, Any]]) -> bool: ... + @overload + def isgeneratorfunction(obj: Callable[_P, Any]) -> TypeGuard[Callable[_P, GeneratorType[Any, Any, Any]]]: ... + @overload + def isgeneratorfunction(obj: object) -> TypeGuard[Callable[..., GeneratorType[Any, Any, Any]]]: ... + @overload + def iscoroutinefunction(obj: Callable[..., Coroutine[Any, Any, Any]]) -> bool: ... + @overload + def iscoroutinefunction(obj: Callable[_P, Awaitable[_T]]) -> TypeGuard[Callable[_P, CoroutineType[Any, Any, _T]]]: ... + @overload + def iscoroutinefunction(obj: Callable[_P, object]) -> TypeGuard[Callable[_P, CoroutineType[Any, Any, Any]]]: ... + @overload + def iscoroutinefunction(obj: object) -> TypeGuard[Callable[..., CoroutineType[Any, Any, Any]]]: ... else: - def isgeneratorfunction(object: object) -> bool: ... - def iscoroutinefunction(object: object) -> bool: ... + @overload + def isgeneratorfunction(object: Callable[..., Generator[Any, Any, Any]]) -> bool: ... + @overload + def isgeneratorfunction(object: Callable[_P, Any]) -> TypeGuard[Callable[_P, GeneratorType[Any, Any, Any]]]: ... + @overload + def isgeneratorfunction(object: object) -> TypeGuard[Callable[..., GeneratorType[Any, Any, Any]]]: ... + @overload + def iscoroutinefunction(object: Callable[..., Coroutine[Any, Any, Any]]) -> bool: ... + @overload + def iscoroutinefunction(object: Callable[_P, Awaitable[_T]]) -> TypeGuard[Callable[_P, CoroutineType[Any, Any, _T]]]: ... + @overload + def iscoroutinefunction(object: Callable[_P, Any]) -> TypeGuard[Callable[_P, CoroutineType[Any, Any, Any]]]: ... + @overload + def iscoroutinefunction(object: object) -> TypeGuard[Callable[..., CoroutineType[Any, Any, Any]]]: ... def isgenerator(object: object) -> TypeGuard[GeneratorType[Any, Any, Any]]: ... def iscoroutine(object: object) -> TypeGuard[CoroutineType[Any, Any, Any]]: ... def isawaitable(object: object) -> TypeGuard[Awaitable[Any]]: ... if sys.version_info >= (3, 8): - def isasyncgenfunction(obj: object) -> bool: ... + @overload + def isasyncgenfunction(obj: Callable[..., AsyncGenerator[Any, Any]]) -> bool: ... + @overload + def isasyncgenfunction(obj: Callable[_P, Any]) -> TypeGuard[Callable[_P, AsyncGeneratorType[Any, Any]]]: ... + @overload + def isasyncgenfunction(obj: object) -> TypeGuard[Callable[..., AsyncGeneratorType[Any, Any]]]: ... else: - def isasyncgenfunction(object: object) -> bool: ... + @overload + def isasyncgenfunction(object: Callable[..., AsyncGenerator[Any, Any]]) -> bool: ... + @overload + def isasyncgenfunction(object: Callable[_P, Any]) -> TypeGuard[Callable[_P, AsyncGeneratorType[Any, Any]]]: ... + @overload + def isasyncgenfunction(object: object) -> TypeGuard[Callable[..., AsyncGeneratorType[Any, Any]]]: ... class _SupportsSet(Protocol[_T_cont, _V_cont]): def __set__(self, __instance: _T_cont, __value: _V_cont) -> None: ... diff --git a/test_cases/stdlib/asyncio/test_coroutines.py b/test_cases/stdlib/asyncio/test_coroutines.py new file mode 100644 index 000000000000..e2c3d19e6fb6 --- /dev/null +++ b/test_cases/stdlib/asyncio/test_coroutines.py @@ -0,0 +1,24 @@ +from asyncio import iscoroutinefunction +from collections.abc import Awaitable, Callable, Coroutine +from typing import Any, Union +from typing_extensions import assert_type + + +def test_iscoroutinefunction( + x: Callable[[str, int], Coroutine[str, int, bytes]], + y: Callable[[str, int], Awaitable[bytes]], + z: Callable[[str, int], Union[str, Awaitable[bytes]]], + xx: object, +) -> None: + + if iscoroutinefunction(x): + assert_type(x, Callable[[str, int], Coroutine[str, int, bytes]]) + + if iscoroutinefunction(y): + assert_type(y, Callable[[str, int], Coroutine[Any, Any, bytes]]) + + if iscoroutinefunction(z): + assert_type(z, Callable[[str, int], Coroutine[Any, Any, Any]]) + + if iscoroutinefunction(xx): + assert_type(xx, Callable[..., Coroutine[Any, Any, Any]])