|
5 | 5 | from pytensor import Variable
|
6 | 6 | from pytensor.graph import Apply, FunctionGraph
|
7 | 7 | from pytensor.graph.rewriting.basic import (
|
| 8 | + PatternNodeRewriter, |
8 | 9 | copy_stack_trace,
|
9 | 10 | node_rewriter,
|
10 | 11 | )
|
11 |
| -from pytensor.tensor.basic import TensorVariable, diagonal |
| 12 | +from pytensor.scalar.basic import Mul |
| 13 | +from pytensor.tensor.basic import ARange, Eye, TensorVariable, alloc, diagonal |
12 | 14 | from pytensor.tensor.blas import Dot22
|
13 | 15 | from pytensor.tensor.blockwise import Blockwise
|
14 |
| -from pytensor.tensor.elemwise import DimShuffle |
| 16 | +from pytensor.tensor.elemwise import DimShuffle, Elemwise |
15 | 17 | from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod
|
16 | 18 | from pytensor.tensor.nlinalg import (
|
17 | 19 | SVD,
|
|
39 | 41 | solve,
|
40 | 42 | solve_triangular,
|
41 | 43 | )
|
| 44 | +from pytensor.tensor.subtensor import advanced_set_subtensor |
42 | 45 |
|
43 | 46 |
|
44 | 47 | logger = logging.getLogger(__name__)
|
@@ -384,6 +387,104 @@ def local_lift_through_linalg(
|
384 | 387 | raise NotImplementedError # pragma: no cover
|
385 | 388 |
|
386 | 389 |
|
| 390 | +def _find_diag_from_eye_mul(potential_mul_input): |
| 391 | + # Check if the op is Elemwise and mul |
| 392 | + if not ( |
| 393 | + potential_mul_input.owner is not None |
| 394 | + and isinstance(potential_mul_input.owner.op, Elemwise) |
| 395 | + and isinstance(potential_mul_input.owner.op.scalar_op, Mul) |
| 396 | + ): |
| 397 | + return None |
| 398 | + |
| 399 | + # Find whether any of the inputs to mul is Eye |
| 400 | + inputs_to_mul = potential_mul_input.owner.inputs |
| 401 | + eye_input = [ |
| 402 | + mul_input |
| 403 | + for mul_input in inputs_to_mul |
| 404 | + if mul_input.owner and isinstance(mul_input.owner.op, Eye) |
| 405 | + ] |
| 406 | + |
| 407 | + # Check if 1's are being put on the main diagonal only (k = 0) |
| 408 | + if eye_input and getattr(eye_input[0].owner.inputs[-1], "data", -1).item() != 0: |
| 409 | + return None |
| 410 | + |
| 411 | + # If the broadcast pattern of eye_input is not (False, False), we do not get a diagonal matrix and thus, dont need to apply the rewrite |
| 412 | + if eye_input and eye_input[0].broadcastable[-2:] != (False, False): |
| 413 | + return None |
| 414 | + |
| 415 | + # Get all non Eye inputs (scalars/matrices/vectors) |
| 416 | + non_eye_inputs = list(set(inputs_to_mul) - set(eye_input)) |
| 417 | + return eye_input, non_eye_inputs |
| 418 | + |
| 419 | + |
| 420 | +@register_canonicalize("shape_unsafe") |
| 421 | +@register_stabilize("shape_unsafe") |
| 422 | +@node_rewriter([det]) |
| 423 | +def rewrite_det_diag_from_eye_mul(fgraph, node): |
| 424 | + """ |
| 425 | + This rewrite takes advantage of the fact that for a diagonal matrix, the determinant value is the product of its diagonal elements. |
| 426 | +
|
| 427 | + The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to make this rewrite as efficient as possible, depending on whether the multiplication was with a scalar, vector or a matrix. |
| 428 | +
|
| 429 | + Parameters |
| 430 | + ---------- |
| 431 | + fgraph: FunctionGraph |
| 432 | + Function graph being optimized |
| 433 | + node: Apply |
| 434 | + Node of the function graph to be optimized |
| 435 | +
|
| 436 | + Returns |
| 437 | + ------- |
| 438 | + list of Variable, optional |
| 439 | + List of optimized variables, or None if no optimization was performed |
| 440 | + """ |
| 441 | + potential_mul_input = node.inputs[0] |
| 442 | + eye_non_eye_inputs = _find_diag_from_eye_mul(potential_mul_input) |
| 443 | + if eye_non_eye_inputs is None: |
| 444 | + return None |
| 445 | + eye_input, non_eye_inputs = eye_non_eye_inputs |
| 446 | + |
| 447 | + # Dealing with only one other input |
| 448 | + if len(non_eye_inputs) != 1: |
| 449 | + return None |
| 450 | + |
| 451 | + useful_eye, useful_non_eye = eye_input[0], non_eye_inputs[0] |
| 452 | + |
| 453 | + # Checking if original x was scalar/vector/matrix |
| 454 | + if useful_non_eye.type.broadcastable[-2:] == (True, True): |
| 455 | + # For scalar |
| 456 | + det_val = useful_non_eye.squeeze(axis=(-1, -2)) ** (useful_eye.shape[0]) |
| 457 | + elif useful_non_eye.type.broadcastable[-2:] == (False, False): |
| 458 | + # For Matrix |
| 459 | + det_val = useful_non_eye.diagonal(axis1=-1, axis2=-2).prod(axis=-1) |
| 460 | + else: |
| 461 | + # For vector |
| 462 | + det_val = useful_non_eye.prod(axis=(-1, -2)) |
| 463 | + det_val = det_val.astype(node.outputs[0].type.dtype) |
| 464 | + return [det_val] |
| 465 | + |
| 466 | + |
| 467 | +arange = ARange("int64") |
| 468 | +det_diag_from_diag = PatternNodeRewriter( |
| 469 | + ( |
| 470 | + det, |
| 471 | + ( |
| 472 | + advanced_set_subtensor, |
| 473 | + (alloc, 0, "sh1", "sh2"), |
| 474 | + "x", |
| 475 | + (arange, 0, "stop", 1), |
| 476 | + (arange, 0, "stop", 1), |
| 477 | + ), |
| 478 | + ), |
| 479 | + (prod, "x"), |
| 480 | + name="det_diag_from_diag", |
| 481 | + allow_multiple_clients=True, |
| 482 | +) |
| 483 | +register_canonicalize(det_diag_from_diag) |
| 484 | +register_stabilize(det_diag_from_diag) |
| 485 | +register_specialize(det_diag_from_diag) |
| 486 | + |
| 487 | + |
387 | 488 | @register_canonicalize
|
388 | 489 | @register_stabilize
|
389 | 490 | @register_specialize
|
|
0 commit comments