Skip to content

Commit ddf88f3

Browse files
committed
improve cp-sat python types
1 parent dd765fb commit ddf88f3

16 files changed

+94
-106
lines changed

ortools/sat/docs/scheduling.md

+8-5
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def create_data_model() -> tuple[pd.DataFrame, pd.DataFrame]:
726726
return capacity_df, tasks_df
727727

728728

729-
def main():
729+
def main() -> None:
730730
"""Create the model and solves it."""
731731
capacity_df, tasks_df = create_data_model()
732732

@@ -834,7 +834,7 @@ def rank_tasks(
834834
starts: list[cp_model.IntVar],
835835
presences: list[cp_model.IntVar],
836836
ranks: list[cp_model.IntVar],
837-
):
837+
) -> None:
838838
"""This method adds constraints and variables to links tasks and ranks.
839839
840840
This method assumes that all starts are disjoint, meaning that all tasks have
@@ -852,7 +852,7 @@ def rank_tasks(
852852
all_tasks = range(num_tasks)
853853

854854
# Creates precedence variables between pairs of intervals.
855-
precedences = {}
855+
precedences: dict[tuple[int, int], cp_model.IntVar] = {}
856856
for i in all_tasks:
857857
for j in all_tasks:
858858
if i == j:
@@ -865,7 +865,10 @@ def rank_tasks(
865865
# Treats optional intervals.
866866
for i in range(num_tasks - 1):
867867
for j in range(i + 1, num_tasks):
868-
tmp_array = [precedences[(i, j)], precedences[(j, i)]]
868+
tmp_array: list[cp_model.LiteralT] = [
869+
precedences[(i, j)],
870+
precedences[(j, i)],
871+
]
869872
if not cp_model.object_is_a_true_literal(presences[i]):
870873
tmp_array.append(presences[i].negated())
871874
# Makes sure that if i is not performed, all precedences are false.
@@ -898,7 +901,7 @@ def rank_tasks(
898901
model.add(ranks[i] == sum(precedences[(j, i)] for j in all_tasks) - 1)
899902

900903

901-
def ranking_sample_sat():
904+
def ranking_sample_sat() -> None:
902905
"""Ranks tasks in a NoOverlap constraint."""
903906

904907
model = cp_model.CpModel()

ortools/sat/docs/troubleshooting.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ parallelism. Therefore, the number of workers must be set to 1.
101101
from ortools.sat.python import cp_model
102102

103103

104-
def main():
104+
def main() -> None:
105105
"""Showcases assumptions."""
106106
# Creates the model.
107107
model = cp_model.CpModel()

ortools/sat/python/cp_model.py

+51-66
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
Dict,
5757
Iterable,
5858
List,
59+
NoReturn,
5960
Optional,
6061
Sequence,
6162
Tuple,
@@ -80,48 +81,30 @@
8081
# usual arithmetic operators + - * / and with constant numbers, which makes the
8182
# python API very intuitive. See../ samples/*.py for examples.
8283

83-
INT_MIN = -9223372036854775808 # hardcoded to be platform independent.
84-
INT_MAX = 9223372036854775807
85-
INT32_MAX = 2147483647
86-
INT32_MIN = -2147483648
84+
INT_MIN = -(2**63) # hardcoded to be platform independent.
85+
INT_MAX = 2**63 - 1
86+
INT32_MIN = -(2**31)
87+
INT32_MAX = 2**31 - 1
8788

8889
# CpSolver status (exported to avoid importing cp_model_cp2).
89-
UNKNOWN: cp_model_pb2.CpSolverStatus = cp_model_pb2.UNKNOWN
90-
MODEL_INVALID: cp_model_pb2.CpSolverStatus = cp_model_pb2.MODEL_INVALID
91-
FEASIBLE: cp_model_pb2.CpSolverStatus = cp_model_pb2.FEASIBLE
92-
INFEASIBLE: cp_model_pb2.CpSolverStatus = cp_model_pb2.INFEASIBLE
93-
OPTIMAL: cp_model_pb2.CpSolverStatus = cp_model_pb2.OPTIMAL
90+
UNKNOWN = cp_model_pb2.UNKNOWN
91+
MODEL_INVALID = cp_model_pb2.MODEL_INVALID
92+
FEASIBLE = cp_model_pb2.FEASIBLE
93+
INFEASIBLE = cp_model_pb2.INFEASIBLE
94+
OPTIMAL = cp_model_pb2.OPTIMAL
9495

9596
# Variable selection strategy
96-
CHOOSE_FIRST: cp_model_pb2.DecisionStrategyProto.VariableSelectionStrategy = (
97-
cp_model_pb2.DecisionStrategyProto.CHOOSE_FIRST
98-
)
99-
CHOOSE_LOWEST_MIN: (
100-
cp_model_pb2.DecisionStrategyProto.VariableSelectionStrategy
101-
) = cp_model_pb2.DecisionStrategyProto.CHOOSE_LOWEST_MIN
102-
CHOOSE_HIGHEST_MAX: (
103-
cp_model_pb2.DecisionStrategyProto.VariableSelectionStrategy
104-
) = cp_model_pb2.DecisionStrategyProto.CHOOSE_HIGHEST_MAX
105-
CHOOSE_MIN_DOMAIN_SIZE: (
106-
cp_model_pb2.DecisionStrategyProto.VariableSelectionStrategy
107-
) = cp_model_pb2.DecisionStrategyProto.CHOOSE_MIN_DOMAIN_SIZE
108-
CHOOSE_MAX_DOMAIN_SIZE: (
109-
cp_model_pb2.DecisionStrategyProto.VariableSelectionStrategy
110-
) = cp_model_pb2.DecisionStrategyProto.CHOOSE_MAX_DOMAIN_SIZE
97+
CHOOSE_FIRST = cp_model_pb2.DecisionStrategyProto.CHOOSE_FIRST
98+
CHOOSE_LOWEST_MIN = cp_model_pb2.DecisionStrategyProto.CHOOSE_LOWEST_MIN
99+
CHOOSE_HIGHEST_MAX = cp_model_pb2.DecisionStrategyProto.CHOOSE_HIGHEST_MAX
100+
CHOOSE_MIN_DOMAIN_SIZE = cp_model_pb2.DecisionStrategyProto.CHOOSE_MIN_DOMAIN_SIZE
101+
CHOOSE_MAX_DOMAIN_SIZE = cp_model_pb2.DecisionStrategyProto.CHOOSE_MAX_DOMAIN_SIZE
111102

112103
# Domain reduction strategy
113-
SELECT_MIN_VALUE: cp_model_pb2.DecisionStrategyProto.DomainReductionStrategy = (
114-
cp_model_pb2.DecisionStrategyProto.SELECT_MIN_VALUE
115-
)
116-
SELECT_MAX_VALUE: cp_model_pb2.DecisionStrategyProto.DomainReductionStrategy = (
117-
cp_model_pb2.DecisionStrategyProto.SELECT_MAX_VALUE
118-
)
119-
SELECT_LOWER_HALF: (
120-
cp_model_pb2.DecisionStrategyProto.DomainReductionStrategy
121-
) = cp_model_pb2.DecisionStrategyProto.SELECT_LOWER_HALF
122-
SELECT_UPPER_HALF: (
123-
cp_model_pb2.DecisionStrategyProto.DomainReductionStrategy
124-
) = cp_model_pb2.DecisionStrategyProto.SELECT_UPPER_HALF
104+
SELECT_MIN_VALUE = cp_model_pb2.DecisionStrategyProto.SELECT_MIN_VALUE
105+
SELECT_MAX_VALUE = cp_model_pb2.DecisionStrategyProto.SELECT_MAX_VALUE
106+
SELECT_LOWER_HALF = cp_model_pb2.DecisionStrategyProto.SELECT_LOWER_HALF
107+
SELECT_UPPER_HALF = cp_model_pb2.DecisionStrategyProto.SELECT_UPPER_HALF
125108

126109
# Search branching
127110
AUTOMATIC_SEARCH = sat_parameters_pb2.SatParameters.AUTOMATIC_SEARCH
@@ -144,6 +127,7 @@
144127
VariableT = Union["IntVar", IntegralT]
145128
LinearExprT = Union["LinearExpr", "IntVar", IntegralT]
146129
ObjLinearExprT = Union["LinearExpr", "IntVar", NumberT]
130+
BoundedLinearExprT = Union["BoundedLinearExpression", bool]
147131
ArcT = Tuple[IntegralT, IntegralT, LiteralT]
148132
_IndexOrSeries = Union[pd.Index, pd.Series]
149133

@@ -405,110 +389,110 @@ def get_float_var_value_map(
405389
break
406390
return coeffs, constant, is_integer
407391

408-
def __hash__(self):
392+
def __hash__(self) -> int:
409393
return object.__hash__(self)
410394

411-
def __abs__(self):
395+
def __abs__(self) -> NoReturn:
412396
raise NotImplementedError(
413397
"calling abs() on a linear expression is not supported, "
414398
"please use CpModel.add_abs_equality"
415399
)
416400

417-
def __add__(self, arg):
401+
def __add__(self, arg) -> LinearExprT:
418402
if cmh.is_zero(arg):
419403
return self
420404
return _Sum(self, arg)
421405

422-
def __radd__(self, arg):
406+
def __radd__(self, arg) -> LinearExprT:
423407
if cmh.is_zero(arg):
424408
return self
425409
return _Sum(self, arg)
426410

427-
def __sub__(self, arg):
411+
def __sub__(self, arg) -> LinearExprT:
428412
if cmh.is_zero(arg):
429413
return self
430414
return _Sum(self, -arg)
431415

432-
def __rsub__(self, arg):
416+
def __rsub__(self, arg) -> LinearExprT:
433417
return _Sum(-self, arg)
434418

435-
def __mul__(self, arg):
419+
def __mul__(self, arg) -> LinearExprT:
436420
arg = cmh.assert_is_a_number(arg)
437421
if cmh.is_one(arg):
438422
return self
439423
elif cmh.is_zero(arg):
440424
return 0
441425
return _ProductCst(self, arg)
442426

443-
def __rmul__(self, arg):
427+
def __rmul__(self, arg) -> LinearExprT:
444428
arg = cmh.assert_is_a_number(arg)
445429
if cmh.is_one(arg):
446430
return self
447431
elif cmh.is_zero(arg):
448432
return 0
449433
return _ProductCst(self, arg)
450434

451-
def __div__(self, _):
435+
def __div__(self, _) -> NoReturn:
452436
raise NotImplementedError(
453437
"calling / on a linear expression is not supported, "
454438
"please use CpModel.add_division_equality"
455439
)
456440

457-
def __truediv__(self, _):
441+
def __truediv__(self, _) -> NoReturn:
458442
raise NotImplementedError(
459443
"calling // on a linear expression is not supported, "
460444
"please use CpModel.add_division_equality"
461445
)
462446

463-
def __mod__(self, _):
447+
def __mod__(self, _) -> NoReturn:
464448
raise NotImplementedError(
465449
"calling %% on a linear expression is not supported, "
466450
"please use CpModel.add_modulo_equality"
467451
)
468452

469-
def __pow__(self, _):
453+
def __pow__(self, _) -> NoReturn:
470454
raise NotImplementedError(
471455
"calling ** on a linear expression is not supported, "
472456
"please use CpModel.add_multiplication_equality"
473457
)
474458

475-
def __lshift__(self, _):
459+
def __lshift__(self, _) -> NoReturn:
476460
raise NotImplementedError(
477461
"calling left shift on a linear expression is not supported"
478462
)
479463

480-
def __rshift__(self, _):
464+
def __rshift__(self, _) -> NoReturn:
481465
raise NotImplementedError(
482466
"calling right shift on a linear expression is not supported"
483467
)
484468

485-
def __and__(self, _):
469+
def __and__(self, _) -> NoReturn:
486470
raise NotImplementedError(
487471
"calling and on a linear expression is not supported, "
488472
"please use CpModel.add_bool_and"
489473
)
490474

491-
def __or__(self, _):
475+
def __or__(self, _) -> NoReturn:
492476
raise NotImplementedError(
493477
"calling or on a linear expression is not supported, "
494478
"please use CpModel.add_bool_or"
495479
)
496480

497-
def __xor__(self, _):
481+
def __xor__(self, _) -> NoReturn:
498482
raise NotImplementedError(
499483
"calling xor on a linear expression is not supported, "
500484
"please use CpModel.add_bool_xor"
501485
)
502486

503-
def __neg__(self):
487+
def __neg__(self) -> LinearExprT:
504488
return _ProductCst(self, -1)
505489

506-
def __bool__(self):
490+
def __bool__(self) -> NoReturn:
507491
raise NotImplementedError(
508492
"Evaluating a LinearExpr instance as a Boolean is not implemented."
509493
)
510494

511-
def __eq__(self, arg):
495+
def __eq__(self, arg) -> BoundedLinearExprT:
512496
if arg is None:
513497
return False
514498
if cmh.is_integral(arg):
@@ -517,21 +501,21 @@ def __eq__(self, arg):
517501
else:
518502
return BoundedLinearExpression(self - arg, [0, 0])
519503

520-
def __ge__(self, arg):
504+
def __ge__(self, arg) -> BoundedLinearExprT:
521505
if cmh.is_integral(arg):
522506
arg = cmh.assert_is_int64(arg)
523507
return BoundedLinearExpression(self, [arg, INT_MAX])
524508
else:
525509
return BoundedLinearExpression(self - arg, [0, INT_MAX])
526510

527-
def __le__(self, arg):
511+
def __le__(self, arg) -> BoundedLinearExprT:
528512
if cmh.is_integral(arg):
529513
arg = cmh.assert_is_int64(arg)
530514
return BoundedLinearExpression(self, [INT_MIN, arg])
531515
else:
532516
return BoundedLinearExpression(self - arg, [INT_MIN, 0])
533517

534-
def __lt__(self, arg):
518+
def __lt__(self, arg) -> BoundedLinearExprT:
535519
if cmh.is_integral(arg):
536520
arg = cmh.assert_is_int64(arg)
537521
if arg == INT_MIN:
@@ -540,7 +524,7 @@ def __lt__(self, arg):
540524
else:
541525
return BoundedLinearExpression(self - arg, [INT_MIN, -1])
542526

543-
def __gt__(self, arg):
527+
def __gt__(self, arg) -> BoundedLinearExprT:
544528
if cmh.is_integral(arg):
545529
arg = cmh.assert_is_int64(arg)
546530
if arg == INT_MAX:
@@ -549,7 +533,7 @@ def __gt__(self, arg):
549533
else:
550534
return BoundedLinearExpression(self - arg, [1, INT_MAX])
551535

552-
def __ne__(self, arg):
536+
def __ne__(self, arg) -> BoundedLinearExprT:
553537
if arg is None:
554538
return True
555539
if cmh.is_integral(arg):
@@ -904,7 +888,7 @@ def __str__(self) -> str:
904888
def name(self) -> str:
905889
return "not(%s)" % str(self.__boolvar)
906890

907-
def __bool__(self) -> bool:
891+
def __bool__(self) -> NoReturn:
908892
raise NotImplementedError(
909893
"Evaluating a literal as a Boolean value is not implemented."
910894
)
@@ -964,8 +948,9 @@ def bounds(self) -> Sequence[int]:
964948
return self.__bounds
965949

966950
def __bool__(self) -> bool:
967-
if isinstance(self.__expr, LinearExpr):
968-
coeffs_map, constant = self.__expr.get_integer_var_value_map()
951+
expr = self.__expr
952+
if isinstance(expr, LinearExpr):
953+
coeffs_map, constant = expr.get_integer_var_value_map()
969954
all_coeffs = set(coeffs_map.values())
970955
same_var = set([0])
971956
eq_bounds = [0, 0]
@@ -3181,7 +3166,7 @@ def sufficient_assumptions_for_infeasibility(self) -> Sequence[int]:
31813166
"""Returns the indices of the infeasible assumptions."""
31823167
return self._solution.sufficient_assumptions_for_infeasibility
31833168

3184-
def status_name(self, status: Optional[cp_model_pb2.CpSolverStatus] = None) -> str:
3169+
def status_name(self, status: Optional[Any] = None) -> str:
31853170
"""Returns the name of the status returned by solve()."""
31863171
if status is None:
31873172
status = self._solution.status
@@ -3245,7 +3230,7 @@ def Solve(
32453230
def SolutionInfo(self) -> str:
32463231
return self.solution_info()
32473232

3248-
def StatusName(self, status: Optional[cp_model_pb2.CpSolverStatus] = None) -> str:
3233+
def StatusName(self, status: Optional[Any] = None) -> str:
32493234
return self.status_name(status)
32503235

32513236
def StopSearch(self) -> None:

ortools/sat/samples/assignment_groups_sat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# [END import]
2020

2121

22-
def main():
22+
def main() -> None:
2323
# Data
2424
# [START data]
2525
costs = [

ortools/sat/samples/assignment_sat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
# [END import]
2525

2626

27-
def main():
27+
def main() -> None:
2828
# Data
2929
# [START data_model]
3030
data_str = """

ortools/sat/samples/assignment_task_sizes_sat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# [END import]
2020

2121

22-
def main():
22+
def main() -> None:
2323
# Data
2424
# [START data]
2525
costs = [

ortools/sat/samples/assignment_teams_sat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# [END import]
2020

2121

22-
def main():
22+
def main() -> None:
2323
# Data
2424
# [START data]
2525
costs = [

ortools/sat/samples/assumptions_sample_sat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# [END import]
2020

2121

22-
def main():
22+
def main() -> None:
2323
"""Showcases assumptions."""
2424
# Creates the model.
2525
# [START model]

0 commit comments

Comments
 (0)