Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eig Rewrites #1126

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@
MatrixPinv,
SLogDet,
det,
eig,
inv,
kron,
pinv,
@@ -1013,3 +1014,104 @@
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
}
return replacements


@register_canonicalize
@register_stabilize
@node_rewriter([eig])
def rewrite_eig_eye(fgraph, node):
"""
This rewrite takes advantage of the fact that for any identity matrix, all the eigenvalues are 1 and the eigenvectors are the standard basis.

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check whether input to Eig is Eye and the 1's are on main diagonal
potential_eye = node.inputs[0]
if not (
potential_eye.owner
and isinstance(potential_eye.owner.op, Eye)
and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0
):
return None

eigval_rewritten = pt.ones(potential_eye.shape[-1])
eigvec_rewritten = pt.eye(potential_eye.shape[-1])

return [eigval_rewritten, eigvec_rewritten]


@register_canonicalize
@register_stabilize
@node_rewriter([eig])
def rewrite_eig_diag(fgraph, node):
"""
This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the standard basis.

The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices
that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to
make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar,
vector or a matrix.

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
inputs = node.inputs[0]

# Check for use of pt.diag first
if (
inputs.owner
and isinstance(inputs.owner.op, AllocDiag)
and AllocDiag.is_offset_zero(inputs.owner)
):
eigval_rewritten = pt.diag(inputs)
eigvec_rewritten = pt.eye(inputs.shape[-1])
return [eigval_rewritten, eigvec_rewritten]

# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(inputs)
if inputs_or_none is None:
return None

eye_input, non_eye_inputs = inputs_or_none

# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None

Check warning on line 1098 in pytensor/tensor/rewriting/linalg.py

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L1098

Added line #L1098 was not covered by tests

eye_input, non_eye_input = eye_input, non_eye_inputs[0]
# eigval_rewritten = pt.diag(non_eye_input)
eigvec_rewritten = eye_input

# Checking if original x was scalar/vector/matrix
if non_eye_input.type.broadcastable[-2:] == (True, True):
# For scalar
eigval_rewritten = pt.full(
(eye_input.shape[0],), non_eye_input.squeeze(axis=(-1, -2))
)
elif non_eye_input.type.broadcastable[-2:] == (False, False):
# For Matrix
eigval_rewritten = pt.diag(non_eye_input)
else:
# For vector
eigval_rewritten = non_eye_input.squeeze()

return [eigval_rewritten, eigvec_rewritten]
105 changes: 105 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
from pytensor.tensor.nlinalg import (
SVD,
Det,
Eig,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
@@ -996,3 +997,107 @@ def test_slogdet_specialization():
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)


@pytest.mark.parametrize(
"shape",
[(), (7,), (1, 7), (7, 1), (7, 7)],
ids=["scalar", "vector", "row_vec", "col_vec", "matrix"],
)
def test_eig_diag_from_eye_mul(shape):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
y = pt.eye(7) * x

# Calculating eigval and eigvec using pt.linalg.eig
eigval, eigvec = pt.linalg.eig(y)

# REWRITE TEST
f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

assert not any(
isinstance(node.op, Eig) or isinstance(getattr(node.op, "core_op", None), Eig)
for node in nodes
)

# NUMERIC VALUE TEST
if len(shape) == 0:
x_test = np.array(np.random.rand()).astype(config.floatX)
elif len(shape) == 1:
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)

x_test_matrix = np.eye(7) * x_test
eigval, eigvec = np.linalg.eig(x_test_matrix)
rewritten_eigval, rewritten_eigvec = f_rewritten(x_test)

assert_allclose(
eigval,
rewritten_eigval,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
eigvec,
rewritten_eigvec,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_eig_eye():
x = pt.eye(10)
eigval, eigvec = pt.linalg.eig(x)

# REWRITE TEST
f_rewritten = function([], [eigval, eigvec], mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Eig) for node in nodes)

# NUMERIC VALUE TEST
x_test = np.eye(10)
eigval, eigvec = np.linalg.eig(x_test)
rewritten_eigval, rewritten_eigvec = f_rewritten()
assert_allclose(
eigval,
rewritten_eigval,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
eigvec,
rewritten_eigvec,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_eig_diag():
x = pt.tensor("x", shape=(None,))
x_diag = pt.diag(x)
eigval, eigvec = pt.linalg.eig(x_diag)

# REWRITE TEST
f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Eig) for node in nodes)

# NUMERIC VALUE TEST
x_test = np.random.rand(7).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
eigval, eigvec = np.linalg.eig(x_test_matrix)
rewritten_eigval, rewritten_eigvec = f_rewritten(x_test)
assert_allclose(
eigval,
rewritten_eigval,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
eigvec,
rewritten_eigvec,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)