Skip to content

Commit 94e9ef0

Browse files
Rewrite determinant of diagonal matrix as product of diagonal (#797)
* Added det-diag rewrite * fixed pt.diagonal error * Added test for rewrite * Added test for rewrite * fixed test * added check for verifying rewrite * fixed other failing test * added docstring * updated docstring * fixed mypy error * added det_diag_from_diag and test * fixed node rewriter name * added row/col tests * updated check for eye * updated rewrite and tests * added check for eye_input and new test for cases where not to apply rewrite * does not apply rewrite to specific cases * typecasted test variable * typecast variables * removed shape known check; fails for rectangle eye * added new tests for (1,1) eye and rectangle eye * added helper function for diag from eye_mul * updated case for no rewrite which was failing tests * cleaned code; updated rectangle_eye test which is an invalid rewrite * add check for k in pt.eye * Update pytensor/tensor/rewriting/linalg.py Co-authored-by: Ricardo Vieira <[email protected]> * typecasted det_val * fixed final typecasting * fixed merge * fixed failing rectangle eye test * fixed typo --------- Co-authored-by: Ricardo Vieira <[email protected]>
1 parent bf8a1b5 commit 94e9ef0

File tree

2 files changed

+192
-2
lines changed

2 files changed

+192
-2
lines changed

pytensor/tensor/rewriting/linalg.py

+103-2
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from pytensor import Variable
66
from pytensor.graph import Apply, FunctionGraph
77
from pytensor.graph.rewriting.basic import (
8+
PatternNodeRewriter,
89
copy_stack_trace,
910
node_rewriter,
1011
)
11-
from pytensor.tensor.basic import TensorVariable, diagonal
12+
from pytensor.scalar.basic import Mul
13+
from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal
1214
from pytensor.tensor.blas import Dot22
1315
from pytensor.tensor.blockwise import Blockwise
14-
from pytensor.tensor.elemwise import DimShuffle
16+
from pytensor.tensor.elemwise import DimShuffle, Elemwise
1517
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
1618
from pytensor.tensor.nlinalg import (
1719
SVD,
@@ -39,6 +41,7 @@
3941
solve,
4042
solve_triangular,
4143
)
44+
from pytensor.tensor.subtensor import advanced_set_subtensor
4245

4346

4447
logger = logging.getLogger(__name__)
@@ -384,6 +387,104 @@ def local_lift_through_linalg(
384387
raise NotImplementedError # pragma: no cover
385388

386389

390+
def _find_diag_from_eye_mul(potential_mul_input):
391+
# Check if the op is Elemwise and mul
392+
if not (
393+
potential_mul_input.owner is not None
394+
and isinstance(potential_mul_input.owner.op, Elemwise)
395+
and isinstance(potential_mul_input.owner.op.scalar_op, Mul)
396+
):
397+
return None
398+
399+
# Find whether any of the inputs to mul is Eye
400+
inputs_to_mul = potential_mul_input.owner.inputs
401+
eye_input = [
402+
mul_input
403+
for mul_input in inputs_to_mul
404+
if mul_input.owner and isinstance(mul_input.owner.op, Eye)
405+
]
406+
407+
# Check if 1's are being put on the main diagonal only (k = 0)
408+
if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0:
409+
return None
410+
411+
# If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite
412+
if eye_input and eye_input[0].broadcastable[-2:] != (False, False):
413+
return None
414+
415+
# Get all non Eye inputs (scalars/matrices/vectors)
416+
non_eye_inputs = list(set(inputs_to_mul) - set(eye_input))
417+
return eye_input, non_eye_inputs
418+
419+
420+
@register_canonicalize("shape_unsafe")
421+
@register_stabilize("shape_unsafe")
422+
@node_rewriter([det])
423+
def rewrite_det_diag_from_eye_mul(fgraph, node):
424+
"""
425+
This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements.
426+
427+
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix.
428+
429+
Parameters
430+
----------
431+
fgraph: FunctionGraph
432+
Function graph being optimized
433+
node: Apply
434+
Node of the function graph to be optimized
435+
436+
Returns
437+
-------
438+
list of Variable, optional
439+
List of optimized variables, or None if no optimization was performed
440+
"""
441+
potential_mul_input = node.inputs[0]
442+
eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input)
443+
if eye_non_eye_inputs is None:
444+
return None
445+
eye_input, non_eye_inputs = eye_non_eye_inputs
446+
447+
# Dealing with only one other input
448+
if len(non_eye_inputs) != 1:
449+
return None
450+
451+
useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0]
452+
453+
# Checking if original x was scalar/vector/matrix
454+
if useful_non_eye.type.broadcastable[-2:] == (True, True):
455+
# For scalar
456+
det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0])
457+
elif useful_non_eye.type.broadcastable[-2:] == (False, False):
458+
# For Matrix
459+
det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1)
460+
else:
461+
# For vector
462+
det_val = useful_non_eye.prod(axis=(-1, -2))
463+
det_val = det_val.astype(node.outputs[0].type.dtype)
464+
return [det_val]
465+
466+
467+
arange = ARange("int64")
468+
det_diag_from_diag = PatternNodeRewriter(
469+
(
470+
det,
471+
(
472+
advanced_set_subtensor,
473+
(alloc, 0, "sh1", "sh2"),
474+
"x",
475+
(arange, 0, "stop", 1),
476+
(arange, 0, "stop", 1),
477+
),
478+
),
479+
(prod, "x"),
480+
name="det_diag_from_diag",
481+
allow_multiple_clients=True,
482+
)
483+
register_canonicalize(det_diag_from_diag)
484+
register_stabilize(det_diag_from_diag)
485+
register_specialize(det_diag_from_diag)
486+
487+
387488
@register_canonicalize
388489
@register_stabilize
389490
@register_specialize

