-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
get_vector_length
incorrectly returns for shared variable without static shape
#1239
Comments
@ricardoV94 This issue arises due to this registration in @_get_vector_length.register(TensorSharedVariable)
def _get_vector_length_TensorSharedVariable(var_inst, var):
return len(var.get_value(borrow=True)) How should I move forward? Should I remove this registration (which will cause the snipped you have provided to raise an error)? Or should it be left as is? |
Yes we should remove that but make sure that if the shape of the shared variable is static it still works. I suspect we don't need to do anything and and the dispatch on the base class handles that case |
@ricardoV94 Should not be an issue, the implementation of def get_vector_length(v: TensorLike) -> int:
"""Return the run-time length of a symbolic vector, when possible.
Parameters
----------
v
A rank-1 `TensorType` variable.
Raises
------
TypeError
`v` hasn't the proper type.
ValueError
No special case applies, the length is not known.
In general this is not possible, but for a number of special cases
the length can be determined at compile / graph-construction time.
This function implements these special cases.
"""
v = as_tensor_variable(v)
if v.type.ndim != 1:
raise TypeError(f"Argument must be a vector; got {v.type}")
static_shape: int | None = v.type.shape[0]
if static_shape is not None:
return static_shape
return _get_vector_length(getattr(v.owner, "op", v), v) A small test for the same: import pytensor
import numpy as np
from pytensor import tensor as pt
val = np.zeros(3, dtype='float32')
x = pytensor.shared(val, name='x', shape = val.shape, strict=True)
print(pt.get_vector_length(x)) # returns 3
pt.get_vector_length(pytensor.shared(val))# Raises ValueError : "ValueError: Length of <Vector(float32, shape=(?,))> cannot be determined" |
Description
It should raise because the variable has
type.shape=(None,)
The text was updated successfully, but these errors were encountered: