2
2
import typing
3
3
import warnings
4
4
from functools import reduce
5
- from typing import TYPE_CHECKING , Literal
5
+ from typing import TYPE_CHECKING , Literal , cast
6
6
7
7
import numpy as np
8
8
import scipy .linalg
@@ -141,7 +141,7 @@ def cholesky(x, lower=True, on_error="raise", check_finite=False):
141
141
class SolveBase (Op ):
142
142
"""Base class for `scipy.linalg` matrix equation solvers."""
143
143
144
- __props__ = (
144
+ __props__ : tuple [ str , ...] = (
145
145
"lower" ,
146
146
"check_finite" ,
147
147
"b_ndim" ,
@@ -352,7 +352,7 @@ def solve_triangular(
352
352
This will influence how batched dimensions are interpreted.
353
353
"""
354
354
b_ndim = _default_b_ndim (b , b_ndim )
355
- return Blockwise (
355
+ ret = Blockwise (
356
356
SolveTriangular (
357
357
lower = lower ,
358
358
trans = trans ,
@@ -361,6 +361,7 @@ def solve_triangular(
361
361
b_ndim = b_ndim ,
362
362
)
363
363
)(a , b )
364
+ return cast (TensorVariable , ret )
364
365
365
366
366
367
class Solve (SolveBase ):
@@ -714,9 +715,7 @@ def grad(self, inputs, output_grads):
714
715
715
716
716
717
_solve_continuous_lyapunov = SolveContinuousLyapunov ()
717
- _solve_bilinear_direct_lyapunov = typing .cast (
718
- typing .Callable , BilinearSolveDiscreteLyapunov ()
719
- )
718
+ _solve_bilinear_direct_lyapunov = cast (typing .Callable , BilinearSolveDiscreteLyapunov ())
720
719
721
720
722
721
def _direct_solve_discrete_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
@@ -729,7 +728,7 @@ def _direct_solve_discrete_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorV
729
728
AA = kron (A_ , A_ )
730
729
731
730
X = solve (pt .eye (AA .shape [0 ]) - AA , Q_ .ravel ())
732
- return typing . cast (TensorVariable , reshape (X , Q_ .shape ))
731
+ return cast (TensorVariable , reshape (X , Q_ .shape ))
733
732
734
733
735
734
def solve_discrete_lyapunov (
@@ -765,7 +764,7 @@ def solve_discrete_lyapunov(
765
764
if method == "direct" :
766
765
return _direct_solve_discrete_lyapunov (A , Q )
767
766
if method == "bilinear" :
768
- return typing . cast (TensorVariable , _solve_bilinear_direct_lyapunov (A , Q ))
767
+ return cast (TensorVariable , _solve_bilinear_direct_lyapunov (A , Q ))
769
768
770
769
771
770
def solve_continuous_lyapunov (A : "TensorLike" , Q : "TensorLike" ) -> TensorVariable :
@@ -785,7 +784,7 @@ def solve_continuous_lyapunov(A: "TensorLike", Q: "TensorLike") -> TensorVariabl
785
784
786
785
"""
787
786
788
- return typing . cast (TensorVariable , _solve_continuous_lyapunov (A , Q ))
787
+ return cast (TensorVariable , _solve_continuous_lyapunov (A , Q ))
789
788
790
789
791
790
class SolveDiscreteARE (pt .Op ):
@@ -866,9 +865,7 @@ def solve_discrete_are(A, B, Q, R, enforce_Q_symmetric=False) -> TensorVariable:
866
865
Square matrix of shape M x M, representing the solution to the DARE
867
866
"""
868
867
869
- return typing .cast (
870
- TensorVariable , SolveDiscreteARE (enforce_Q_symmetric )(A , B , Q , R )
871
- )
868
+ return cast (TensorVariable , SolveDiscreteARE (enforce_Q_symmetric )(A , B , Q , R ))
872
869
873
870
874
871
def _largest_common_dtype (tensors : typing .Sequence [TensorVariable ]) -> np .dtype :
0 commit comments