Skip to content

Commit 6cf729b

Browse files
committed
Fix mypy in tensor/slinalg.py
1 parent 7d54c5e commit 6cf729b

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

pytensor/tensor/slinalg.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import typing
33
import warnings
44
from functools import reduce
5-
from typing import TYPE_CHECKING, Literal
5+
from typing import TYPE_CHECKING, Literal, cast
66

77
import numpy as np
88
import scipy.linalg
@@ -141,7 +141,7 @@ def cholesky(x, lower=True, on_error="raise", check_finite=False):
141141
class SolveBase(Op):
142142
"""Base class for `scipy.linalg` matrix equation solvers."""
143143

144-
__props__ = (
144+
__props__: tuple[str, ...] = (
145145
"lower",
146146
"check_finite",
147147
"b_ndim",
@@ -352,7 +352,7 @@ def solve_triangular(
352352
This will influence how batched dimensions are interpreted.
353353
"""
354354
b_ndim = _default_b_ndim(b, b_ndim)
355-
return Blockwise(
355+
ret = Blockwise(
356356
SolveTriangular(
357357
lower=lower,
358358
trans=trans,
@@ -361,6 +361,7 @@ def solve_triangular(
361361
b_ndim=b_ndim,
362362
)
363363
)(a, b)
364+
return cast(TensorVariable, ret)
364365

365366

366367
class Solve(SolveBase):
@@ -714,9 +715,7 @@ def grad(self, inputs, output_grads):
714715

715716

716717
_solve_continuous_lyapunov = SolveContinuousLyapunov()
717-
_solve_bilinear_direct_lyapunov = typing.cast(
718-
typing.Callable, BilinearSolveDiscreteLyapunov()
719-
)
718+
_solve_bilinear_direct_lyapunov = cast(typing.Callable, BilinearSolveDiscreteLyapunov())
720719

721720

722721
def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
@@ -729,7 +728,7 @@ def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorV
729728
AA = kron(A_, A_)
730729

731730
X = solve(pt.eye(AA.shape[0]) - AA, Q_.ravel())
732-
return typing.cast(TensorVariable, reshape(X, Q_.shape))
731+
return cast(TensorVariable, reshape(X, Q_.shape))
733732

734733

735734
def solve_discrete_lyapunov(
@@ -765,7 +764,7 @@ def solve_discrete_lyapunov(
765764
if method == "direct":
766765
return _direct_solve_discrete_lyapunov(A, Q)
767766
if method == "bilinear":
768-
return typing.cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
767+
return cast(TensorVariable, _solve_bilinear_direct_lyapunov(A, Q))
769768

770769

771770
def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariable:
@@ -785,7 +784,7 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
785784
786785
"""
787786

788-
return typing.cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
787+
return cast(TensorVariable, _solve_continuous_lyapunov(A, Q))
789788

790789

791790
class SolveDiscreteARE(pt.Op):
@@ -866,9 +865,7 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
866865
Square matrix of shape M x M, representing the solution to the DARE
867866
"""
868867

869-
return typing.cast(
870-
TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R)
871-
)
868+
return cast(TensorVariable, SolveDiscreteARE(enforce_Q_symmetric)(A, B, Q, R))
872869

873870

874871
def _largest_common_dtype(tensors: typing.Sequence[TensorVariable]) -> np.dtype:

scripts/mypy-failing.txt

-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ pytensor/tensor/random/basic.py
2424
pytensor/tensor/random/op.py
2525
pytensor/tensor/random/utils.py
2626
pytensor/tensor/rewriting/basic.py
27-
pytensor/tensor/slinalg.py
2827
pytensor/tensor/type.py
2928
pytensor/tensor/type_other.py
3029
pytensor/tensor/variable.py

0 commit comments

Comments
 (0)