Skip to content

Commit d982a62

Browse files
authored
Merge pull request #313 from cbourjau/guard-tests-no-complex-dtypes
Avoid calling hh.arrays with dtype=None for complex types
2 parents a882502 + c43c83a commit d982a62

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

Diff for: array_api_tests/hypothesis_helpers.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from hypothesis import assume, reject
1111
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
1212
integers, just, lists, none, one_of,
13-
sampled_from, shared, builds)
13+
sampled_from, shared, builds, nothing)
1414

1515
from . import _array_module as xp, api_version
1616
from . import array_helpers as ah
@@ -200,11 +200,11 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes:
200200
real_floating_dtypes = sampled_from(dh.real_float_dtypes)
201201
numeric_dtypes = sampled_from(dh.numeric_dtypes)
202202
# Note: this always returns complex dtypes, even if api_version < 2022.12
203-
complex_dtypes: SearchStrategy[Any] | None = sampled_from(dh.complex_dtypes) if dh.complex_dtypes else None
203+
complex_dtypes: SearchStrategy[Any] = sampled_from(dh.complex_dtypes) if dh.complex_dtypes else nothing()
204204

205205
def all_floating_dtypes() -> SearchStrategy[DataType]:
206206
strat = floating_dtypes
207-
if api_version >= "2022.12" and complex_dtypes is not None:
207+
if api_version >= "2022.12" and not complex_dtypes.is_empty:
208208
strat |= complex_dtypes
209209
return strat
210210

Diff for: array_api_tests/test_operators_and_elementwise_functions.py

+3
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,7 @@ def refimpl(_x, _min, _max):
10621062

10631063

10641064
@pytest.mark.min_version("2022.12")
1065+
@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from")
10651066
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
10661067
def test_conj(x):
10671068
out = xp.conj(x)
@@ -1264,6 +1265,7 @@ def test_hypot(x1, x2):
12641265

12651266

12661267
@pytest.mark.min_version("2022.12")
1268+
@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from")
12671269
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
12681270
def test_imag(x):
12691271
out = xp.imag(x)
@@ -1559,6 +1561,7 @@ def test_pow(ctx, data):
15591561

15601562

15611563
@pytest.mark.min_version("2022.12")
1564+
@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from")
15621565
@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
15631566
def test_real(x):
15641567
out = xp.real(x)

0 commit comments

Comments
 (0)