7
7
import pytest
8
8
9
9
from array_api_extra ._lib ._backends import Backend
10
- from array_api_extra ._lib ._testing import xp_assert_close , xp_assert_equal
10
+ from array_api_extra ._lib ._testing import (
11
+ xp_assert_close ,
12
+ xp_assert_equal ,
13
+ xp_assert_less ,
14
+ )
11
15
from array_api_extra ._lib ._utils ._compat import (
12
16
array_namespace ,
13
17
is_dask_namespace ,
23
27
"func" ,
24
28
[
25
29
xp_assert_equal ,
30
+ xp_assert_less ,
26
31
pytest .param (
27
32
xp_assert_close ,
28
33
marks = pytest .mark .xfail_xp_backend (
33
38
)
34
39
35
40
36
- @param_assert_equal_close
41
+ @pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" , strict = False )
42
+ @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close ])
37
43
def test_assert_close_equal_basic (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
38
44
func (xp .asarray (0 ), xp .asarray (0 ))
39
45
func (xp .asarray ([1 , 2 ]), xp .asarray ([1 , 2 ]))
@@ -53,8 +59,8 @@ def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): #
53
59
54
60
@pytest .mark .skip_xp_backend (Backend .NUMPY , reason = "test other ns vs. numpy" )
55
61
@pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "test other ns vs. numpy" )
56
- @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close ])
57
- def test_assert_close_equal_namespace (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
62
+ @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close , xp_assert_less ])
63
+ def test_assert_close_equal_less_namespace (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
58
64
with pytest .raises (AssertionError , match = "namespaces do not match" ):
59
65
func (xp .asarray (0 ), np .asarray (0 ))
60
66
with pytest .raises (TypeError , match = "Unrecognized array input" ):
@@ -65,7 +71,7 @@ def test_assert_close_equal_namespace(xp: ModuleType, func: Callable[..., None])
65
71
66
72
@param_assert_equal_close
67
73
@pytest .mark .parametrize ("check_shape" , [False , True ])
68
- def test_assert_close_equal_shape ( # type: ignore[explicit-any]
74
+ def test_assert_close_equal_less_shape ( # type: ignore[explicit-any]
69
75
xp : ModuleType ,
70
76
func : Callable [..., None ],
71
77
check_shape : bool ,
@@ -76,12 +82,12 @@ def test_assert_close_equal_shape( # type: ignore[explicit-any]
76
82
else nullcontext ()
77
83
)
78
84
with context :
79
- func (xp .asarray ([0 , 0 ]), xp .asarray (0 ), check_shape = check_shape )
85
+ func (xp .asarray ([xp . nan , xp . nan ]), xp .asarray (xp . nan ), check_shape = check_shape )
80
86
81
87
82
88
@param_assert_equal_close
83
89
@pytest .mark .parametrize ("check_dtype" , [False , True ])
84
- def test_assert_close_equal_dtype ( # type: ignore[explicit-any]
90
+ def test_assert_close_equal_less_dtype ( # type: ignore[explicit-any]
85
91
xp : ModuleType ,
86
92
func : Callable [..., None ],
87
93
check_dtype : bool ,
@@ -92,12 +98,17 @@ def test_assert_close_equal_dtype( # type: ignore[explicit-any]
92
98
else nullcontext ()
93
99
)
94
100
with context :
95
- func (xp .asarray (0.0 ), xp .asarray (0 ), check_dtype = check_dtype )
101
+ func (
102
+ xp .asarray (xp .nan , dtype = xp .float32 ),
103
+ xp .asarray (xp .nan , dtype = xp .float64 ),
104
+ check_dtype = check_dtype ,
105
+ )
96
106
97
107
98
- @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close ])
108
+ @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close , xp_assert_less ])
99
109
@pytest .mark .parametrize ("check_scalar" , [False , True ])
100
- def test_assert_close_equal_scalar ( # type: ignore[explicit-any]
110
+ def test_assert_close_equal_less_scalar ( # type: ignore[explicit-any]
111
+ xp : ModuleType ,
101
112
func : Callable [..., None ],
102
113
check_scalar : bool ,
103
114
):
@@ -107,7 +118,7 @@ def test_assert_close_equal_scalar( # type: ignore[explicit-any]
107
118
else nullcontext ()
108
119
)
109
120
with context :
110
- func (np .asarray (0 ), np .asarray (0 )[()], check_scalar = check_scalar )
121
+ func (np .asarray (xp . nan ), np .asarray (xp . nan )[()], check_scalar = check_scalar )
111
122
112
123
113
124
@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "no isdtype" )
@@ -121,9 +132,18 @@ def test_assert_close_tolerance(xp: ModuleType):
121
132
xp_assert_close (xp .asarray ([100.0 ]), xp .asarray ([102.0 ]), atol = 1 )
122
133
123
134
124
- @param_assert_equal_close
135
+ def test_assert_less_basic (xp : ModuleType ):
136
+ xp_assert_less (xp .asarray (- 1 ), xp .asarray (0 ))
137
+ xp_assert_less (xp .asarray ([1 , 2 ]), xp .asarray ([2 , 3 ]))
138
+ with pytest .raises (AssertionError ):
139
+ xp_assert_less (xp .asarray ([1 , 1 ]), xp .asarray ([2 , 1 ]))
140
+ with pytest .raises (AssertionError , match = "hello" ):
141
+ xp_assert_less (xp .asarray ([1 , 1 ]), xp .asarray ([2 , 1 ]), err_msg = "hello" )
142
+
143
+
125
144
@pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "index by sparse array" )
126
145
@pytest .mark .skip_xp_backend (Backend .ARRAY_API_STRICTEST , reason = "boolean indexing" )
146
+ @pytest .mark .parametrize ("func" , [xp_assert_equal , xp_assert_close ])
127
147
def test_assert_close_equal_none_shape (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
128
148
"""On Dask and other lazy backends, test that a shape with NaN's or None's
129
149
can be compared to a real shape.
0 commit comments