tests/tensor/rewriting/test_linalg.py

+89
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,95 @@ def test_local_lift_through_linalg(constructor, f_op, f, g_op, g):
394394
np.testing.assert_allclose(f1(*test_vals), f2(*test_vals), atol=1e-8)
395395

396396

397+
@pytest.mark.parametrize(
398+
"shape",
399+
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
400+
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
401+
)
402+
def test_det_diag_from_eye_mul(shape):
403+
# Initializing x based on scalar/vector/matrix
404+
x = pt.tensor("x", shape=shape)
405+
y = pt.eye(7) * x
406+
# Calculating determinant value using pt.linalg.det
407+
z_det = pt.linalg.det(y)
408+
409+
# REWRITE TEST
410+
f_rewritten = function([x], z_det, mode="FAST_RUN")
411+
nodes = f_rewritten.maker.fgraph.apply_nodes
412+
assert not any(isinstance(node.op, Det) for node in nodes)
413+
414+
# NUMERIC VALUE TEST
415+
if len(shape) == 0:
416+
x_test = np.array(np.random.rand()).astype(config.floatX)
417+
elif len(shape) == 1:
418+
x_test = np.random.rand(*shape).astype(config.floatX)
419+
else:
420+
x_test = np.random.rand(*shape).astype(config.floatX)
421+
x_test_matrix = np.eye(7) * x_test
422+
det_val = np.linalg.det(x_test_matrix)
423+
rewritten_val = f_rewritten(x_test)
424+
425+
assert_allclose(
426+
det_val,
427+
rewritten_val,
428+
atol=1e-3 if config.floatX == "float32" else 1e-8,
429+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
430+
)
431+
432+
433+
def test_det_diag_from_diag():
434+
x = pt.tensor("x", shape=(None,))
435+
x_diag = pt.diag(x)
436+
y = pt.linalg.det(x_diag)
437+
438+
# REWRITE TEST
439+
f_rewritten = function([x], y, mode="FAST_RUN")
440+
nodes = f_rewritten.maker.fgraph.apply_nodes
441+
assert not any(isinstance(node.op, Det) for node in nodes)
442+
443+
# NUMERIC VALUE TEST
444+
x_test = np.random.rand(7).astype(config.floatX)
445+
x_test_matrix = np.eye(7) * x_test
446+
det_val = np.linalg.det(x_test_matrix)
447+
rewritten_val = f_rewritten(x_test)
448+
449+
assert_allclose(
450+
det_val,
451+
rewritten_val,
452+
atol=1e-3 if config.floatX == "float32" else 1e-8,
453+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
454+
)
455+
456+
457+
def test_dont_apply_det_diag_rewrite_for_1_1():
458+
x = pt.matrix("x")
459+
x_diag = pt.eye(1, 1) * x
460+
y = pt.linalg.det(x_diag)
461+
f_rewritten = function([x], y, mode="FAST_RUN")
462+
nodes = f_rewritten.maker.fgraph.apply_nodes
463+
464+
assert any(isinstance(node.op, Det) for node in nodes)
465+
466+
# Numeric Value test
467+
x_test = np.random.normal(size=(3, 3)).astype(config.floatX)
468+
x_test_matrix = np.eye(1, 1) * x_test
469+
det_val = np.linalg.det(x_test_matrix)
470+
rewritten_val = f_rewritten(x_test)
471+
assert_allclose(
472+
det_val,
473+
rewritten_val,
474+
atol=1e-3 if config.floatX == "float32" else 1e-8,
475+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
476+
)
477+
478+
479+
def test_det_diag_incorrect_for_rectangle_eye():
480+
x = pt.matrix("x")
481+
x_diag = pt.eye(7, 5) * x
482+
with pytest.raises(ValueError, match="Determinant not defined"):
483+
pt.linalg.det(x_diag)
484+
485+
397486
def test_svd_uv_merge():
398487
a = matrix("a")
399488
s_1 = svd(a, full_matrices=False, compute_uv=False)

0 commit comments

Comments
 (0)