Skip to content

Commit fe260b8

Browse files
authored
Merge pull request #74 from asmeurer/sign-fix
Fix the definition of sign() for complex numbers
2 parents e775240 + a0161d0 commit fe260b8

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

array_api_strict/_elementwise_functions.py

+2
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,8 @@ def sign(x: Array, /) -> Array:
855855
"""
856856
if x.dtype not in _numeric_dtypes:
857857
raise TypeError("Only numeric dtypes are allowed in sign")
858+
if x.dtype in _complex_floating_dtypes:
859+
return x/abs(x)
858860
return Array._new(np.sign(x._array), device=x.device)
859861

860862

0 commit comments

Comments
 (0)