Skip to content

Commit e3e9a83

Browse files
authored
Merge pull request #129 from crusaderky/setdiff1d
2 parents 7ec0b82 + e61c31f commit e3e9a83

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/array_api_extra/_lib/_funcs.py

+1
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@ def setdiff1d(
580580

581581
if assume_unique:
582582
x1 = xp.reshape(x1, (-1,))
583+
x2 = xp.reshape(x2, (-1,))
583584
else:
584585
x1 = xp.unique_values(x1)
585586
x2 = xp.unique_values(x2)

tests/test_funcs.py

+15
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,21 @@ def test_assume_unique(self, xp: ModuleType):
579579
actual = setdiff1d(x1, x2, assume_unique=True)
580580
xp_assert_equal(actual, expected)
581581

582+
@pytest.mark.parametrize("assume_unique", [True, False])
583+
@pytest.mark.parametrize("shape1", [(), (1,), (1, 1)])
584+
@pytest.mark.parametrize("shape2", [(), (1,), (1, 1)])
585+
def test_shapes(
586+
self,
587+
assume_unique: bool,
588+
shape1: tuple[int, ...],
589+
shape2: tuple[int, ...],
590+
xp: ModuleType,
591+
):
592+
x1 = xp.zeros(shape1)
593+
x2 = xp.zeros(shape2)
594+
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
595+
xp_assert_equal(actual, xp.empty((0,)))
596+
582597
def test_device(self, xp: ModuleType, device: Device):
583598
x1 = xp.asarray([3, 8, 20], device=device)
584599
x2 = xp.asarray([2, 3, 4], device=device)

0 commit comments

Comments
 (0)