Skip to content

Commit bbc9cce

Browse files
authored
Add signature for attr.evolve (#14526)
Validate `attr.evolve` calls to specify correct arguments and types. The implementation makes it so that at every point where `attr.evolve` is called, the signature is modified to expect the attrs class' initializer's arguments (but so that they're all kw-only and optional). Notes: - Added `class dict: pass` to some fixtures files since our attrs type stubs now have **kwargs and that triggers a `builtin.dict` lookup in dozens of attrs tests. - Looking up the type of the 1st argument with `ctx.api.expr_checker.accept(inst_arg)` which is a hack since it's not part of the plugin API. This is a compromise for due to #10216. Fixes #14525.
1 parent 2ab1d82 commit bbc9cce

File tree

5 files changed

+159
-1
lines changed

5 files changed

+159
-1
lines changed

mypy/plugins/attrs.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from typing_extensions import Final, Literal
77

88
import mypy.plugin # To avoid circular imports.
9+
from mypy.checker import TypeChecker
910
from mypy.errorcodes import LITERAL_REQ
1011
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
12+
from mypy.messages import format_type_bare
1113
from mypy.nodes import (
1214
ARG_NAMED,
1315
ARG_NAMED_OPT,
@@ -77,6 +79,7 @@
7779
SELF_TVAR_NAME: Final = "_AT"
7880
MAGIC_ATTR_NAME: Final = "__attrs_attrs__"
7981
MAGIC_ATTR_CLS_NAME_TEMPLATE: Final = "__{}_AttrsAttributes__" # The tuple subclass pattern.
82+
ATTRS_INIT_NAME: Final = "__attrs_init__"
8083

8184

8285
class Converter:
@@ -330,7 +333,7 @@ def attr_class_maker_callback(
330333

331334
adder = MethodAdder(ctx)
332335
# If __init__ is not being generated, attrs still generates it as __attrs_init__ instead.
333-
_add_init(ctx, attributes, adder, "__init__" if init else "__attrs_init__")
336+
_add_init(ctx, attributes, adder, "__init__" if init else ATTRS_INIT_NAME)
334337
if order:
335338
_add_order(ctx, adder)
336339
if frozen:
@@ -888,3 +891,64 @@ def add_method(
888891
"""
889892
self_type = self_type if self_type is not None else self.self_type
890893
add_method(self.ctx, method_name, args, ret_type, self_type, tvd)
894+
895+
896+
def _get_attrs_init_type(typ: Type) -> CallableType | None:
897+
"""
898+
If `typ` refers to an attrs class, gets the type of its initializer method.
899+
"""
900+
typ = get_proper_type(typ)
901+
if not isinstance(typ, Instance):
902+
return None
903+
magic_attr = typ.type.get(MAGIC_ATTR_NAME)
904+
if magic_attr is None or not magic_attr.plugin_generated:
905+
return None
906+
init_method = typ.type.get_method("__init__") or typ.type.get_method(ATTRS_INIT_NAME)
907+
if not isinstance(init_method, FuncDef) or not isinstance(init_method.type, CallableType):
908+
return None
909+
return init_method.type
910+
911+
912+
def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
913+
"""
914+
Generates a signature for the 'attr.evolve' function that's specific to the call site
915+
and dependent on the type of the first argument.
916+
"""
917+
if len(ctx.args) != 2:
918+
# Ideally the name and context should be callee's, but we don't have it in FunctionSigContext.
919+
ctx.api.fail(f'"{ctx.default_signature.name}" has unexpected type annotation', ctx.context)
920+
return ctx.default_signature
921+
922+
if len(ctx.args[0]) != 1:
923+
return ctx.default_signature # leave it to the type checker to complain
924+
925+
inst_arg = ctx.args[0][0]
926+
927+
# <hack>
928+
assert isinstance(ctx.api, TypeChecker)
929+
inst_type = ctx.api.expr_checker.accept(inst_arg)
930+
# </hack>
931+
932+
inst_type = get_proper_type(inst_type)
933+
if isinstance(inst_type, AnyType):
934+
return ctx.default_signature
935+
inst_type_str = format_type_bare(inst_type)
936+
937+
attrs_init_type = _get_attrs_init_type(inst_type)
938+
if not attrs_init_type:
939+
ctx.api.fail(
940+
f'Argument 1 to "evolve" has incompatible type "{inst_type_str}"; expected an attrs class',
941+
ctx.context,
942+
)
943+
return ctx.default_signature
944+
945+
# AttrClass.__init__ has the following signature (or similar, if having kw-only & defaults):
946+
# def __init__(self, attr1: Type1, attr2: Type2) -> None:
947+
# We want to generate a signature for evolve that looks like this:
948+
# def evolve(inst: AttrClass, *, attr1: Type1 = ..., attr2: Type2 = ...) -> AttrClass:
949+
return attrs_init_type.copy_modified(
950+
arg_names=["inst"] + attrs_init_type.arg_names[1:],
951+
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT for _ in attrs_init_type.arg_kinds[1:]],
952+
ret_type=inst_type,
953+
name=f"{ctx.default_signature.name} of {inst_type_str}",
954+
)

mypy/plugins/default.py

+10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
AttributeContext,
1111
ClassDefContext,
1212
FunctionContext,
13+
FunctionSigContext,
1314
MethodContext,
1415
MethodSigContext,
1516
Plugin,
@@ -46,6 +47,15 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type]
4647
return singledispatch.create_singledispatch_function_callback
4748
return None
4849

50+
def get_function_signature_hook(
51+
self, fullname: str
52+
) -> Callable[[FunctionSigContext], FunctionLike] | None:
53+
from mypy.plugins import attrs
54+
55+
if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
56+
return attrs.evolve_function_sig_callback
57+
return None
58+
4959
def get_method_signature_hook(
5060
self, fullname: str
5161
) -> Callable[[MethodSigContext], FunctionLike] | None:

test-data/unit/check-attr.test

+78
Original file line numberDiff line numberDiff line change
@@ -1867,3 +1867,81 @@ D(1, "").a = 2 # E: Cannot assign to final attribute "a"
18671867
D(1, "").b = "2" # E: Cannot assign to final attribute "b"
18681868

18691869
[builtins fixtures/property.pyi]
1870+
1871+
[case testEvolve]
1872+
import attr
1873+
1874+
class Base:
1875+
pass
1876+
1877+
class Derived(Base):
1878+
pass
1879+
1880+
class Other:
1881+
pass
1882+
1883+
@attr.s(auto_attribs=True)
1884+
class C:
1885+
name: str
1886+
b: Base
1887+
1888+
c = C(name='foo', b=Derived())
1889+
c = attr.evolve(c)
1890+
c = attr.evolve(c, name='foo')
1891+
c = attr.evolve(c, 'foo') # E: Too many positional arguments for "evolve" of "C"
1892+
c = attr.evolve(c, b=Derived())
1893+
c = attr.evolve(c, b=Base())
1894+
c = attr.evolve(c, b=Other()) # E: Argument "b" to "evolve" of "C" has incompatible type "Other"; expected "Base"
1895+
c = attr.evolve(c, name=42) # E: Argument "name" to "evolve" of "C" has incompatible type "int"; expected "str"
1896+
c = attr.evolve(c, foobar=42) # E: Unexpected keyword argument "foobar" for "evolve" of "C"
1897+
1898+
# test passing instance as 'inst' kw
1899+
c = attr.evolve(inst=c, name='foo')
1900+
c = attr.evolve(not_inst=c, name='foo') # E: Missing positional argument "inst" in call to "evolve"
1901+
1902+
# test determining type of first argument's expression from something that's not NameExpr
1903+
def f() -> C:
1904+
return c
1905+
1906+
c = attr.evolve(f(), name='foo')
1907+
1908+
[builtins fixtures/attr.pyi]
1909+
1910+
[case testEvolveFromNonAttrs]
1911+
import attr
1912+
1913+
attr.evolve(42, name='foo') # E: Argument 1 to "evolve" has incompatible type "int"; expected an attrs class
1914+
attr.evolve(None, name='foo') # E: Argument 1 to "evolve" has incompatible type "None"; expected an attrs class
1915+
[case testEvolveFromAny]
1916+
from typing import Any
1917+
import attr
1918+
1919+
any: Any = 42
1920+
ret = attr.evolve(any, name='foo')
1921+
reveal_type(ret) # N: Revealed type is "Any"
1922+
1923+
[typing fixtures/typing-medium.pyi]
1924+
1925+
[case testEvolveVariants]
1926+
from typing import Any
1927+
import attr
1928+
import attrs
1929+
1930+
1931+
@attr.s(auto_attribs=True)
1932+
class C:
1933+
name: str
1934+
1935+
c = C(name='foo')
1936+
1937+
c = attr.assoc(c, name='test')
1938+
c = attr.assoc(c, name=42) # E: Argument "name" to "assoc" of "C" has incompatible type "int"; expected "str"
1939+
1940+
c = attrs.evolve(c, name='test')
1941+
c = attrs.evolve(c, name=42) # E: Argument "name" to "evolve" of "C" has incompatible type "int"; expected "str"
1942+
1943+
c = attrs.assoc(c, name='test')
1944+
c = attrs.assoc(c, name=42) # E: Argument "name" to "assoc" of "C" has incompatible type "int"; expected "str"
1945+
1946+
[builtins fixtures/attr.pyi]
1947+
[typing fixtures/typing-medium.pyi]

test-data/unit/lib-stub/attr/__init__.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,6 @@ def field(
244244
order: Optional[bool] = ...,
245245
on_setattr: Optional[object] = ...,
246246
) -> Any: ...
247+
248+
def evolve(inst: _T, **changes: Any) -> _T: ...
249+
def assoc(inst: _T, **changes: Any) -> _T: ...

test-data/unit/lib-stub/attrs/__init__.pyi

+3
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,6 @@ def field(
126126
order: Optional[bool] = ...,
127127
on_setattr: Optional[object] = ...,
128128
) -> Any: ...
129+
130+
def evolve(inst: _T, **changes: Any) -> _T: ...
131+
def assoc(inst: _T, **changes: Any) -> _T: ...

0 commit comments

Comments
 (0)