diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 54691d6..c11b17c 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -378,8 +378,8 @@ def conj(x: Array, /) -> Array: See its docstring for more information. """ - if x.dtype not in _complex_floating_dtypes: - raise TypeError("Only complex floating-point dtypes are allowed in conj") + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in conj") return Array._new(np.conj(x._array), device=x.device) @@ -568,8 +568,8 @@ def real(x: Array, /) -> Array: See its docstring for more information. """ - if x.dtype not in _complex_floating_dtypes: - raise TypeError("Only complex floating-point dtypes are allowed in real") + if x.dtype not in _numeric_dtypes: + raise TypeError("Only numeric dtypes are allowed in real") return Array._new(np.real(x._array), device=x.device) diff --git a/array_api_strict/tests/test_elementwise_functions.py b/array_api_strict/tests/test_elementwise_functions.py index 0b90f0b..f38cdb9 100644 --- a/array_api_strict/tests/test_elementwise_functions.py +++ b/array_api_strict/tests/test_elementwise_functions.py @@ -52,7 +52,7 @@ def nargs(func): "bitwise_xor": "integer or boolean", "ceil": "real numeric", "clip": "real numeric", - "conj": "complex floating-point", + "conj": "numeric", "copysign": "real floating-point", "cos": "floating-point", "cosh": "floating-point", @@ -88,7 +88,7 @@ def nargs(func): "not_equal": "all", "positive": "numeric", "pow": "numeric", - "real": "complex floating-point", + "real": "numeric", "reciprocal": "floating-point", "remainder": "real numeric", "round": "numeric",