Skip to content

Commit b7d1d09

Browse files
authored
add TypeGuard to coroutines.iscoroutine (#6105)
make CoroutineType extend Coroutine
1 parent e018ad6 commit b7d1d09

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

stdlib/asyncio/coroutines.pyi

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1-
from typing import Any, Callable, TypeVar
1+
import sys
2+
import types
3+
from collections.abc import Callable, Coroutine
4+
from typing import Any, TypeVar
5+
from typing_extensions import TypeGuard
26

37
_F = TypeVar("_F", bound=Callable[..., Any])
48

59
def coroutine(func: _F) -> _F: ...
610
def iscoroutinefunction(func: object) -> bool: ...
7-
def iscoroutine(obj: object) -> bool: ...
11+
12+
if sys.version_info < (3, 8):
13+
def iscoroutine(obj: object) -> TypeGuard[types.GeneratorType[Any, Any, Any] | Coroutine[Any, Any, Any]]: ...
14+
15+
else:
16+
def iscoroutine(obj: object) -> TypeGuard[Coroutine[Any, Any, Any]]: ...

stdlib/inspect.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ else:
6262
def iscoroutinefunction(object: object) -> bool: ...
6363

6464
def isgenerator(object: object) -> TypeGuard[GeneratorType[Any, Any, Any]]: ...
65-
def iscoroutine(object: object) -> TypeGuard[CoroutineType]: ...
65+
def iscoroutine(object: object) -> TypeGuard[CoroutineType[Any, Any, Any]]: ...
6666
def isawaitable(object: object) -> TypeGuard[Awaitable[Any]]: ...
6767

6868
if sys.version_info >= (3, 8):

stdlib/types.pyi

+9-6
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from typing import (
66
AsyncGenerator,
77
Awaitable,
88
Callable,
9+
Coroutine,
910
Generator,
1011
Generic,
1112
ItemsView,
@@ -211,20 +212,22 @@ class AsyncGeneratorType(AsyncGenerator[_T_co, _T_contra]):
211212
def aclose(self) -> Awaitable[None]: ...
212213

213214
@final
214-
class CoroutineType:
215+
class CoroutineType(Coroutine[_T_co, _T_contra, _V_co]):
215216
__name__: str
216217
__qualname__: str
217218
cr_await: Any | None
218219
cr_code: CodeType
219220
cr_frame: FrameType
220221
cr_running: bool
221222
def close(self) -> None: ...
222-
def __await__(self) -> Generator[Any, None, Any]: ...
223-
def send(self, __arg: Any) -> Any: ...
223+
def __await__(self) -> Generator[Any, None, _V_co]: ...
224+
def send(self, __arg: _T_contra) -> _T_co: ...
224225
@overload
225-
def throw(self, __typ: Type[BaseException], __val: BaseException | object = ..., __tb: TracebackType | None = ...) -> Any: ...
226+
def throw(
227+
self, __typ: Type[BaseException], __val: BaseException | object = ..., __tb: TracebackType | None = ...
228+
) -> _T_co: ...
226229
@overload
227-
def throw(self, __typ: BaseException, __val: None = ..., __tb: TracebackType | None = ...) -> Any: ...
230+
def throw(self, __typ: BaseException, __val: None = ..., __tb: TracebackType | None = ...) -> _T_co: ...
228231

229232
class _StaticFunctionType:
230233
"""Fictional type to correct the type of MethodType.__func__.
@@ -365,7 +368,7 @@ def prepare_class(
365368
# Actually a different type, but `property` is special and we want that too.
366369
DynamicClassAttribute = property
367370

368-
def coroutine(func: Callable[..., Any]) -> CoroutineType: ...
371+
def coroutine(func: Callable[..., Any]) -> CoroutineType[Any, Any, Any]: ...
369372

370373
if sys.version_info >= (3, 8):
371374
CellType = _Cell

0 commit comments

Comments
 (0)