15
15
is_jax_array ,
16
16
is_writeable_array ,
17
17
)
18
+ from ._utils ._helpers import meta_namespace
18
19
from ._utils ._typing import Array , Index
19
20
20
21
@@ -419,9 +420,16 @@ def min(
419
420
xp : ModuleType | None = None ,
420
421
) -> Array : # numpydoc ignore=PR01,RT01
421
422
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
423
+ # On Dask, this function runs on the chunks, so we need to determine the
424
+ # namespace that Dask is wrapping.
425
+ # Note that da.minimum _incidentally_ works on numpy, cupy, and sparse
426
+ # thanks to all these meta-namespaces implementing the __array_ufunc__
427
+ # interface, but there's no guarantee that it will work for other
428
+ # wrapped libraries in the future.
422
429
xp = array_namespace (self ._x ) if xp is None else xp
430
+ mxp = meta_namespace (self ._x , xp = xp )
423
431
y = xp .asarray (y )
424
- return self ._op (_AtOp .MIN , xp .minimum , xp .minimum , y , copy = copy , xp = xp )
432
+ return self ._op (_AtOp .MIN , mxp .minimum , mxp .minimum , y , copy = copy , xp = xp )
425
433
426
434
def max (
427
435
self ,
@@ -431,6 +439,8 @@ def max(
431
439
xp : ModuleType | None = None ,
432
440
) -> Array : # numpydoc ignore=PR01,RT01
433
441
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
442
+ # See note on min()
434
443
xp = array_namespace (self ._x ) if xp is None else xp
444
+ mxp = meta_namespace (self ._x , xp = xp )
435
445
y = xp .asarray (y )
436
- return self ._op (_AtOp .MAX , xp .maximum , xp .maximum , y , copy = copy , xp = xp )
446
+ return self ._op (_AtOp .MAX , mxp .maximum , mxp .maximum , y , copy = copy , xp = xp )
0 commit comments