Skip to content

Refactor infer_shape method of Ops to find output shapes using gufunc_signature #1294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,16 @@ def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
# By default, do nothing
return self

def infer_shape(self, fgraph, node, input_shapes):
if hasattr(self, "gufunc_signature"):
from pytensor.tensor.utils import _gufunc_to_out_shape

return _gufunc_to_out_shape(self.gufunc_signature, input_shapes)
else:
from pytensor.tensor.exceptions import ShapeError

raise ShapeError(f"Op {self} does not implement infer_shape")

def __str__(self):
return getattr(type(self), "__name__", super().__str__())

Expand Down
16 changes: 0 additions & 16 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ def L_op(self, inputs, outputs, g_outputs):
).T
return [grad]

def infer_shape(self, fgraph, node, shapes):
return [list(reversed(shapes[0]))]


def pinv(x, hermitian=False):
"""Computes the pseudo-inverse of a matrix :math:`A`.
Expand Down Expand Up @@ -155,9 +152,6 @@ def R_op(self, inputs, eval_points):
return [None]
return [-matrix_dot(xi, ev, xi)]

def infer_shape(self, fgraph, node, shapes):
return shapes


inv = matrix_inverse = Blockwise(MatrixInverse())

Expand Down Expand Up @@ -224,9 +218,6 @@ def grad(self, inputs, g_outputs):
(x,) = inputs
return [gz * self(x) * matrix_inverse(x).T]

def infer_shape(self, fgraph, node, shapes):
return [()]

def __str__(self):
return "Det"

Expand Down Expand Up @@ -258,9 +249,6 @@ def perform(self, node, inputs, outputs):
except Exception as e:
raise ValueError("Failed to compute determinant", x) from e

def infer_shape(self, fgraph, node, shapes):
return [(), ()]

def __str__(self):
return "SLogDet"

Expand Down Expand Up @@ -316,10 +304,6 @@ def perform(self, node, inputs, outputs):
(w, v) = outputs
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))

def infer_shape(self, fgraph, node, shapes):
n = shapes[0][0]
return [(n,), (n, n)]


eig = Blockwise(Eig())

Expand Down
21 changes: 0 additions & 21 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def __init__(
if self.overwrite_a:
self.destroy_map = {0: [0]}

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]

def make_node(self, x):
x = as_tensor_variable(x)
if x.type.ndim != 2:
Expand Down Expand Up @@ -268,15 +265,6 @@ def make_node(self, A, b):
x = tensor(dtype=o_dtype, shape=b.type.shape)
return Apply(self, [A, b], [x])

def infer_shape(self, fgraph, node, shapes):
Ashape, Bshape = shapes
rows = Ashape[1]
if len(Bshape) == 1:
return [(rows,)]
else:
cols = Bshape[1]
return [(rows, cols)]

def L_op(self, inputs, outputs, output_gradients):
r"""Reverse-mode gradient updates for matrix solve operation :math:`c = A^{-1} b`.

Expand Down Expand Up @@ -890,9 +878,6 @@ def perform(self, node, inputs, output_storage):
out_dtype = node.outputs[0].type.dtype
X[0] = scipy_linalg.solve_continuous_lyapunov(A, B).astype(out_dtype)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]

def grad(self, inputs, output_grads):
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
# Note that they write the equation as AX + XA.H + Q = 0, while scipy uses AX + XA^H = Q,
Expand Down Expand Up @@ -962,9 +947,6 @@ def perform(self, node, inputs, output_storage):
out_dtype
)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]

def grad(self, inputs, output_grads):
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
A, Q = inputs
Expand Down Expand Up @@ -1082,9 +1064,6 @@ def perform(self, node, inputs, output_storage):
out_dtype = node.outputs[0].type.dtype
X[0] = scipy_linalg.solve_discrete_are(A, B, Q, R).astype(out_dtype)

def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]

