Skip to content

Commit 5313497

Browse files
authored
stubgen: add support for PEPs 695 and 696 syntax (#18054)
closes #17997
1 parent 7f09f0c commit 5313497

File tree

4 files changed

+139
-5
lines changed

4 files changed

+139
-5
lines changed

mypy/stubdoc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class FunctionSig(NamedTuple):
7676
name: str
7777
args: list[ArgSig]
7878
ret_type: str | None
79+
type_args: str = "" # TODO implement in stubgenc and remove the default
7980

8081
def is_special_method(self) -> bool:
8182
return bool(
@@ -141,9 +142,7 @@ def format_sig(
141142
retfield = " -> " + ret_type
142143

143144
prefix = "async " if is_async else ""
144-
sig = "{indent}{prefix}def {name}({args}){ret}:".format(
145-
indent=indent, prefix=prefix, name=self.name, args=", ".join(args), ret=retfield
146-
)
145+
sig = f"{indent}{prefix}def {self.name}{self.type_args}({', '.join(args)}){retfield}:"
147146
if docstring:
148147
suffix = f"\n{indent} {mypy.util.quote_docstring(docstring)}"
149148
else:

mypy/stubgen.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
StrExpr,
107107
TempNode,
108108
TupleExpr,
109+
TypeAliasStmt,
109110
TypeInfo,
110111
UnaryExpr,
111112
Var,
@@ -398,6 +399,9 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
398399
for name in get_assigned_names(o.lvalues):
399400
self.names.add(name)
400401

402+
def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
403+
self.names.add(o.name.name)
404+
401405

402406
def find_referenced_names(file: MypyFile) -> set[str]:
403407
finder = ReferenceFinder()
@@ -507,7 +511,8 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
507511
def get_default_function_sig(self, func_def: FuncDef, ctx: FunctionContext) -> FunctionSig:
508512
args = self._get_func_args(func_def, ctx)
509513
retname = self._get_func_return(func_def, ctx)
510-
return FunctionSig(func_def.name, args, retname)
514+
type_args = self.format_type_args(func_def)
515+
return FunctionSig(func_def.name, args, retname, type_args)
511516

512517
def _get_func_args(self, o: FuncDef, ctx: FunctionContext) -> list[ArgSig]:
513518
args: list[ArgSig] = []
@@ -765,7 +770,8 @@ def visit_class_def(self, o: ClassDef) -> None:
765770
self.import_tracker.add_import("abc")
766771
self.import_tracker.require_name("abc")
767772
bases = f"({', '.join(base_types)})" if base_types else ""
768-
self.add(f"{self._indent}class {o.name}{bases}:\n")
773+
type_args = self.format_type_args(o)
774+
self.add(f"{self._indent}class {o.name}{type_args}{bases}:\n")
769775
self.indent()
770776
if self._include_docstrings and o.docstring:
771777
docstring = mypy.util.quote_docstring(o.docstring)
@@ -1101,6 +1107,16 @@ def process_typealias(self, lvalue: NameExpr, rvalue: Expression) -> None:
11011107
self.record_name(lvalue.name)
11021108
self._vars[-1].append(lvalue.name)
11031109

1110+
def visit_type_alias_stmt(self, o: TypeAliasStmt) -> None:
1111+
"""Type aliases defined with the `type` keyword (PEP 695)."""
1112+
p = AliasPrinter(self)
1113+
name = o.name.name
1114+
rvalue = o.value.expr()
1115+
type_args = self.format_type_args(o)
1116+
self.add(f"{self._indent}type {name}{type_args} = {rvalue.accept(p)}\n")
1117+
self.record_name(name)
1118+
self._vars[-1].append(name)
1119+
11041120
def visit_if_stmt(self, o: IfStmt) -> None:
11051121
# Ignore if __name__ == '__main__'.
11061122
expr = o.expr[0]

mypy/stubutil.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import mypy.options
1717
from mypy.modulefinder import ModuleNotFoundReason
1818
from mypy.moduleinspect import InspectError, ModuleInspect
19+
from mypy.nodes import PARAM_SPEC_KIND, TYPE_VAR_TUPLE_KIND, ClassDef, FuncDef, TypeAliasStmt
1920
from mypy.stubdoc import ArgSig, FunctionSig
2021
from mypy.types import AnyType, NoneType, Type, TypeList, TypeStrVisitor, UnboundType, UnionType
2122

@@ -777,6 +778,31 @@ def format_func_def(
777778
)
778779
return lines
779780

781+
def format_type_args(self, o: TypeAliasStmt | FuncDef | ClassDef) -> str:
782+
if not o.type_args:
783+
return ""
784+
p = AnnotationPrinter(self)
785+
type_args_list: list[str] = []
786+
for type_arg in o.type_args:
787+
if type_arg.kind == PARAM_SPEC_KIND:
788+
prefix = "**"
789+
elif type_arg.kind == TYPE_VAR_TUPLE_KIND:
790+
prefix = "*"
791+
else:
792+
prefix = ""
793+
if type_arg.upper_bound:
794+
bound_or_values = f": {type_arg.upper_bound.accept(p)}"
795+
elif type_arg.values:
796+
bound_or_values = f": ({', '.join(v.accept(p) for v in type_arg.values)})"
797+
else:
798+
bound_or_values = ""
799+
if type_arg.default:
800+
default = f" = {type_arg.default.accept(p)}"
801+
else:
802+
default = ""
803+
type_args_list.append(f"{prefix}{type_arg.name}{bound_or_values}{default}")
804+
return "[" + ", ".join(type_args_list) + "]"
805+
780806
def print_annotation(
781807
self,
782808
t: Type,

test-data/unit/stubgen.test

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4415,3 +4415,96 @@ class Test(Whatever, a=1, b='b', c=True, d=1.5, e=None, f=1j, g=b'123'): ...
44154415
class Test(Whatever, keyword=SomeName * 2, attr=SomeName.attr): ...
44164416
[out]
44174417
class Test(Whatever, keyword=SomeName * 2, attr=SomeName.attr): ...
4418+
4419+
[case testPEP695GenericClass]
4420+
# flags: --python-version=3.12
4421+
4422+
class C[T]: ...
4423+
class C1[T1](int): ...
4424+
class C2[T2: int]: ...
4425+
class C3[T3: str | bytes]: ...
4426+
class C4[T4: (str, bytes)]: ...
4427+
4428+
class Outer:
4429+
class Inner[T]: ...
4430+
4431+
[out]
4432+
class C[T]: ...
4433+
class C1[T1](int): ...
4434+
class C2[T2: int]: ...
4435+
class C3[T3: str | bytes]: ...
4436+
class C4[T4: (str, bytes)]: ...
4437+
4438+
class Outer:
4439+
class Inner[T]: ...
4440+
4441+
[case testPEP695GenericFunction]
4442+
# flags: --python-version=3.12
4443+
4444+
def f1[T1](): ...
4445+
def f2[T2: int](): ...
4446+
def f3[T3: str | bytes](): ...
4447+
def f4[T4: (str, bytes)](): ...
4448+
4449+
class C:
4450+
def m[T](self, x: T) -> T: ...
4451+
4452+
[out]
4453+
def f1[T1]() -> None: ...
4454+
def f2[T2: int]() -> None: ...
4455+
def f3[T3: str | bytes]() -> None: ...
4456+
def f4[T4: (str, bytes)]() -> None: ...
4457+
4458+
class C:
4459+
def m[T](self, x: T) -> T: ...
4460+
4461+
[case testPEP695TypeAlias]
4462+
# flags: --python-version=3.12
4463+
4464+
type Alias = int | str
4465+
type Alias1[T1] = list[T1] | set[T1]
4466+
type Alias2[T2: int] = list[T2] | set[T2]
4467+
type Alias3[T3: str | bytes] = list[T3] | set[T3]
4468+
type Alias4[T4: (str, bytes)] = list[T4] | set[T4]
4469+
4470+
class C:
4471+
type IndentedAlias[T] = list[T]
4472+
4473+
[out]
4474+
type Alias = int | str
4475+
type Alias1[T1] = list[T1] | set[T1]
4476+
type Alias2[T2: int] = list[T2] | set[T2]
4477+
type Alias3[T3: str | bytes] = list[T3] | set[T3]
4478+
type Alias4[T4: (str, bytes)] = list[T4] | set[T4]
4479+
class C:
4480+
type IndentedAlias[T] = list[T]
4481+
4482+
[case testPEP695Syntax_semanal]
4483+
# flags: --python-version=3.12
4484+
4485+
class C[T]: ...
4486+
def f[S](): ...
4487+
type A[R] = list[R]
4488+
4489+
[out]
4490+
class C[T]: ...
4491+
4492+
def f[S]() -> None: ...
4493+
type A[R] = list[R]
4494+
4495+
[case testPEP696Syntax]
4496+
# flags: --python-version=3.13
4497+
4498+
type Alias1[T1 = int] = list[T1] | set[T1]
4499+
type Alias2[T2: int | float = int] = list[T2] | set[T2]
4500+
class C3[T3 = int]: ...
4501+
class C4[T4: int | float = int](list[T4]): ...
4502+
def f5[T5 = int](): ...
4503+
4504+
[out]
4505+
type Alias1[T1 = int] = list[T1] | set[T1]
4506+
type Alias2[T2: int | float = int] = list[T2] | set[T2]
4507+
class C3[T3 = int]: ...
4508+
class C4[T4: int | float = int](list[T4]): ...
4509+
4510+
def f5[T5 = int]() -> None: ...

0 commit comments

Comments
 (0)