Skip to content

Commit bf1e1e6

Browse files
committed
_at meta_xp
1 parent 0b7fce1 commit bf1e1e6

File tree

1 file changed

+12
-2
lines changed
  • src/array_api_extra/_lib

1 file changed

+12
-2
lines changed

src/array_api_extra/_lib/_at.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
is_jax_array,
1616
is_writeable_array,
1717
)
18+
from ._utils._helpers import meta_namespace
1819
from ._utils._typing import Array, Index
1920

2021

@@ -419,9 +420,16 @@ def min(
419420
xp: ModuleType | None = None,
420421
) -> Array: # numpydoc ignore=PR01,RT01
421422
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
423+
# On Dask, this function runs on the chunks, so we need to determine the
424+
# namespace that Dask is wrapping.
425+
# Note that da.minimum _incidentally_ works on numpy, cupy, and sparse
426+
# thanks to all these meta-namespaces implementing the __array_ufunc__
427+
# interface, but there's no guarantee that it will work for other
428+
# wrapped libraries in the future.
422429
xp = array_namespace(self._x) if xp is None else xp
430+
mxp = meta_namespace(self._x, xp=xp)
423431
y = xp.asarray(y)
424-
return self._op(_AtOp.MIN, xp.minimum, xp.minimum, y, copy=copy, xp=xp)
432+
return self._op(_AtOp.MIN, mxp.minimum, mxp.minimum, y, copy=copy, xp=xp)
425433

426434
def max(
427435
self,
@@ -431,6 +439,8 @@ def max(
431439
xp: ModuleType | None = None,
432440
) -> Array: # numpydoc ignore=PR01,RT01
433441
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
442+
# See note on min()
434443
xp = array_namespace(self._x) if xp is None else xp
444+
mxp = meta_namespace(self._x, xp=xp)
435445
y = xp.asarray(y)
436-
return self._op(_AtOp.MAX, xp.maximum, xp.maximum, y, copy=copy, xp=xp)
446+
return self._op(_AtOp.MAX, mxp.maximum, mxp.maximum, y, copy=copy, xp=xp)

0 commit comments

Comments
 (0)