Skip to content

Commit 7d54c5e

Browse files
committed
Enable mypy's warn_no_return lint
1 parent ee4d4f7 commit 7d54c5e

File tree

3 files changed

+38
-36
lines changed

3 files changed

+38
-36
lines changed

pyproject.toml

-5
Original file line numberDiff line numberDiff line change
@@ -156,17 +156,12 @@ lines-after-imports = 2
156156
[tool.mypy]
157157
python_version = "3.10"
158158
ignore_missing_imports = true
159-
no_implicit_optional = true
160-
check_untyped_defs = false
161159
strict_equality = true
162160
warn_redundant_casts = true
163161
warn_unused_configs = true
164162
warn_unused_ignores = true
165163
warn_return_any = true
166-
warn_no_return = false
167164
warn_unreachable = true
168-
show_error_codes = true
169-
allow_redefinition = false
170165
files = ["pytensor", "tests"]
171166
plugins = ["numpy.typing.mypy_plugin"]
172167

pytensor/graph/basic.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -909,6 +909,7 @@ def ancestors(
909909
def expand(r: Variable) -> Iterator[Variable] | None:
910910
if r.owner and (not blockers or r not in blockers):
911911
return reversed(r.owner.inputs)
912+
return None
912913

913914
yield from cast(Generator[Variable, None, None], walk(graphs, expand, False))
914915

@@ -1011,6 +1012,7 @@ def vars_between(
10111012
def expand(r: Variable) -> Iterable[Variable] | None:
10121013
if r.owner and r not in ins:
10131014
return reversed(r.owner.inputs + r.owner.outputs)
1015+
return None
10141016

10151017
yield from cast(Generator[Variable, None, None], walk(outs, expand))
10161018

@@ -2039,13 +2041,15 @@ def get_var_by_name(
20392041
from pytensor.graph.op import HasInnerGraph
20402042

20412043
def expand(r) -> list[Variable] | None:
2042-
if r.owner:
2043-
res = list(r.owner.inputs)
2044+
if not r.owner:
2045+
return None
2046+
2047+
res = list(r.owner.inputs)
20442048

2045-
if isinstance(r.owner.op, HasInnerGraph):
2046-
res.extend(r.owner.op.inner_outputs)
2049+
if isinstance(r.owner.op, HasInnerGraph):
2050+
res.extend(r.owner.op.inner_outputs)
20472051

2048-
return res
2052+
return res
20492053

20502054
results: tuple[Variable, ...] = ()
20512055
for var in walk(graphs, expand, False):

pytensor/tensor/rewriting/linalg.py

+29-26
Original file line numberDiff line numberDiff line change
@@ -355,34 +355,37 @@ def local_lift_through_linalg(
355355
"""
356356

357357
# TODO: Simplify this if we end up Blockwising KroneckerProduct
358-
if isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
359-
y = node.inputs[0]
360-
outer_op = node.op
361-
362-
if y.owner and (
363-
isinstance(y.owner.op, Blockwise)
364-
and isinstance(y.owner.op.core_op, BlockDiagonal)
365-
or isinstance(y.owner.op, KroneckerProduct)
366-
):
367-
input_matrices = y.owner.inputs
368-
369-
if isinstance(outer_op.core_op, MatrixInverse):
370-
outer_f = cast(Callable, inv)
371-
elif isinstance(outer_op.core_op, Cholesky):
372-
outer_f = cast(Callable, cholesky)
373-
elif isinstance(outer_op.core_op, MatrixPinv):
374-
outer_f = cast(Callable, pinv)
375-
else:
376-
raise NotImplementedError # pragma: no cover
358+
if not isinstance(node.op.core_op, MatrixInverse | Cholesky | MatrixPinv):
359+
return None
377360

378-
inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices]
361+
y = node.inputs[0]
362+
outer_op = node.op
379363

380-
if isinstance(y.owner.op, KroneckerProduct):
381-
return [kron(*inner_matrices)]
382-
elif isinstance(y.owner.op.core_op, BlockDiagonal):
383-
return [block_diag(*inner_matrices)]
384-
else:
385-
raise NotImplementedError # pragma: no cover
364+
if y.owner and (
365+
isinstance(y.owner.op, Blockwise)
366+
and isinstance(y.owner.op.core_op, BlockDiagonal)
367+
or isinstance(y.owner.op, KroneckerProduct)
368+
):
369+
input_matrices = y.owner.inputs
370+
371+
if isinstance(outer_op.core_op, MatrixInverse):
372+
outer_f = cast(Callable, inv)
373+
elif isinstance(outer_op.core_op, Cholesky):
374+
outer_f = cast(Callable, cholesky)
375+
elif isinstance(outer_op.core_op, MatrixPinv):
376+
outer_f = cast(Callable, pinv)
377+
else:
378+
raise NotImplementedError # pragma: no cover
379+
380+
inner_matrices = [cast(TensorVariable, outer_f(m)) for m in input_matrices]
381+
382+
if isinstance(y.owner.op, KroneckerProduct):
383+
return [kron(*inner_matrices)]
384+
elif isinstance(y.owner.op.core_op, BlockDiagonal):
385+
return [block_diag(*inner_matrices)]
386+
else:
387+
raise NotImplementedError # pragma: no cover
388+
return None
386389

387390

388391
def _find_diag_from_eye_mul(potential_mul_input):

0 commit comments

Comments
 (0)