Skip to content

Commit 9511daa

Browse files
sobolevnpre-commit-ci[bot]ilevkivskyi
authored
Support better __post_init__ method signature for dataclasses (#15503)
Now we use a similar approach to #14849 First, we generate a private name to store in a metadata (with `-`, so - no conflicts, ever). Next, we check override to be compatible: we take the currect signature and compare it to the ideal one we have. Simple and it works :) Closes #15498 Closes #9254 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ivan Levkivskyi <[email protected]>
1 parent 9ad3f38 commit 9511daa

File tree

6 files changed

+318
-17
lines changed

6 files changed

+318
-17
lines changed

mypy/checker.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
from mypy.options import Options
137137
from mypy.patterns import AsPattern, StarredPattern
138138
from mypy.plugin import CheckerPluginInterface, Plugin
139+
from mypy.plugins import dataclasses as dataclasses_plugin
139140
from mypy.scope import Scope
140141
from mypy.semanal import is_trivial_body, refers_to_fullname, set_callable_name
141142
from mypy.semanal_enum import ENUM_BASES, ENUM_SPECIAL_PROPS
@@ -1044,6 +1045,9 @@ def check_func_item(
10441045

10451046
if name == "__exit__":
10461047
self.check__exit__return_type(defn)
1048+
if name == "__post_init__":
1049+
if dataclasses_plugin.is_processed_dataclass(defn.info):
1050+
dataclasses_plugin.check_post_init(self, defn, defn.info)
10471051

10481052
@contextmanager
10491053
def enter_attribute_inference_context(self) -> Iterator[None]:
@@ -1851,7 +1855,7 @@ def check_method_or_accessor_override_for_base(
18511855
found_base_method = True
18521856

18531857
# Check the type of override.
1854-
if name not in ("__init__", "__new__", "__init_subclass__"):
1858+
if name not in ("__init__", "__new__", "__init_subclass__", "__post_init__"):
18551859
# Check method override
18561860
# (__init__, __new__, __init_subclass__ are special).
18571861
if self.check_method_override_for_base_with_name(defn, name, base):
@@ -2812,6 +2816,9 @@ def check_assignment(
28122816
if name == "__match_args__" and inferred is not None:
28132817
typ = self.expr_checker.accept(rvalue)
28142818
self.check_match_args(inferred, typ, lvalue)
2819+
if name == "__post_init__":
2820+
if dataclasses_plugin.is_processed_dataclass(self.scope.active_class()):
2821+
self.fail(message_registry.DATACLASS_POST_INIT_MUST_BE_A_FUNCTION, rvalue)
28152822

28162823
# Defer PartialType's super type checking.
28172824
if (

mypy/message_registry.py

+1
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
277277
DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL: Final = (
278278
'"alias" argument to dataclass field must be a string literal'
279279
)
280+
DATACLASS_POST_INIT_MUST_BE_A_FUNCTION: Final = '"__post_init__" method must be an instance method'
280281

281282
# fastparse
282283
FAILED_TO_MERGE_OVERLOADS: Final = ErrorMessage(

mypy/messages.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -1253,18 +1253,21 @@ def argument_incompatible_with_supertype(
12531253
code=codes.OVERRIDE,
12541254
secondary_context=secondary_context,
12551255
)
1256-
self.note(
1257-
"This violates the Liskov substitution principle",
1258-
context,
1259-
code=codes.OVERRIDE,
1260-
secondary_context=secondary_context,
1261-
)
1262-
self.note(
1263-
"See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides",
1264-
context,
1265-
code=codes.OVERRIDE,
1266-
secondary_context=secondary_context,
1267-
)
1256+
if name != "__post_init__":
1257+
# `__post_init__` is special, it can be incompatible by design.
1258+
# So, this note is misleading.
1259+
self.note(
1260+
"This violates the Liskov substitution principle",
1261+
context,
1262+
code=codes.OVERRIDE,
1263+
secondary_context=secondary_context,
1264+
)
1265+
self.note(
1266+
"See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides",
1267+
context,
1268+
code=codes.OVERRIDE,
1269+
secondary_context=secondary_context,
1270+
)
12681271

12691272
if name == "__eq__" and type_name:
12701273
multiline_msg = self.comparison_method_example_msg(class_name=type_name)

mypy/plugins/dataclasses.py

+82-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Iterator, Optional
5+
from typing import TYPE_CHECKING, Iterator, Optional
66
from typing_extensions import Final
77

88
from mypy import errorcodes, message_registry
@@ -26,6 +26,7 @@
2626
DataclassTransformSpec,
2727
Expression,
2828
FuncDef,
29+
FuncItem,
2930
IfStmt,
3031
JsonDict,
3132
NameExpr,
@@ -55,6 +56,7 @@
5556
from mypy.types import (
5657
AnyType,
5758
CallableType,
59+
FunctionLike,
5860
Instance,
5961
LiteralType,
6062
NoneType,
@@ -69,19 +71,23 @@
6971
)
7072
from mypy.typevars import fill_typevars
7173

74+
if TYPE_CHECKING:
75+
from mypy.checker import TypeChecker
76+
7277
# The set of decorators that generate dataclasses.
7378
dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"}
7479

7580

7681
SELF_TVAR_NAME: Final = "_DT"
77-
_TRANSFORM_SPEC_FOR_DATACLASSES = DataclassTransformSpec(
82+
_TRANSFORM_SPEC_FOR_DATACLASSES: Final = DataclassTransformSpec(
7883
eq_default=True,
7984
order_default=False,
8085
kw_only_default=False,
8186
frozen_default=False,
8287
field_specifiers=("dataclasses.Field", "dataclasses.field"),
8388
)
84-
_INTERNAL_REPLACE_SYM_NAME = "__mypy-replace"
89+
_INTERNAL_REPLACE_SYM_NAME: Final = "__mypy-replace"
90+
_INTERNAL_POST_INIT_SYM_NAME: Final = "__mypy-__post_init__"
8591

8692

8793
class DataclassAttribute:
@@ -350,6 +356,8 @@ def transform(self) -> bool:
350356

351357
if self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
352358
self._add_internal_replace_method(attributes)
359+
if "__post_init__" in info.names:
360+
self._add_internal_post_init_method(attributes)
353361

354362
info.metadata["dataclass"] = {
355363
"attributes": [attr.serialize() for attr in attributes],
@@ -385,7 +393,47 @@ def _add_internal_replace_method(self, attributes: list[DataclassAttribute]) ->
385393
fallback=self._api.named_type("builtins.function"),
386394
)
387395

388-
self._cls.info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
396+
info.names[_INTERNAL_REPLACE_SYM_NAME] = SymbolTableNode(
397+
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
398+
)
399+
400+
def _add_internal_post_init_method(self, attributes: list[DataclassAttribute]) -> None:
401+
arg_types: list[Type] = [fill_typevars(self._cls.info)]
402+
arg_kinds = [ARG_POS]
403+
arg_names: list[str | None] = ["self"]
404+
405+
info = self._cls.info
406+
for attr in attributes:
407+
if not attr.is_init_var:
408+
continue
409+
attr_type = attr.expand_type(info)
410+
assert attr_type is not None
411+
arg_types.append(attr_type)
412+
# We always use `ARG_POS` without a default value, because it is practical.
413+
# Consider this case:
414+
#
415+
# @dataclass
416+
# class My:
417+
# y: dataclasses.InitVar[str] = 'a'
418+
# def __post_init__(self, y: str) -> None: ...
419+
#
420+
# We would be *required* to specify `y: str = ...` if default is added here.
421+
# But, most people won't care about adding default values to `__post_init__`,
422+
# because it is not designed to be called directly, and duplicating default values
423+
# for the sake of type-checking is unpleasant.
424+
arg_kinds.append(ARG_POS)
425+
arg_names.append(attr.name)
426+
427+
signature = CallableType(
428+
arg_types=arg_types,
429+
arg_kinds=arg_kinds,
430+
arg_names=arg_names,
431+
ret_type=NoneType(),
432+
fallback=self._api.named_type("builtins.function"),
433+
name="__post_init__",
434+
)
435+
436+
info.names[_INTERNAL_POST_INIT_SYM_NAME] = SymbolTableNode(
389437
kind=MDEF, node=FuncDef(typ=signature), plugin_generated=True
390438
)
391439

@@ -1052,3 +1100,33 @@ def replace_function_sig_callback(ctx: FunctionSigContext) -> CallableType:
10521100
fallback=ctx.default_signature.fallback,
10531101
name=f"{ctx.default_signature.name} of {inst_type_str}",
10541102
)
1103+
1104+
1105+
def is_processed_dataclass(info: TypeInfo | None) -> bool:
1106+
return info is not None and "dataclass" in info.metadata
1107+
1108+
1109+
def check_post_init(api: TypeChecker, defn: FuncItem, info: TypeInfo) -> None:
1110+
if defn.type is None:
1111+
return
1112+
1113+
ideal_sig = info.get_method(_INTERNAL_POST_INIT_SYM_NAME)
1114+
if ideal_sig is None or ideal_sig.type is None:
1115+
return
1116+
1117+
# We set it ourself, so it is always fine:
1118+
assert isinstance(ideal_sig.type, ProperType)
1119+
assert isinstance(ideal_sig.type, FunctionLike)
1120+
# Type of `FuncItem` is always `FunctionLike`:
1121+
assert isinstance(defn.type, FunctionLike)
1122+
1123+
api.check_override(
1124+
override=defn.type,
1125+
original=ideal_sig.type,
1126+
name="__post_init__",
1127+
name_in_super="__post_init__",
1128+
supertype="dataclass",
1129+
original_class_or_static=False,
1130+
override_class_or_static=False,
1131+
node=defn,
1132+
)

0 commit comments

Comments
 (0)