Skip to content

Commit a0a494a

Browse files
committed
Vectorize ScalarFromTensor
1 parent 6dd6172 commit a0a494a

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

pytensor/tensor/basic.py

+11
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,17 @@ def c_code_cache_version(self):
710710
scalar_from_tensor = ScalarFromTensor()
711711

712712

713+
@_vectorize_node.register(ScalarFromTensor)
714+
def vectorize_scalar_from_tensor(op, node, batch_x):
715+
if batch_x.ndim == 0:
716+
return scalar_from_tensor(batch_x).owner
717+
if batch_x.owner is not None:
718+
return batch_x.owner
719+
720+
# Needed until we fix https://github.com/pymc-devs/pytensor/issues/902
721+
return batch_x.copy().owner
722+
723+
713724
# to be removed as we get the epydoc routine-documenting thing going
714725
# -JB 20080924
715726
def _conversion(real_value: Op, name: str) -> Op:

0 commit comments

Comments
 (0)