We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 7ec0b82 + e61c31f commit e3e9a83Copy full SHA for e3e9a83
src/array_api_extra/_lib/_funcs.py
@@ -580,6 +580,7 @@ def setdiff1d(
580
581
if assume_unique:
582
x1 = xp.reshape(x1, (-1,))
583
+ x2 = xp.reshape(x2, (-1,))
584
else:
585
x1 = xp.unique_values(x1)
586
x2 = xp.unique_values(x2)
tests/test_funcs.py
@@ -579,6 +579,21 @@ def test_assume_unique(self, xp: ModuleType):
579
actual = setdiff1d(x1, x2, assume_unique=True)
xp_assert_equal(actual, expected)
+ @pytest.mark.parametrize("assume_unique", [True, False])
+ @pytest.mark.parametrize("shape1", [(), (1,), (1, 1)])
+ @pytest.mark.parametrize("shape2", [(), (1,), (1, 1)])
+ def test_shapes(
+ self,
587
+ assume_unique: bool,
588
+ shape1: tuple[int, ...],
589
+ shape2: tuple[int, ...],
590
+ xp: ModuleType,
591
+ ):
592
+ x1 = xp.zeros(shape1)
593
+ x2 = xp.zeros(shape2)
594
+ actual = setdiff1d(x1, x2, assume_unique=assume_unique)
595
+ xp_assert_equal(actual, xp.empty((0,)))
596
+
597
def test_device(self, xp: ModuleType, device: Device):
598
x1 = xp.asarray([3, 8, 20], device=device)
599
x2 = xp.asarray([2, 3, 4], device=device)
0 commit comments