Skip to content

Commit 55bee20

Browse files
authored
Make disable_type_names a context manager (#11716)
1 parent e8cf960 commit 55bee20

File tree

3 files changed

+38
-23
lines changed

3 files changed

+38
-23
lines changed

mypy/checkexpr.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2064,11 +2064,19 @@ def check_union_call(self,
20642064
arg_names: Optional[Sequence[Optional[str]]],
20652065
context: Context,
20662066
arg_messages: MessageBuilder) -> Tuple[Type, Type]:
2067-
self.msg.disable_type_names += 1
2068-
results = [self.check_call(subtype, args, arg_kinds, context, arg_names,
2069-
arg_messages=arg_messages)
2070-
for subtype in callee.relevant_items()]
2071-
self.msg.disable_type_names -= 1
2067+
with self.msg.disable_type_names():
2068+
results = [
2069+
self.check_call(
2070+
subtype,
2071+
args,
2072+
arg_kinds,
2073+
context,
2074+
arg_names,
2075+
arg_messages=arg_messages,
2076+
)
2077+
for subtype in callee.relevant_items()
2078+
]
2079+
20722080
return (make_simplified_union([res[0] for res in results]),
20732081
callee)
20742082

@@ -2462,11 +2470,11 @@ def check_union_method_call_by_name(self,
24622470
for typ in base_type.relevant_items():
24632471
# Format error messages consistently with
24642472
# mypy.checkmember.analyze_union_member_access().
2465-
local_errors.disable_type_names += 1
2466-
item, meth_item = self.check_method_call_by_name(method, typ, args, arg_kinds,
2467-
context, local_errors,
2468-
original_type)
2469-
local_errors.disable_type_names -= 1
2473+
with local_errors.disable_type_names():
2474+
item, meth_item = self.check_method_call_by_name(
2475+
method, typ, args, arg_kinds,
2476+
context, local_errors, original_type,
2477+
)
24702478
res.append(item)
24712479
meth_res.append(meth_item)
24722480
return make_simplified_union(res), make_simplified_union(meth_res)

mypy/checkmember.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -311,13 +311,12 @@ def analyze_type_type_member_access(name: str,
311311

312312

313313
def analyze_union_member_access(name: str, typ: UnionType, mx: MemberContext) -> Type:
314-
mx.msg.disable_type_names += 1
315-
results = []
316-
for subtype in typ.relevant_items():
317-
# Self types should be bound to every individual item of a union.
318-
item_mx = mx.copy_modified(self_type=subtype)
319-
results.append(_analyze_member_access(name, subtype, item_mx))
320-
mx.msg.disable_type_names -= 1
314+
with mx.msg.disable_type_names():
315+
results = []
316+
for subtype in typ.relevant_items():
317+
# Self types should be bound to every individual item of a union.
318+
item_mx = mx.copy_modified(self_type=subtype)
319+
results.append(_analyze_member_access(name, subtype, item_mx))
321320
return make_simplified_union(results)
322321

323322

mypy/messages.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,13 @@ class MessageBuilder:
107107
disable_count = 0
108108

109109
# Hack to deduplicate error messages from union types
110-
disable_type_names = 0
110+
disable_type_names_count = 0
111111

112112
def __init__(self, errors: Errors, modules: Dict[str, MypyFile]) -> None:
113113
self.errors = errors
114114
self.modules = modules
115115
self.disable_count = 0
116-
self.disable_type_names = 0
116+
self.disable_type_names_count = 0
117117

118118
#
119119
# Helpers
@@ -122,7 +122,7 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile]) -> None:
122122
def copy(self) -> 'MessageBuilder':
123123
new = MessageBuilder(self.errors.copy(), self.modules)
124124
new.disable_count = self.disable_count
125-
new.disable_type_names = self.disable_type_names
125+
new.disable_type_names_count = self.disable_type_names_count
126126
return new
127127

128128
def clean_copy(self) -> 'MessageBuilder':
@@ -145,6 +145,14 @@ def disable_errors(self) -> Iterator[None]:
145145
finally:
146146
self.disable_count -= 1
147147

148+
@contextmanager
149+
def disable_type_names(self) -> Iterator[None]:
150+
self.disable_type_names_count += 1
151+
try:
152+
yield
153+
finally:
154+
self.disable_type_names_count -= 1
155+
148156
def is_errors(self) -> bool:
149157
return self.errors.is_errors()
150158

@@ -298,7 +306,7 @@ def has_no_attr(self,
298306
extra = ' (not iterable)'
299307
elif member == '__aiter__':
300308
extra = ' (not async iterable)'
301-
if not self.disable_type_names:
309+
if not self.disable_type_names_count:
302310
failed = False
303311
if isinstance(original_type, Instance) and original_type.type.names:
304312
alternatives = set(original_type.type.names.keys())
@@ -380,7 +388,7 @@ def unsupported_operand_types(self,
380388
else:
381389
right_str = format_type(right_type)
382390

383-
if self.disable_type_names:
391+
if self.disable_type_names_count:
384392
msg = 'Unsupported operand types for {} (likely involving Union)'.format(op)
385393
else:
386394
msg = 'Unsupported operand types for {} ({} and {})'.format(
@@ -389,7 +397,7 @@ def unsupported_operand_types(self,
389397

390398
def unsupported_left_operand(self, op: str, typ: Type,
391399
context: Context) -> None:
392-
if self.disable_type_names:
400+
if self.disable_type_names_count:
393401
msg = 'Unsupported left operand type for {} (some union)'.format(op)
394402
else:
395403
msg = 'Unsupported left operand type for {} ({})'.format(

0 commit comments

Comments
 (0)