Skip to content

Commit 4caff28

Browse files
authored
Merge pull request #283 from cbourjau/fix-default-complex
Fix way to determine default_complex
2 parents db95e67 + 4da61e5 commit 4caff28

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
322322
default_float = xp.asarray(float()).dtype
323323
if default_float not in real_float_dtypes:
324324
warn(f"inferred default float is {default_float!r}, which is not a float")
325-
if api_version > "2021.12":
325+
if api_version > "2021.12" and ({'complex64', 'complex128'} - set(skip_dtypes)):
326326
default_complex = xp.asarray(complex()).dtype
327327
if default_complex not in complex_dtypes:
328328
warn(

array_api_tests/hypothesis_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,11 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes:
186186
real_floating_dtypes = sampled_from(dh.real_float_dtypes)
187187
numeric_dtypes = sampled_from(dh.numeric_dtypes)
188188
# Note: this always returns complex dtypes, even if api_version < 2022.12
189-
complex_dtypes = sampled_from(dh.complex_dtypes)
189+
complex_dtypes: SearchStrategy[Any] | None = sampled_from(dh.complex_dtypes) if dh.complex_dtypes else None
190190

191191
def all_floating_dtypes() -> SearchStrategy[DataType]:
192192
strat = floating_dtypes
193-
if api_version >= "2022.12":
193+
if api_version >= "2022.12" and complex_dtypes is not None:
194194
strat |= complex_dtypes
195195
return strat
196196

0 commit comments

Comments
 (0)