Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor infer_shape method of Ops to find output shapes using gufunc_signature #1294

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

Aarsh-Wankar
Copy link
Contributor

@Aarsh-Wankar Aarsh-Wankar commented Mar 13, 2025

Refactored infer_shape method of Ops to find output shapes using gufunc_signature using a newly defined function _gufunc_to_out_shape.

Description

Ops have a method infer_shape which helps to find the shapes of the outputs given the input shapes. All Ops have their own implementations of infer_shape. However, many Ops have a gufunc_signature string which gives information about the input-output shape relations. In principle, this string is enough to find the output shapes given the input shapes. Thus, a function _gufunc_to_out_shape has been added in this PR, which calculates the output shapes list given the gufunc_signature and the input shapes list. This PR also replaces the Op specific implementations of infer_shape with the _gufunc_to_out_shape output.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1294.org.readthedocs.build/en/1294/

Sorry, something went wrong.

Copy link

codecov bot commented Mar 13, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.99%. Comparing base (b75c18f) to head (c161452).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1294      +/-   ##
==========================================
- Coverage   81.99%   81.99%   -0.01%     
==========================================
  Files         188      188              
  Lines       48506    48502       -4     
  Branches     8672     8671       -1     
==========================================
- Hits        39771    39767       -4     
  Misses       6583     6583              
  Partials     2152     2152              