def grad(self, inputs, output_grads):
# Gradient computations come from Kao and Hennequin (2020), https://arxiv.org/pdf/2011.11430.pdf
A, B, Q, R = inputs
Expand Down
53 changes: 53 additions & 0 deletions pytensor/tensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytensor
from pytensor.graph import FunctionGraph, Variable
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor import Any, Constant
from pytensor.utils import hash_from_code


Expand Down Expand Up @@ -202,6 +203,58 @@ def _parse_gufunc_signature(
)


def _gufunc_to_out_shape(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should try to prioritize input dimensions that are constant when there are multiple ones with the same letter, as it will generate a better shape graph.

Also fail explicitly if the shape can't be inferred from the signature alone

Copy link
Contributor Author

@Aarsh-Wankar Aarsh-Wankar Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I raise a ValueError when a single dimension in the gufunc_signature (say m) is assigned two values in the input shapes? I am raising a ValueError when the shape cannot be inferred from the signature. Here is my new function:

def _gufunc_to_out_shape(
    signature: str, shapes: list[tuple[int, ...]]
) -> list[tuple[int, ...]]:
    """
    Compute the shape of the output of an Op given its gufunc signature and the
    shapes of its inputs.

    Parameters
    ----------
    signature : str
        The gufunc signature of the Op.
        eg: "(m,n),(n,p)->(m,p)".

    shapes : list of tuple of int
        The list of shapes of the inputs.

    Returns
    -------
    out_shape : list of tuple of int
        The list of shapes of the outputs.

    Raises
    ------
    ValueError
        If the signature is invalid for the shapes of the inputs.
    """
    parsed = _parse_gufunc_signature(signature)
    out_shape = []
    dim_to_size = dict()
    for i, input_shape in enumerate(parsed[0]):
        for j, dim in enumerate(input_shape):
            if dim not in dim_to_size:
                dim_to_size[dim] = shapes[i][j]
            elif dim_to_size[dim] != shapes[i][j]:
                raise ValueError(
                    f"Invalid signature {signature} for shapes {shapes}. "
                    f"Dimension {dim} is not consistent across inputs."
                )

    for i, output_shape in enumerate(parsed[1]):
        temp_list = []
        for j, dim in enumerate(output_shape):
            if dim not in dim_to_size:
                raise ValueError(
                    f"Invalid signature {signature} for shapes {shapes}. "
                    f"Dimension {dim} not in input dimensions."
                )
            else:
                temp_list.append(dim_to_size[dim])
        out_shape.append(tuple(temp_list))
    return out_shape

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shapes are symbolic you can't always compare them to know if they match, you can only do that for constants

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so if shapes list can contain symbolic values, can we first convert the symbolic variable to a tensor_variable using as_tensor_variables and then use .equals method to compare them? Specifically, in the function above, we make this change:

for i, input_shape in enumerate(parsed[0]):
    for j, dim in enumerate(input_shape):
        current_dim = as_tensor_variable(shapes[i][j])
        if dim not in dim_to_size:
            dim_to_size[dim] = current_dim
        elif not dim_to_size[dim].equals(current_dim):
            raise ValueError(
                f"Invalid signature {signature} for shapes {shapes}. "
                f"Dimension {dim} is not consistent across inputs."
            )

This is how the output looks like then: (The shapes input contains integers here)

print(_gufunc_to_out_shape("(m,m)->(m),(m,m)", [(2, 2)]))
[(TensorConstant(TensorType(int8, shape=()), data=array(2, dtype=int8)),), (TensorConstant(TensorType(int8, shape=()), data=array(2, dtype=int8)), TensorConstant(TensorType(int8, shape=()), data=array(2, dtype=int8)))]

Can there be a better way to do this?

Copy link
Member

@ricardoV94 ricardoV94 Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if they are symbolic, unless they are constants, you can't know if they are equivalent, so raising an error is incorrect. User may have defined x = pt.vector("x"); y =pt.vector("y"); out = x + y. You will not know if x and y have identical shapes until the user compiles a function and provides values for them.

The logic for giving priority to constants is already in the Blockwise infer_shape, you should be able to just grab it and refactor it. We just want to simplify the graph returned, if there are two inputs with the same letters, and one of them has a constant shape. In that case we pick the constant one.

Checking if two inputs agree is not the critical thing here, although we can do that. You can only do that if isinstance(shapes[0][j], Constant) and isinstance(shapes[1][j], Constant) in which case you can then check shapes[0][j].data == shapes[1][j].data if you want to raise an informative error when they are inconsistent.

Also no reason to convert to as_tensor_variable, they should be ScalarVariables IIRC. Could be wrong here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took the code from the infer_shape of Blockwise class, and used it to write the function. This gives priority to Constants. Is the output type list[tuple[Any, ...]] fine?

def _gufunc_to_out_shape(
    signature: str, shapes
) -> list[tuple[Any, ...]]:
    """
    Compute the shape of the output of an Op given its gufunc signature and the
    shapes of its inputs.

    Parameters
    ----------
    signature : str
        The gufunc signature of the Op.
        eg: "(m,n),(n,p)->(m,p)".

    shapes : list of tuple of Any
        The list of shapes of the inputs.

    Returns
    -------
    out_shape : list of tuple of Any
        The list of shapes of the outputs.

    Raises
    ------
    ValueError
        If the signature is invalid for the shapes of the inputs.
    """
    input_sig, output_sig = _parse_gufunc_signature(signature)
    dim_to_size : dict[str, Any] = {}
    for input_shape, sig in zip(shapes, input_sig, strict = True):
        for size, dim_name in zip(input_shape, sig, strict=True):
            prev_size = dim_to_size.get(dim_name)
            if prev_size is None:
                dim_to_size[dim_name] = size
            # Prefer constants
            elif not isinstance(prev_size, Constant):
                dim_to_size[dim_name] = size
            elif prev_size.data != size:
                raise ValueError(
                    f"Invalid signature {signature} for shapes {shapes}. "
                    f"Dimension {dim_name} is not consistent across inputs."
                )
    out_shapes = []
    for output_shape in output_sig:
        temp_list = []
        for dim in output_shape:
            if dim not in dim_to_size:
                raise ValueError(
                    f"Invalid signature {signature} for shapes {shapes}. "
                    f"Dimension {dim} not in input dimensions."
                )
            else:
                temp_list.append(dim_to_size[dim])
        out_shapes.append((*temp_list,))       
    return out_shapes
    

signature: str, shapes: list[tuple[Any, ...]]
) -> list[tuple[Any, ...]]:
"""
Compute the shape of the output of an Op given its gufunc signature and the
shapes of its inputs.

Parameters
----------
signature : str
The gufunc signature of the Op.
eg: "(m,n),(n,p)->(m,p)".

shapes : list of tuple of Any
The list of shapes of the inputs.

Returns
-------
out_shape : list of tuple of Any
The list of shapes of the outputs.

Raises
------
ValueError
If the signature is invalid for the shapes of the inputs.
"""
input_sig, output_sig = _parse_gufunc_signature(signature)
dim_to_size: dict[str, Any] = {}
for input_shape, sig in zip(shapes, input_sig, strict=True):
for size, dim_name in zip(input_shape, sig, strict=True):
prev_size = dim_to_size.get(dim_name)
if prev_size is None:
dim_to_size[dim_name] = size
# Prefer constants
elif not isinstance(prev_size, Constant):
dim_to_size[dim_name] = size

out_shapes = []
for output_shape in output_sig:
temp_list = []
for dim in output_shape:
if dim not in dim_to_size:
raise ValueError(
f"Invalid signature {signature} for shapes {shapes}. "
f"Dimension {dim} not in input dimensions."
)
else:
temp_list.append(dim_to_size[dim])
out_shapes.append((*temp_list,))
return out_shapes


def safe_signature(
core_inputs_ndim: Sequence[int],
core_outputs_ndim: Sequence[int],
Expand Down
Loading