2
2
from typing import Union
3
3
4
4
import pytest
5
- from hypothesis import given
5
+ from hypothesis import given , assume
6
6
from hypothesis import strategies as st
7
7
8
8
from . import _array_module as xp
@@ -23,26 +23,43 @@ def float32(n: Union[int, float]) -> float:
23
23
return struct .unpack ("!f" , struct .pack ("!f" , float (n )))[0 ]
24
24
25
25
26
+ def _float_match_complex (complex_dtype ):
27
+ return xp .float32 if complex_dtype == xp .complex64 else xp .float64
28
+
29
+
26
30
@given (
27
- x_dtype = non_complex_dtypes () ,
28
- dtype = non_complex_dtypes () ,
31
+ x_dtype = hh . all_dtypes ,
32
+ dtype = hh . all_dtypes ,
29
33
kw = hh .kwargs (copy = st .booleans ()),
30
34
data = st .data (),
31
35
)
32
36
def test_astype (x_dtype , dtype , kw , data ):
37
+ _complex_dtypes = (xp .complex64 , xp .complex128 )
38
+
33
39
if xp .bool in (x_dtype , dtype ):
34
40
elements_strat = hh .from_dtype (x_dtype )
35
41
else :
36
- m1 , M1 = dh .dtype_ranges [x_dtype ]
37
- m2 , M2 = dh .dtype_ranges [dtype ]
42
+
38
43
if dh .is_int_dtype (x_dtype ):
39
44
cast = int
40
- elif x_dtype == xp .float32 :
45
+ elif x_dtype in ( xp .float32 , xp . complex64 ) :
41
46
cast = float32
42
47
else :
43
48
cast = float
49
+
50
+ real_dtype = x_dtype
51
+ if x_dtype in _complex_dtypes :
52
+ real_dtype = _float_match_complex (x_dtype )
53
+ m1 , M1 = dh .dtype_ranges [real_dtype ]
54
+
55
+ real_dtype = dtype
56
+ if dtype in _complex_dtypes :
57
+ real_dtype = _float_match_complex (x_dtype )
58
+ m2 , M2 = dh .dtype_ranges [real_dtype ]
59
+
44
60
min_value = cast (max (m1 , m2 ))
45
61
max_value = cast (min (M1 , M2 ))
62
+
46
63
elements_strat = hh .from_dtype (
47
64
x_dtype ,
48
65
min_value = min_value ,
@@ -54,6 +71,11 @@ def test_astype(x_dtype, dtype, kw, data):
54
71
hh .arrays (dtype = x_dtype , shape = hh .shapes (), elements = elements_strat ), label = "x"
55
72
)
56
73
74
+ # according to the spec, "Casting a complex floating-point array to a real-valued
75
+ # data type should not be permitted."
76
+ # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
77
+ assume (not ((x_dtype in _complex_dtypes ) and (dtype not in _complex_dtypes )))
78
+
57
79
out = xp .astype (x , dtype , ** kw )
58
80
59
81
ph .assert_kw_dtype ("astype" , kw_dtype = dtype , out_dtype = out .dtype )
0 commit comments