We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 6dd6172 commit a0a494aCopy full SHA for a0a494a
pytensor/tensor/basic.py
@@ -710,6 +710,17 @@ def c_code_cache_version(self):
710
scalar_from_tensor = ScalarFromTensor()
711
712
713
+@_vectorize_node.register(ScalarFromTensor)
714
+def vectorize_scalar_from_tensor(op, node, batch_x):
715
+ if batch_x.ndim == 0:
716
+ return scalar_from_tensor(batch_x).owner
717
+ if batch_x.owner is not None:
718
+ return batch_x.owner
719
+
720
+ # Needed until we fix https://github.com/pymc-devs/pytensor/issues/902
721
+ return batch_x.copy().owner
722
723
724
# to be removed as we get the epydoc routine-documenting thing going
725
# -JB 20080924
726
def _conversion(real_value: Op, name: str) -> Op:
0 commit comments