Skip to content

Commit 4d74d13

Browse files
brendan-m-murphyricardoV94
authored andcommitted
Updated doctests
From numpy PR numpy/numpy#22449, the repr of scalar values has changed, e.g. from "1" to "np.int64(1)", which caused two doctests to fail.
1 parent 999a62c commit 4d74d13

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

pytensor/tensor/einsum.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _general_dot(
256256
257257
.. testoutput::
258258
259-
(3, 4, 2)
259+
(np.int64(3), np.int64(4), np.int64(2))
260260
"""
261261
# Shortcut for non batched case
262262
if not batch_axes[0] and not batch_axes[1]:

pytensor/tensor/subtensor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -757,13 +757,15 @@ def get_constant_idx(
757757
Example usage where `v` and `a` are appropriately typed PyTensor variables :
758758
>>> from pytensor.scalar import int64
759759
>>> from pytensor.tensor import matrix
760+
>>> import numpy as np
761+
>>>
760762
>>> v = int64("v")
761763
>>> a = matrix("a")
762764
>>> b = a[v, 1:3]
763765
>>> b.owner.op.idx_list
764766
(ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None))
765767
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
766-
[v, slice(1, 3, None)]
768+
[v, slice(np.int64(1), np.int64(3), None)]
767769
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
768770
Traceback (most recent call last):
769771
pytensor.tensor.exceptions.NotScalarConstantError

0 commit comments

Comments
 (0)