Skip to content

Commit cee0030

Browse files
authored
Constant fold more unary and binary expressions (#15202)
Now mypy can constant fold these additional operations: - Float arithmetic - Mixed int and float arithmetic - String multiplication - Complex plus or minus a literal real (eg. 1+j2) While this can be useful with literal types, the main goal is to improve constant folding in mypyc (mypyc/mypyc#772). mypyc can also fold bytes addition and multiplication, but mypy cannot as byte values can't be easily stored anywhere.
1 parent 2bb7078 commit cee0030

10 files changed

+412
-97
lines changed

mypy/constant_fold.py

+90-28
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,21 @@
88
from typing import Union
99
from typing_extensions import Final
1010

11-
from mypy.nodes import Expression, FloatExpr, IntExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var
11+
from mypy.nodes import (
12+
ComplexExpr,
13+
Expression,
14+
FloatExpr,
15+
IntExpr,
16+
NameExpr,
17+
OpExpr,
18+
StrExpr,
19+
UnaryExpr,
20+
Var,
21+
)
1222

1323
# All possible result types of constant folding
14-
ConstantValue = Union[int, bool, float, str]
15-
CONST_TYPES: Final = (int, bool, float, str)
24+
ConstantValue = Union[int, bool, float, complex, str]
25+
CONST_TYPES: Final = (int, bool, float, complex, str)
1626

1727

1828
def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | None:
@@ -39,6 +49,8 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non
3949
return expr.value
4050
if isinstance(expr, FloatExpr):
4151
return expr.value
52+
if isinstance(expr, ComplexExpr):
53+
return expr.value
4254
elif isinstance(expr, NameExpr):
4355
if expr.name == "True":
4456
return True
@@ -56,26 +68,60 @@ def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | Non
5668
elif isinstance(expr, OpExpr):
5769
left = constant_fold_expr(expr.left, cur_mod_id)
5870
right = constant_fold_expr(expr.right, cur_mod_id)
59-
if isinstance(left, int) and isinstance(right, int):
60-
return constant_fold_binary_int_op(expr.op, left, right)
61-
elif isinstance(left, str) and isinstance(right, str):
62-
return constant_fold_binary_str_op(expr.op, left, right)
71+
if left is not None and right is not None:
72+
return constant_fold_binary_op(expr.op, left, right)
6373
elif isinstance(expr, UnaryExpr):
6474
value = constant_fold_expr(expr.expr, cur_mod_id)
65-
if isinstance(value, int):
66-
return constant_fold_unary_int_op(expr.op, value)
67-
if isinstance(value, float):
68-
return constant_fold_unary_float_op(expr.op, value)
75+
if value is not None:
76+
return constant_fold_unary_op(expr.op, value)
6977
return None
7078

7179

72-
def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
80+
def constant_fold_binary_op(
81+
op: str, left: ConstantValue, right: ConstantValue
82+
) -> ConstantValue | None:
83+
if isinstance(left, int) and isinstance(right, int):
84+
return constant_fold_binary_int_op(op, left, right)
85+
86+
# Float and mixed int/float arithmetic.
87+
if isinstance(left, float) and isinstance(right, float):
88+
return constant_fold_binary_float_op(op, left, right)
89+
elif isinstance(left, float) and isinstance(right, int):
90+
return constant_fold_binary_float_op(op, left, right)
91+
elif isinstance(left, int) and isinstance(right, float):
92+
return constant_fold_binary_float_op(op, left, right)
93+
94+
# String concatenation and multiplication.
95+
if op == "+" and isinstance(left, str) and isinstance(right, str):
96+
return left + right
97+
elif op == "*" and isinstance(left, str) and isinstance(right, int):
98+
return left * right
99+
elif op == "*" and isinstance(left, int) and isinstance(right, str):
100+
return left * right
101+
102+
# Complex construction.
103+
if op == "+" and isinstance(left, (int, float)) and isinstance(right, complex):
104+
return left + right
105+
elif op == "+" and isinstance(left, complex) and isinstance(right, (int, float)):
106+
return left + right
107+
elif op == "-" and isinstance(left, (int, float)) and isinstance(right, complex):
108+
return left - right
109+
elif op == "-" and isinstance(left, complex) and isinstance(right, (int, float)):
110+
return left - right
111+
112+
return None
113+
114+
115+
def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | float | None:
73116
if op == "+":
74117
return left + right
75118
if op == "-":
76119
return left - right
77120
elif op == "*":
78121
return left * right
122+
elif op == "/":
123+
if right != 0:
124+
return left / right
79125
elif op == "//":
80126
if right != 0:
81127
return left // right
@@ -102,25 +148,41 @@ def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | None:
102148
return None
103149

104150

105-
def constant_fold_unary_int_op(op: str, value: int) -> int | None:
106-
if op == "-":
107-
return -value
108-
elif op == "~":
109-
return ~value
110-
elif op == "+":
111-
return value
151+
def constant_fold_binary_float_op(op: str, left: int | float, right: int | float) -> float | None:
152+
assert not (isinstance(left, int) and isinstance(right, int)), (op, left, right)
153+
if op == "+":
154+
return left + right
155+
elif op == "-":
156+
return left - right
157+
elif op == "*":
158+
return left * right
159+
elif op == "/":
160+
if right != 0:
161+
return left / right
162+
elif op == "//":
163+
if right != 0:
164+
return left // right
165+
elif op == "%":
166+
if right != 0:
167+
return left % right
168+
elif op == "**":
169+
if (left < 0 and isinstance(right, int)) or left > 0:
170+
try:
171+
ret = left**right
172+
except OverflowError:
173+
return None
174+
else:
175+
assert isinstance(ret, float), ret
176+
return ret
177+
112178
return None
113179

114180

115-
def constant_fold_unary_float_op(op: str, value: float) -> float | None:
116-
if op == "-":
181+
def constant_fold_unary_op(op: str, value: ConstantValue) -> int | float | None:
182+
if op == "-" and isinstance(value, (int, float)):
117183
return -value
118-
elif op == "+":
184+
elif op == "~" and isinstance(value, int):
185+
return ~value
186+
elif op == "+" and isinstance(value, (int, float)):
119187
return value
120188
return None
121-
122-
123-
def constant_fold_binary_str_op(op: str, left: str, right: str) -> str | None:
124-
if op == "+":
125-
return left + right
126-
return None

mypy/nodes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
10021002
# If constant value is a simple literal,
10031003
# store the literal value (unboxed) for the benefit of
10041004
# tools like mypyc.
1005-
self.final_value: int | float | bool | str | None = None
1005+
self.final_value: int | float | complex | bool | str | None = None
10061006
# Where the value was set (only for class attributes)
10071007
self.final_unset_in_class = False
10081008
self.final_set_in_init = False

mypy/semanal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3394,7 +3394,7 @@ def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Typ
33943394
return None
33953395

33963396
value = constant_fold_expr(rvalue, self.cur_mod_id)
3397-
if value is None:
3397+
if value is None or isinstance(value, complex):
33983398
return None
33993399

34003400
if isinstance(value, bool):

mypyc/irbuild/builder.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -535,25 +535,25 @@ def load_final_static(
535535
error_msg=f'value for final name "{error_name}" was not set',
536536
)
537537

538-
def load_final_literal_value(self, val: int | str | bytes | float | bool, line: int) -> Value:
539-
"""Load value of a final name or class-level attribute."""
538+
def load_literal_value(self, val: int | str | bytes | float | complex | bool) -> Value:
539+
"""Load value of a final name, class-level attribute, or constant folded expression."""
540540
if isinstance(val, bool):
541541
if val:
542542
return self.true()
543543
else:
544544
return self.false()
545545
elif isinstance(val, int):
546-
# TODO: take care of negative integer initializers
547-
# (probably easier to fix this in mypy itself).
548546
return self.builder.load_int(val)
549547
elif isinstance(val, float):
550548
return self.builder.load_float(val)
551549
elif isinstance(val, str):
552550
return self.builder.load_str(val)
553551
elif isinstance(val, bytes):
554552
return self.builder.load_bytes(val)
553+
elif isinstance(val, complex):
554+
return self.builder.load_complex(val)
555555
else:
556-
assert False, "Unsupported final literal value"
556+
assert False, "Unsupported literal value"
557557

558558
def get_assignment_target(
559559
self, lvalue: Lvalue, line: int = -1, *, for_read: bool = False
@@ -1013,7 +1013,7 @@ def emit_load_final(
10131013
line: line number where loading occurs
10141014
"""
10151015
if final_var.final_value is not None: # this is safe even for non-native names
1016-
return self.load_final_literal_value(final_var.final_value, line)
1016+
return self.load_literal_value(final_var.final_value)
10171017
elif native:
10181018
return self.load_final_static(fullname, self.mapper.type_to_rtype(typ), line, name)
10191019
else:

mypyc/irbuild/callable_class.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
def setup_callable_class(builder: IRBuilder) -> None:
20-
"""Generate an (incomplete) callable class representing function.
20+
"""Generate an (incomplete) callable class representing a function.
2121
2222
This can be a nested function or a function within a non-extension
2323
class. Also set up the 'self' variable for that class.

mypyc/irbuild/constant_fold.py

+41-23
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,10 @@
1313
from typing import Union
1414
from typing_extensions import Final
1515

16-
from mypy.constant_fold import (
17-
constant_fold_binary_int_op,
18-
constant_fold_binary_str_op,
19-
constant_fold_unary_float_op,
20-
constant_fold_unary_int_op,
21-
)
16+
from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op
2217
from mypy.nodes import (
18+
BytesExpr,
19+
ComplexExpr,
2320
Expression,
2421
FloatExpr,
2522
IntExpr,
@@ -31,10 +28,11 @@
3128
Var,
3229
)
3330
from mypyc.irbuild.builder import IRBuilder
31+
from mypyc.irbuild.util import bytes_from_str
3432

3533
# All possible result types of constant folding
36-
ConstantValue = Union[int, str, float]
37-
CONST_TYPES: Final = (int, str, float)
34+
ConstantValue = Union[int, float, complex, str, bytes]
35+
CONST_TYPES: Final = (int, float, complex, str, bytes)
3836

3937

4038
def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None:
@@ -44,35 +42,55 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue |
4442
"""
4543
if isinstance(expr, IntExpr):
4644
return expr.value
45+
if isinstance(expr, FloatExpr):
46+
return expr.value
4747
if isinstance(expr, StrExpr):
4848
return expr.value
49-
if isinstance(expr, FloatExpr):
49+
if isinstance(expr, BytesExpr):
50+
return bytes_from_str(expr.value)
51+
if isinstance(expr, ComplexExpr):
5052
return expr.value
5153
elif isinstance(expr, NameExpr):
5254
node = expr.node
5355
if isinstance(node, Var) and node.is_final:
54-
value = node.final_value
55-
if isinstance(value, (CONST_TYPES)):
56-
return value
56+
final_value = node.final_value
57+
if isinstance(final_value, (CONST_TYPES)):
58+
return final_value
5759
elif isinstance(expr, MemberExpr):
5860
final = builder.get_final_ref(expr)
5961
if final is not None:
6062
fn, final_var, native = final
6163
if final_var.is_final:
62-
value = final_var.final_value
63-
if isinstance(value, (CONST_TYPES)):
64-
return value
64+
final_value = final_var.final_value
65+
if isinstance(final_value, (CONST_TYPES)):
66+
return final_value
6567
elif isinstance(expr, OpExpr):
6668
left = constant_fold_expr(builder, expr.left)
6769
right = constant_fold_expr(builder, expr.right)
68-
if isinstance(left, int) and isinstance(right, int):
69-
return constant_fold_binary_int_op(expr.op, left, right)
70-
elif isinstance(left, str) and isinstance(right, str):
71-
return constant_fold_binary_str_op(expr.op, left, right)
70+
if left is not None and right is not None:
71+
return constant_fold_binary_op_extended(expr.op, left, right)
7272
elif isinstance(expr, UnaryExpr):
7373
value = constant_fold_expr(builder, expr.expr)
74-
if isinstance(value, int):
75-
return constant_fold_unary_int_op(expr.op, value)
76-
if isinstance(value, float):
77-
return constant_fold_unary_float_op(expr.op, value)
74+
if value is not None and not isinstance(value, bytes):
75+
return constant_fold_unary_op(expr.op, value)
76+
return None
77+
78+
79+
def constant_fold_binary_op_extended(
80+
op: str, left: ConstantValue, right: ConstantValue
81+
) -> ConstantValue | None:
82+
"""Like mypy's constant_fold_binary_op(), but includes bytes support.
83+
84+
mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc.
85+
"""
86+
if not isinstance(left, bytes) and not isinstance(right, bytes):
87+
return constant_fold_binary_op(op, left, right)
88+
89+
if op == "+" and isinstance(left, bytes) and isinstance(right, bytes):
90+
return left + right
91+
elif op == "*" and isinstance(left, bytes) and isinstance(right, int):
92+
return left * right
93+
elif op == "*" and isinstance(left, int) and isinstance(right, bytes):
94+
return left * right
95+
7896
return None

mypyc/irbuild/expression.py

+2-13
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
Assign,
5656
BasicBlock,
5757
ComparisonOp,
58-
Float,
5958
Integer,
6059
LoadAddress,
6160
LoadLiteral,
@@ -92,7 +91,6 @@
9291
tokenizer_printf_style,
9392
)
9493
from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization
95-
from mypyc.irbuild.util import bytes_from_str
9694
from mypyc.primitives.bytes_ops import bytes_slice_op
9795
from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op
9896
from mypyc.primitives.generic_ops import iter_op
@@ -575,12 +573,8 @@ def try_constant_fold(builder: IRBuilder, expr: Expression) -> Value | None:
575573
Return None otherwise.
576574
"""
577575
value = constant_fold_expr(builder, expr)
578-
if isinstance(value, int):
579-
return builder.load_int(value)
580-
elif isinstance(value, str):
581-
return builder.load_str(value)
582-
elif isinstance(value, float):
583-
return Float(value)
576+
if value is not None:
577+
return builder.load_literal_value(value)
584578
return None
585579

586580

@@ -662,10 +656,6 @@ def set_literal_values(builder: IRBuilder, items: Sequence[Expression]) -> list[
662656
values.append(True)
663657
elif item.fullname == "builtins.False":
664658
values.append(False)
665-
elif isinstance(item, (BytesExpr, FloatExpr, ComplexExpr)):
666-
# constant_fold_expr() doesn't handle these (yet?)
667-
v = bytes_from_str(item.value) if isinstance(item, BytesExpr) else item.value
668-
values.append(v)
669659
elif isinstance(item, TupleExpr):
670660
tuple_values = set_literal_values(builder, item.items)
671661
if tuple_values is not None:
@@ -685,7 +675,6 @@ def precompute_set_literal(builder: IRBuilder, s: SetExpr) -> Value | None:
685675
Supported items:
686676
- Anything supported by irbuild.constant_fold.constant_fold_expr()
687677
- None, True, and False
688-
- Float, byte, and complex literals
689678
- Tuple literals with only items listed above
690679
"""
691680
values = set_literal_values(builder, s.items)

0 commit comments

Comments
 (0)