Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions dpnp/dpnp_iface_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import dpnp.backend.extensions.ufunc._ufunc_impl as ufi
from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc

from .dpnp_array import dpnp_array
from .dpnp_utils import get_usm_allocations

__all__ = [
Expand All @@ -69,6 +70,7 @@
"iscomplexobj",
"isfinite",
"isfortran",
"isin",
"isinf",
"isnan",
"isneginf",
Expand Down Expand Up @@ -1196,6 +1198,120 @@ def isfortran(a):
return a.flags.fnc


def isin(
element,
test_elements,
assume_unique=False, # pylint: disable=unused-argument
invert=False,
*,
kind=None, # pylint: disable=unused-argument
):
"""
Calculates ``element in test_elements``, broadcasting over `element` only.
Returns a boolean array of the same shape as `element` that is ``True``
where an element of `element` is in `test_elements` and ``False``
otherwise.

For full documentation refer to :obj:`numpy.isin`.

Parameters
----------
element : {dpnp.ndarray, usm_ndarray, scalar}
Input array.
test_elements : {dpnp.ndarray, usm_ndarray, scalar}
The values against which to test each value of `element`.
This argument is flattened if it is an array.
assume_unique : bool, optional
Ignored, as no performance benefit is gained by assuming the
input arrays are unique. Included for compatibility with NumPy.

Default: ``False``.
invert : bool, optional
If ``True``, the values in the returned array are inverted, as if
calculating ``element not in test_elements``.
``dpnp.isin(a, b, invert=True)`` is equivalent to (but faster
than) ``dpnp.invert(dpnp.isin(a, b))``.

Default: ``False``.
kind : {None, "sort"}, optional
Ignored, as the only algorithm implemented is ``"sort"``. Included for
compatibility with NumPy.

Default: ``None``.

Returns
-------
isin : dpnp.ndarray of bool dtype
Has the same shape as `element`. The values `element[isin]`
are in `test_elements`.

Examples
--------
>>> import dpnp as np
>>> element = 2*np.arange(4).reshape((2, 2))
>>> element
array([[0, 2],
[4, 6]])
>>> test_elements = [1, 2, 4, 8]
>>> mask = np.isin(element, test_elements)
>>> mask
array([[False, True],
[ True, False]])
>>> element[mask]
array([2, 4])

The indices of the matched values can be obtained with `nonzero`:

>>> np.nonzero(mask)
(array([0, 1]), array([1, 0]))

The test can also be inverted:

>>> mask = np.isin(element, test_elements, invert=True)
>>> mask
array([[ True, False],
[False, True]])
>>> element[mask]
array([0, 6])

"""

dpnp.check_supported_arrays_type(element, test_elements, scalar_type=True)
if dpnp.isscalar(element):
usm_element = dpnp.as_usm_ndarray(
element,
usm_type=test_elements.usm_type,
sycl_queue=test_elements.sycl_queue,
)
usm_test = dpnp.get_usm_ndarray(test_elements)
elif dpnp.isscalar(test_elements):
usm_test = dpnp.as_usm_ndarray(
test_elements,
usm_type=element.usm_type,
sycl_queue=element.sycl_queue,
)
usm_element = dpnp.get_usm_ndarray(element)
else:
if (
dpu.get_execution_queue(
(element.sycl_queue, test_elements.sycl_queue)
)
is None
):
raise dpu.ExecutionPlacementError(
"Input arrays have incompatible allocation queues"
)
usm_element = dpnp.get_usm_ndarray(element)
usm_test = dpnp.get_usm_ndarray(test_elements)
return dpnp_array._create_from_usm_ndarray(
dpt.isin(
usm_element,
usm_test,
invert=invert,
)
)


_ISINF_DOCSTRING = """
Tests each element :math:`x_i` of the input array `x` to determine if equal to
positive or negative infinity.
Expand Down
102 changes: 102 additions & 0 deletions dpnp/tests/test_logic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import dpctl
import numpy
import pytest
from dpctl.utils import ExecutionPlacementError
from numpy.testing import (
assert_allclose,
assert_array_equal,
Expand Down Expand Up @@ -795,3 +797,103 @@ def test_array_equal_nan(a):
result = dpnp.array_equal(dpnp.array(a), dpnp.array(b), equal_nan=True)
expected = numpy.array_equal(a, b, equal_nan=True)
assert_equal(result, expected)


class TestIsin:
@pytest.mark.parametrize(
"a",
[
numpy.array([1, 2, 3, 4]),
numpy.array([[1, 2], [3, 4]]),
],
)
@pytest.mark.parametrize(
"b",
[
numpy.array([2, 4, 6]),
numpy.array([[1, 3], [5, 7]]),
],
)
def test_isin_basic(self, a, b):
dp_a = dpnp.array(a)
dp_b = dpnp.array(b)

expected = numpy.isin(a, b)
result = dpnp.isin(dp_a, dp_b)
assert_equal(result, expected)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
def test_isin_dtype(self, dtype):
a = numpy.array([1, 2, 3, 4], dtype=dtype)
b = numpy.array([2, 4], dtype=dtype)

dp_a = dpnp.array(a, dtype=dtype)
dp_b = dpnp.array(b, dtype=dtype)

expected = numpy.isin(a, b)
result = dpnp.isin(dp_a, dp_b)
assert_equal(result, expected)

@pytest.mark.parametrize(
"sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))]
)
def test_isin_broadcast(self, sh_a, sh_b):
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)

dp_a = dpnp.array(a)
dp_b = dpnp.array(b)

expected = numpy.isin(a, b)
result = dpnp.isin(dp_a, dp_b)
assert_equal(result, expected)

def test_isin_scalar_elements(self):
a = numpy.array([1, 2, 3])
b = 2

dp_a = dpnp.array(a)
dp_b = dpnp.array(b)

expected = numpy.isin(a, b)
result = dpnp.isin(dp_a, dp_b)
assert_equal(result, expected)

def test_isin_scalar_test_elements(self):
a = 2
b = numpy.array([1, 2, 3])

dp_a = dpnp.array(a)
dp_b = dpnp.array(b)

expected = numpy.isin(a, b)
result = dpnp.isin(dp_a, dp_b)
assert_equal(result, expected)

def test_isin_empty(self):
a = numpy.array([], dtype=int)
b = numpy.array([1, 2, 3])

dp_a = dpnp.array(a)
dp_b = dpnp.array(b)

expected = numpy.isin(a, b)
result = dpnp.isin(dp_a, dp_b)
assert_equal(result, expected)

def test_isin_errors(self):
q1 = dpctl.SyclQueue()
q2 = dpctl.SyclQueue()

a = dpnp.arange(5, sycl_queue=q1)
b = dpnp.arange(3, sycl_queue=q2)

# unsupported type for elements or test_elements
with pytest.raises(TypeError):
dpnp.isin(dict(), a)

with pytest.raises(TypeError):
dpnp.isin(a, dict())

with pytest.raises(ExecutionPlacementError):
dpnp.isin(a, b)
1 change: 1 addition & 0 deletions dpnp/tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ def test_logic_op_1in(op, device):
"greater",
"greater_equal",
"isclose",
"isin",
"less",
"less_equal",
"logical_and",
Expand Down
1 change: 1 addition & 0 deletions dpnp/tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def test_logic_op_1in(op, usm_type_x):
"greater",
"greater_equal",
"isclose",
"isin",
"less",
"less_equal",
"logical_and",
Expand Down
1 change: 0 additions & 1 deletion dpnp/tests/third_party/cupy/logic_tests/test_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def test_with_out(self, xp, dtype):
return out


@pytest.mark.skip("isin() is not supported yet")
@testing.parameterize(
*testing.product(
{
Expand Down
Loading