@@ -355,34 +355,37 @@ def local_lift_through_linalg(
355
355
"""
356
356
357
357
# TODO: Simplify this if we end up Blockwising KroneckerProduct
358
- if isinstance (node .op .core_op , MatrixInverse | Cholesky | MatrixPinv ):
359
- y = node .inputs [0 ]
360
- outer_op = node .op
361
-
362
- if y .owner and (
363
- isinstance (y .owner .op , Blockwise )
364
- and isinstance (y .owner .op .core_op , BlockDiagonal )
365
- or isinstance (y .owner .op , KroneckerProduct )
366
- ):
367
- input_matrices = y .owner .inputs
368
-
369
- if isinstance (outer_op .core_op , MatrixInverse ):
370
- outer_f = cast (Callable , inv )
371
- elif isinstance (outer_op .core_op , Cholesky ):
372
- outer_f = cast (Callable , cholesky )
373
- elif isinstance (outer_op .core_op , MatrixPinv ):
374
- outer_f = cast (Callable , pinv )
375
- else :
376
- raise NotImplementedError # pragma: no cover
358
+ if not isinstance (node .op .core_op , MatrixInverse | Cholesky | MatrixPinv ):
359
+ return None
377
360
378
- inner_matrices = [cast (TensorVariable , outer_f (m )) for m in input_matrices ]
361
+ y = node .inputs [0 ]
362
+ outer_op = node .op
379
363
380
- if isinstance (y .owner .op , KroneckerProduct ):
381
- return [kron (* inner_matrices )]
382
- elif isinstance (y .owner .op .core_op , BlockDiagonal ):
383
- return [block_diag (* inner_matrices )]
384
- else :
385
- raise NotImplementedError # pragma: no cover
364
+ if y .owner and (
365
+ isinstance (y .owner .op , Blockwise )
366
+ and isinstance (y .owner .op .core_op , BlockDiagonal )
367
+ or isinstance (y .owner .op , KroneckerProduct )
368
+ ):
369
+ input_matrices = y .owner .inputs
370
+
371
+ if isinstance (outer_op .core_op , MatrixInverse ):
372
+ outer_f = cast (Callable , inv )
373
+ elif isinstance (outer_op .core_op , Cholesky ):
374
+ outer_f = cast (Callable , cholesky )
375
+ elif isinstance (outer_op .core_op , MatrixPinv ):
376
+ outer_f = cast (Callable , pinv )
377
+ else :
378
+ raise NotImplementedError # pragma: no cover
379
+
380
+ inner_matrices = [cast (TensorVariable , outer_f (m )) for m in input_matrices ]
381
+
382
+ if isinstance (y .owner .op , KroneckerProduct ):
383
+ return [kron (* inner_matrices )]
384
+ elif isinstance (y .owner .op .core_op , BlockDiagonal ):
385
+ return [block_diag (* inner_matrices )]
386
+ else :
387
+ raise NotImplementedError # pragma: no cover
388
+ return None
386
389
387
390
388
391
def _find_diag_from_eye_mul (potential_mul_input ):
0 commit comments