You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements.
656
+
This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix
657
+
658
+
Parameters
659
+
----------
660
+
fgraph: FunctionGraph
661
+
Function graph being optimized
662
+
node: Apply
663
+
Node of the function graph to be optimized
664
+
665
+
Returns
666
+
-------
667
+
list of Variable, optional
668
+
List of optimized variables, or None if no optimization was performed
669
+
"""
670
+
core_op=node.op.core_op
671
+
ifnot (isinstance(core_op, ALL_INVERSE_OPS)):
672
+
returnNone
673
+
674
+
inputs=node.inputs[0]
675
+
# Check for use of pt.diag first
676
+
if (
677
+
inputs.owner
678
+
andisinstance(inputs.owner.op, AllocDiag)
679
+
andAllocDiag.is_offset_zero(inputs.owner)
680
+
):
681
+
inv_input=inputs.owner.inputs[0]
682
+
inv_val=pt.diag(1/inv_input)
683
+
return [inv_val]
684
+
685
+
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
686
+
inputs_or_none=_find_diag_from_eye_mul(inputs)
687
+
ifinputs_or_noneisNone:
688
+
returnNone
689
+
690
+
eye_input, non_eye_inputs=inputs_or_none
691
+
692
+
# Dealing with only one other input
693
+
iflen(non_eye_inputs) !=1:
694
+
returnNone
695
+
696
+
non_eye_input=non_eye_inputs[0]
697
+
698
+
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
0 commit comments