Skip to content

Commit e156c63

Browse files
authored
inspect, asyncio: Use more TypeGuards (#8057)
1 parent 4b504c7 commit e156c63

File tree

3 files changed

+83
-17
lines changed

3 files changed

+83
-17
lines changed

stdlib/asyncio/coroutines.pyi

+16-9
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,28 @@
11
import sys
2-
from collections.abc import Coroutine
3-
from typing import Any
4-
from typing_extensions import TypeGuard
2+
from collections.abc import Awaitable, Callable, Coroutine
3+
from typing import Any, TypeVar, overload
4+
from typing_extensions import ParamSpec, TypeGuard
55

66
if sys.version_info >= (3, 11):
77
__all__ = ("iscoroutinefunction", "iscoroutine")
88
else:
99
__all__ = ("coroutine", "iscoroutinefunction", "iscoroutine")
1010

11-
if sys.version_info < (3, 11):
12-
from collections.abc import Callable
13-
from typing import TypeVar
11+
_T = TypeVar("_T")
12+
_FunctionT = TypeVar("_FunctionT", bound=Callable[..., Any])
13+
_P = ParamSpec("_P")
1414

15-
_F = TypeVar("_F", bound=Callable[..., Any])
16-
def coroutine(func: _F) -> _F: ...
15+
if sys.version_info < (3, 11):
16+
def coroutine(func: _FunctionT) -> _FunctionT: ...
1717

18-
def iscoroutinefunction(func: object) -> bool: ...
18+
@overload
19+
def iscoroutinefunction(func: Callable[..., Coroutine[Any, Any, Any]]) -> bool: ...
20+
@overload
21+
def iscoroutinefunction(func: Callable[_P, Awaitable[_T]]) -> TypeGuard[Callable[_P, Coroutine[Any, Any, _T]]]: ...
22+
@overload
23+
def iscoroutinefunction(func: Callable[_P, object]) -> TypeGuard[Callable[_P, Coroutine[Any, Any, Any]]]: ...
24+
@overload
25+
def iscoroutinefunction(func: object) -> TypeGuard[Callable[..., Coroutine[Any, Any, Any]]]: ...
1926

2027
# Can actually be a generator-style coroutine on Python 3.7
2128
def iscoroutine(obj: object) -> TypeGuard[Coroutine[Any, Any, Any]]: ...

stdlib/inspect.pyi

+43-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import sys
44
import types
55
from _typeshed import Self
66
from collections import OrderedDict
7-
from collections.abc import Awaitable, Callable, Coroutine, Generator, Mapping, Sequence, Set as AbstractSet
7+
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Generator, Mapping, Sequence, Set as AbstractSet
88
from types import (
99
AsyncGeneratorType,
1010
BuiltinFunctionType,
@@ -25,7 +25,7 @@ from types import (
2525
TracebackType,
2626
WrapperDescriptorType,
2727
)
28-
from typing import Any, ClassVar, NamedTuple, Protocol, TypeVar, Union
28+
from typing import Any, ClassVar, NamedTuple, Protocol, TypeVar, Union, overload
2929
from typing_extensions import Literal, ParamSpec, TypeAlias, TypeGuard
3030

3131
if sys.version_info >= (3, 11):
@@ -129,6 +129,7 @@ if sys.version_info >= (3, 11):
129129
]
130130

131131
_P = ParamSpec("_P")
132+
_T = TypeVar("_T")
132133
_T_cont = TypeVar("_T_cont", contravariant=True)
133134
_V_cont = TypeVar("_V_cont", contravariant=True)
134135

@@ -176,22 +177,56 @@ def ismethod(object: object) -> TypeGuard[MethodType]: ...
176177
def isfunction(object: object) -> TypeGuard[FunctionType]: ...
177178

178179
if sys.version_info >= (3, 8):
179-
def isgeneratorfunction(obj: object) -> bool: ...
180-
def iscoroutinefunction(obj: object) -> bool: ...
180+
@overload
181+
def isgeneratorfunction(obj: Callable[..., Generator[Any, Any, Any]]) -> bool: ...
182+
@overload
183+
def isgeneratorfunction(obj: Callable[_P, Any]) -> TypeGuard[Callable[_P, GeneratorType[Any, Any, Any]]]: ...
184+
@overload
185+
def isgeneratorfunction(obj: object) -> TypeGuard[Callable[..., GeneratorType[Any, Any, Any]]]: ...
186+
@overload
187+
def iscoroutinefunction(obj: Callable[..., Coroutine[Any, Any, Any]]) -> bool: ...
188+
@overload
189+
def iscoroutinefunction(obj: Callable[_P, Awaitable[_T]]) -> TypeGuard[Callable[_P, CoroutineType[Any, Any, _T]]]: ...
190+
@overload
191+
def iscoroutinefunction(obj: Callable[_P, object]) -> TypeGuard[Callable[_P, CoroutineType[Any, Any, Any]]]: ...
192+
@overload
193+
def iscoroutinefunction(obj: object) -> TypeGuard[Callable[..., CoroutineType[Any, Any, Any]]]: ...
181194

182195
else:
183-
def isgeneratorfunction(object: object) -> bool: ...
184-
def iscoroutinefunction(object: object) -> bool: ...
196+
@overload
197+
def isgeneratorfunction(object: Callable[..., Generator[Any, Any, Any]]) -> bool: ...
198+
@overload
199+
def isgeneratorfunction(object: Callable[_P, Any]) -> TypeGuard[Callable[_P, GeneratorType[Any, Any, Any]]]: ...
200+
@overload
201+
def isgeneratorfunction(object: object) -> TypeGuard[Callable[..., GeneratorType[Any, Any, Any]]]: ...
202+
@overload
203+
def iscoroutinefunction(object: Callable[..., Coroutine[Any, Any, Any]]) -> bool: ...
204+
@overload
205+
def iscoroutinefunction(object: Callable[_P, Awaitable[_T]]) -> TypeGuard[Callable[_P, CoroutineType[Any, Any, _T]]]: ...
206+
@overload
207+
def iscoroutinefunction(object: Callable[_P, Any]) -> TypeGuard[Callable[_P, CoroutineType[Any, Any, Any]]]: ...
208+
@overload
209+
def iscoroutinefunction(object: object) -> TypeGuard[Callable[..., CoroutineType[Any, Any, Any]]]: ...
185210

186211
def isgenerator(object: object) -> TypeGuard[GeneratorType[Any, Any, Any]]: ...
187212
def iscoroutine(object: object) -> TypeGuard[CoroutineType[Any, Any, Any]]: ...
188213
def isawaitable(object: object) -> TypeGuard[Awaitable[Any]]: ...
189214

190215
if sys.version_info >= (3, 8):
191-
def isasyncgenfunction(obj: object) -> bool: ...
216+
@overload
217+
def isasyncgenfunction(obj: Callable[..., AsyncGenerator[Any, Any]]) -> bool: ...
218+
@overload
219+
def isasyncgenfunction(obj: Callable[_P, Any]) -> TypeGuard[Callable[_P, AsyncGeneratorType[Any, Any]]]: ...
220+
@overload
221+
def isasyncgenfunction(obj: object) -> TypeGuard[Callable[..., AsyncGeneratorType[Any, Any]]]: ...
192222

193223
else:
194-
def isasyncgenfunction(object: object) -> bool: ...
224+
@overload
225+
def isasyncgenfunction(object: Callable[..., AsyncGenerator[Any, Any]]) -> bool: ...
226+
@overload
227+
def isasyncgenfunction(object: Callable[_P, Any]) -> TypeGuard[Callable[_P, AsyncGeneratorType[Any, Any]]]: ...
228+
@overload
229+
def isasyncgenfunction(object: object) -> TypeGuard[Callable[..., AsyncGeneratorType[Any, Any]]]: ...
195230

196231
class _SupportsSet(Protocol[_T_cont, _V_cont]):
197232
def __set__(self, __instance: _T_cont, __value: _V_cont) -> None: ...
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from asyncio import iscoroutinefunction
2+
from collections.abc import Awaitable, Callable, Coroutine
3+
from typing import Any, Union
4+
from typing_extensions import assert_type
5+
6+
7+
def test_iscoroutinefunction(
8+
x: Callable[[str, int], Coroutine[str, int, bytes]],
9+
y: Callable[[str, int], Awaitable[bytes]],
10+
z: Callable[[str, int], Union[str, Awaitable[bytes]]],
11+
xx: object,
12+
) -> None:
13+
14+
if iscoroutinefunction(x):
15+
assert_type(x, Callable[[str, int], Coroutine[str, int, bytes]])
16+
17+
if iscoroutinefunction(y):
18+
assert_type(y, Callable[[str, int], Coroutine[Any, Any, bytes]])
19+
20+
if iscoroutinefunction(z):
21+
assert_type(z, Callable[[str, int], Coroutine[Any, Any, Any]])
22+
23+
if iscoroutinefunction(xx):
24+
assert_type(xx, Callable[..., Coroutine[Any, Any, Any]])

0 commit comments

Comments
 (0)