Files with missing lines Coverage Δ
pytensor/tensor/nlinalg.py 95.26% <100.00%> (ø)
pytensor/tensor/slinalg.py 93.33% <100.00%> (-0.06%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Aarsh-Wankar
Copy link
Contributor Author

@ricardoV94 Some Ops like BaseBlockDiagonal and SVD have gufunc_signatures which cannot be used for inferring the output shapes. For example, in BaseBlockDiagonal, the gufunc_signature is "(m0, n0), (m1, n1), (m2, n2)... -> (m, n)" Here m is the sum of all m_i and n is the sum of all n_i, but that is not represented in its gufunc. I, therefore, did not touch such Ops. How do you suggest I proceed?

@ricardoV94
Copy link
Member

That's fine we know gufunc signature won't cover all cases, the idea is to just cover the easy ones

@@ -156,7 +157,7 @@ def R_op(self, inputs, eval_points):
return [-matrix_dot(xi, ev, xi)]

def infer_shape(self, fgraph, node, shapes):
return shapes
return _gufunc_to_out_shape(self.gufunc_signature, shapes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base class should do this if there's a signature so we can remove the infer_shape methods of these Ops altogether.

Copy link
Contributor Author

@Aarsh-Wankar Aarsh-Wankar Mar 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented this in my latest commit here. Now there is a infer_shape method in the base Op class, which works when there is a gufunc_signature present. But is this the right way? Now all Ops have a infer_shape method by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 Adding a infer_shape method to the base class seems to change function graphs. In particular, there is a test here, the second assertion fails, which checks if function graph of a function which computes the output shape of a Blockwise Op formed using MyTestOp core-op has any MyTestOp nodes in it. Ideally, no TestOp nodes should be present. But when infer_shape is added to base Op class, MyTestOp nodes appear in the function graph. This causes an assertion fail, failing that test case. How do you suggest I proceed?

Here is an example:

from tests.tensor.test_blockwise import MyTestOp
from tests.tensor.test_blockwise import Blockwise
from tests.tensor.test_blockwise import tensor
import pytensor as pt
inp = tensor(shape=(5, None, None))
test_op = MyTestOp()
op = Blockwise(test_op, signature="(m, n) -> (n, m)")
out = op(inp)
shape_fn = pt.function([inp], out.shape)
shape_fn.dprint()

With infer_shape in base Op:

MakeVector{dtype='int64'} [id A] 3
 ├─ 5 [id B]
 ├─ Shape_i{1} [id C] 2
 │  └─ Blockwise{MyTestOp, (m, n) -> (n, m)} [id D] 0
 │     └─ <Tensor3(float64, shape=(5, ?, ?))> [id E]
 └─ Shape_i{2} [id F] 1
    └─ Blockwise{MyTestOp, (m, n) -> (n, m)} [id D] 0
       └─ ···

Without infer_shape in base Op:

MakeVector{dtype='int64'} [id A] 2
 ├─ 5 [id B]
 ├─ Shape_i{2} [id C] 1
 │  └─ <Tensor3(float64, shape=(5, ?, ?))> [id D]
 └─ Shape_i{1} [id E] 0
    └─ <Tensor3(float64, shape=(5, ?, ?))> [id D]

For easy reference, my base infer_shape is as follows:

    def infer_shape(self, fgraph, node, input_shapes):
        if hasattr(self, "gufunc_signature"):
            from pytensor.tensor.utils import _gufunc_to_out_shape

            return _gufunc_to_out_shape(self.gufunc_signature, input_shapes)
        else:
            from pytensor.tensor.exceptions import ShapeError

            raise ShapeError(f"Op {self} does not implement infer_shape")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow the Blockwise shape is not being inferred correctly, the change by itself looks fine so we need to investigate more why does the Blockwise not disappear from the graph now. Does the Blockwise.infer_shape even get triggered now?

@@ -202,6 +202,39 @@ def _parse_gufunc_signature(
)


def _gufunc_to_out_shape(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should try to prioritize input dimensions that are constant when there are multiple ones with the same letter, as it will generate a better shape graph.

Also fail explicitly if the shape can't be inferred from the signature alone

Copy link
Contributor Author

@Aarsh-Wankar Aarsh-Wankar Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I raise a ValueError when a single dimension in the gufunc_signature (say m) is assigned two values in the input shapes? I am raising a ValueError when the shape cannot be inferred from the signature. Here is my new function:

def _gufunc_to_out_shape(
    signature: str, shapes: list[tuple[int, ...]]
) -> list[tuple[int, ...]]:
    """
    Compute the shape of the output of an Op given its gufunc signature and the
    shapes of its inputs.

    Parameters
    ----------
    signature : str
        The gufunc signature of the Op.
        eg: "(m,n),(n,p)->(m,p)".

    shapes : list of tuple of int
        The list of shapes of the inputs.

    Returns
    -------
    out_shape : list of tuple of int
        The list of shapes of the outputs.

    Raises
    ------
    ValueError
        If the signature is invalid for the shapes of the inputs.
    """
    parsed = _parse_gufunc_signature(signature)
    out_shape = []
    dim_to_size = dict()
    for i, input_shape in enumerate(parsed[0]):
        for j, dim in enumerate(input_shape):
            if dim not in dim_to_size:
                dim_to_size[dim] = shapes[i][j]
            elif dim_to_size[dim] != shapes[i][j]:
                raise ValueError(
                    f"Invalid signature {signature} for shapes {shapes}. "
                    f"Dimension {dim} is not consistent across inputs."
                )

    for i, output_shape in enumerate(parsed[1]):
        temp_list = []
        for j, dim in enumerate(output_shape):
            if dim not in dim_to_size:
                raise ValueError(
                    f"Invalid signature {signature} for shapes {shapes}. "
                    f"Dimension {dim} not in input dimensions."
                )
            else:
                temp_list.append(dim_to_size[dim])
        out_shape.append(tuple(temp_list))
    return out_shape

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shapes are symbolic you can't always compare them to know if they match, you can only do that for constants

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so if shapes list can contain symbolic values, can we first convert the symbolic variable to a tensor_variable using as_tensor_variables and then use .equals method to compare them? Specifically, in the function above, we make this change:

for i, input_shape in enumerate(parsed[0]):
    for j, dim in enumerate(input_shape):
        current_dim = as_tensor_variable(shapes[i][j])
        if dim not in dim_to_size:
            dim_to_size[dim] = current_dim
        elif not dim_to_size[dim].equals(current_dim):
            raise ValueError(
                f"Invalid signature {signature} for shapes {shapes}. "
                f"Dimension {dim} is not consistent across inputs."
            )

This is how the output looks like then: (The shapes input contains integers here)

print(_gufunc_to_out_shape("(m,m)->(m),(m,m)", [(2, 2)]))
[(TensorConstant(TensorType(int8, shape=()), data=array(2, dtype=int8)),), (TensorConstant(TensorType(int8, shape=()), data=array(2, dtype=int8)), TensorConstant(TensorType(int8, shape=()), data=array(2, dtype=int8)))]

Can there be a better way to do this?

Copy link
Member

@ricardoV94 ricardoV94 Mar 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if they are symbolic, unless they are constants, you can't know if they are equivalent, so raising an error is incorrect. User may have defined x = pt.vector("x"); y =pt.vector("y"); out = x + y. You will not know if x and y have identical shapes until the user compiles a function and provides values for them.

The logic for giving priority to constants is already in the Blockwise infer_shape, you should be able to just grab it and refactor it. We just want to simplify the graph returned, if there are two inputs with the same letters, and one of them has a constant shape. In that case we pick the constant one.

Checking if two inputs agree is not the critical thing here, although we can do that. You can only do that if isinstance(shapes[0][j], Constant) and isinstance(shapes[1][j], Constant) in which case you can then check shapes[0][j].data == shapes[1][j].data if you want to raise an informative error when they are inconsistent.

Also no reason to convert to as_tensor_variable, they should be ScalarVariables IIRC. Could be wrong here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took the code from the infer_shape of Blockwise class, and used it to write the function. This gives priority to Constants. Is the output type list[tuple[Any, ...]] fine?

def _gufunc_to_out_shape(
    signature: str, shapes
) -> list[tuple[Any, ...]]:
    """
    Compute the shape of the output of an Op given its gufunc signature and the
    shapes of its inputs.

    Parameters
    ----------
    signature : str
        The gufunc signature of the Op.
        eg: "(m,n),(n,p)->(m,p)".

    shapes : list of tuple of Any
        The list of shapes of the inputs.

    Returns
    -------
    out_shape : list of tuple of Any
        The list of shapes of the outputs.

    Raises
    ------
    ValueError
        If the signature is invalid for the shapes of the inputs.
    """
    input_sig, output_sig = _parse_gufunc_signature(signature)
    dim_to_size : dict[str, Any] = {}
    for input_shape, sig in zip(shapes, input_sig, strict = True):
        for size, dim_name in zip(input_shape, sig, strict=True):
            prev_size = dim_to_size.get(dim_name)
            if prev_size is None:
                dim_to_size[dim_name] = size
            # Prefer constants
            elif not isinstance(prev_size, Constant):
                dim_to_size[dim_name] = size
            elif prev_size.data != size:
                raise ValueError(
                    f"Invalid signature {signature} for shapes {shapes}. "
                    f"Dimension {dim_name} is not consistent across inputs."
                )
    out_shapes = []
    for output_shape in output_sig:
        temp_list = []
        for dim in output_shape:
            if dim not in dim_to_size:
                raise ValueError(
                    f"Invalid signature {signature} for shapes {shapes}. "
                    f"Dimension {dim} not in input dimensions."
                )
            else:
                temp_list.append(dim_to_size[dim])
        out_shapes.append((*temp_list,))       
    return out_shapes
    

…ntations in subclasses
…er_shape in Op class
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement infer_shape automatically from gufunc_signature
2 participants