Skip to content

Commit f1c3ed2

Browse files
authored
Merge pull request #311 from ev-br/astype_complex
ENH: test astype with complex inputs
2 parents a71b4c0 + 3f73913 commit f1c3ed2

File tree

1 file changed

+28
-6
lines changed

1 file changed

+28
-6
lines changed

array_api_tests/test_data_type_functions.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Union
33

44
import pytest
5-
from hypothesis import given
5+
from hypothesis import given, assume
66
from hypothesis import strategies as st
77

88
from . import _array_module as xp
@@ -23,26 +23,43 @@ def float32(n: Union[int, float]) -> float:
2323
return struct.unpack("!f", struct.pack("!f", float(n)))[0]
2424

2525

26+
def _float_match_complex(complex_dtype):
27+
return xp.float32 if complex_dtype == xp.complex64 else xp.float64
28+
29+
2630
@given(
27-
x_dtype=non_complex_dtypes(),
28-
dtype=non_complex_dtypes(),
31+
x_dtype=hh.all_dtypes,
32+
dtype=hh.all_dtypes,
2933
kw=hh.kwargs(copy=st.booleans()),
3034
data=st.data(),
3135
)
3236
def test_astype(x_dtype, dtype, kw, data):
37+
_complex_dtypes = (xp.complex64, xp.complex128)
38+
3339
if xp.bool in (x_dtype, dtype):
3440
elements_strat = hh.from_dtype(x_dtype)
3541
else:
36-
m1, M1 = dh.dtype_ranges[x_dtype]
37-
m2, M2 = dh.dtype_ranges[dtype]
42+
3843
if dh.is_int_dtype(x_dtype):
3944
cast = int
40-
elif x_dtype == xp.float32:
45+
elif x_dtype in (xp.float32, xp.complex64):
4146
cast = float32
4247
else:
4348
cast = float
49+
50+
real_dtype = x_dtype
51+
if x_dtype in _complex_dtypes:
52+
real_dtype = _float_match_complex(x_dtype)
53+
m1, M1 = dh.dtype_ranges[real_dtype]
54+
55+
real_dtype = dtype
56+
if dtype in _complex_dtypes:
57+
real_dtype = _float_match_complex(x_dtype)
58+
m2, M2 = dh.dtype_ranges[real_dtype]
59+
4460
min_value = cast(max(m1, m2))
4561
max_value = cast(min(M1, M2))
62+
4663
elements_strat = hh.from_dtype(
4764
x_dtype,
4865
min_value=min_value,
@@ -54,6 +71,11 @@ def test_astype(x_dtype, dtype, kw, data):
5471
hh.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x"
5572
)
5673

74+
# according to the spec, "Casting a complex floating-point array to a real-valued
75+
# data type should not be permitted."
76+
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
77+
assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes)))
78+
5779
out = xp.astype(x, dtype, **kw)
5880

5981
ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype)

0 commit comments

Comments
 (0)