Skip to content

Commit 4c8c8b6

Browse files
brendan-m-murphyricardoV94
authored andcommitted
Preserve numpy < 2.0 Unique inverse output shape
In numpy 2.0, if axis=None, then np.unique does not flatten the inverse indices returned if return_inverse=True A helper function has been added to npy_2_compat.py to mimic the output of `np.unique` from version of numpy before 2.0
1 parent 4d74d13 commit 4c8c8b6

File tree

3 files changed

+47
-11
lines changed

3 files changed

+47
-11
lines changed

pytensor/npy_2_compat.py

+22
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,28 @@
6363
numpy_maxdims = 64 if using_numpy_2 else 32
6464

6565

66+
# function that replicates np.unique from numpy < 2.0
67+
def old_np_unique(
68+
arr, return_index=False, return_inverse=False, return_counts=False, axis=None
69+
):
70+
"""Replicate np.unique from numpy versions < 2.0"""
71+
if not return_inverse or not using_numpy_2:
72+
return np.unique(arr, return_index, return_inverse, return_counts, axis)
73+
74+
outs = list(np.unique(arr, return_index, return_inverse, return_counts, axis))
75+
76+
inv_idx = 2 if return_index else 1
77+
78+
if axis is None:
79+
outs[inv_idx] = np.ravel(outs[inv_idx])
80+
else:
81+
inv_shape = (arr.shape[axis],)
82+
outs[inv_idx] = outs[inv_idx].reshape(inv_shape)
83+
84+
return tuple(outs)
85+
86+
87+
# compatibility header for C code
6688
def npy_2_compat_header() -> str:
6789
"""Compatibility header that Numpy suggests is vendored with code that uses Numpy < 2.0 and Numpy 2.x"""
6890
return dedent("""

pytensor/tensor/extra_ops.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
normalize_axis_index,
2121
npy_2_compat_header,
2222
numpy_axis_is_none_flag,
23+
old_np_unique,
2324
)
2425
from pytensor.raise_op import Assert
2526
from pytensor.scalar import int64 as int_t
@@ -1226,6 +1227,9 @@ class Unique(Op):
12261227
"""
12271228
Wraps `numpy.unique`.
12281229
1230+
The indices returned when `return_inverse` is True are ravelled
1231+
to match the behavior of `numpy.unique` from before numpy version 2.0.
1232+
12291233
Examples
12301234
--------
12311235
>>> import numpy as np
@@ -1271,17 +1275,21 @@ def make_node(self, x):
12711275

12721276
outputs = [TensorType(dtype=x.dtype, shape=out_shape)()]
12731277
typ = TensorType(dtype="int64", shape=(None,))
1278+
12741279
if self.return_index:
12751280
outputs.append(typ())
1281+
12761282
if self.return_inverse:
12771283
outputs.append(typ())
1284+
12781285
if self.return_counts:
12791286
outputs.append(typ())
1287+
12801288
return Apply(self, [x], outputs)
12811289

12821290
def perform(self, node, inputs, output_storage):
12831291
[x] = inputs
1284-
outs = np.unique(
1292+
outs = old_np_unique(
12851293
x,
12861294
return_index=self.return_index,
12871295
return_inverse=self.return_inverse,
@@ -1306,9 +1314,14 @@ def infer_shape(self, fgraph, node, i0_shapes):
13061314
out_shapes[0] = tuple(shape)
13071315

13081316
if self.return_inverse:
1309-
shape = prod(x_shape) if self.axis is None else x_shape[axis]
13101317
return_index_out_idx = 2 if self.return_index else 1
1311-
out_shapes[return_index_out_idx] = (shape,)
1318+
1319+
if self.axis is not None:
1320+
shape = (x_shape[axis],)
1321+
else:
1322+
shape = (prod(x_shape),)
1323+
1324+
out_shapes[return_index_out_idx] = shape
13121325

13131326
return out_shapes
13141327

tests/tensor/test_extra_ops.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.compile.mode import Mode
1010
from pytensor.configdefaults import config
1111
from pytensor.graph.basic import Constant, applys_between, equal_computations
12+
from pytensor.npy_2_compat import old_np_unique
1213
from pytensor.raise_op import Assert
1314
from pytensor.tensor import alloc
1415
from pytensor.tensor.elemwise import DimShuffle
@@ -899,14 +900,14 @@ def setup_method(self):
899900
)
900901
def test_basic_vector(self, x, inp, axis):
901902
list_outs_expected = [
902-
np.unique(inp, axis=axis),
903-
np.unique(inp, True, axis=axis),
904-
np.unique(inp, False, True, axis=axis),
905-
np.unique(inp, True, True, axis=axis),
906-
np.unique(inp, False, False, True, axis=axis),
907-
np.unique(inp, True, False, True, axis=axis),
908-
np.unique(inp, False, True, True, axis=axis),
909-
np.unique(inp, True, True, True, axis=axis),
903+
old_np_unique(inp, axis=axis),
904+
old_np_unique(inp, True, axis=axis),
905+
old_np_unique(inp, False, True, axis=axis),
906+
old_np_unique(inp, True, True, axis=axis),
907+
old_np_unique(inp, False, False, True, axis=axis),
908+
old_np_unique(inp, True, False, True, axis=axis),
909+
old_np_unique(inp, False, True, True, axis=axis),
910+
old_np_unique(inp, True, True, True, axis=axis),
910911
]
911912
for params, outs_expected in zip(
912913
self.op_params, list_outs_expected, strict=True

0 commit comments

Comments
 (0)