Skip to content

Commit 55e3f71

Browse files
committed
Fix cupy sign nan handling
1 parent 2539057 commit 55e3f71

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

array_api_compat/common/_aliases.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import NamedTuple
1313
import inspect
1414

15-
from ._helpers import array_namespace, _check_device, device, is_torch_array
15+
from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace
1616

1717
# These functions are modified from the NumPy versions.
1818

@@ -541,7 +541,8 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
541541
out = xp.sign(x, **kwargs)
542542
# CuPy sign() does not propagate nans. See
543543
# https://github.com/data-apis/array-api-compat/issues/136
544-
out[xp.isnan(x)] = xp.nan
544+
if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
545+
out[xp.isnan(x)] = xp.nan
545546
return out[()]
546547

547548
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',

0 commit comments

Comments
 (0)