Skip to content

Commit 2d9254d

Browse files
committed
TST: add workarounds for funcs which do not support 1j<op>float_array
1 parent 06a5351 commit 2d9254d

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

array_api_strict/tests/test_array_object.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT):
108108
# - a Python int or float for real floating-point array dtypes
109109
# - a Python int, float, or complex for complex floating-point array dtypes
110110

111+
# an exception: complex scalar <op> floating array
112+
scalar_types_for_float = [float, int]
113+
if not (func_name.startswith("__i")
114+
or (func_name in ["__floordiv__", "__rfloordiv__", "__mod__", "__rmod__"]
115+
and type(s) == complex)
116+
):
117+
scalar_types_for_float += [complex]
118+
111119
if ((dtypes == "all"
112120
or dtypes == "numeric" and a.dtype in _numeric_dtypes
113121
or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes
@@ -121,7 +129,7 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT):
121129
# isinstance here.
122130
and (a.dtype in _boolean_dtypes and type(s) == bool
123131
or a.dtype in _integer_dtypes and type(s) == int
124-
or a.dtype in _real_floating_dtypes and type(s) in [float, int]
132+
or a.dtype in _real_floating_dtypes and type(s) in scalar_types_for_float
125133
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
126134
)):
127135
if a.dtype in _integer_dtypes and s == BIG_INT:

array_api_strict/tests/test_elementwise_functions.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,25 @@ def _array_vals():
233233
if nargs(func) != 2:
234234
continue
235235

236+
nocomplex = [
237+
'atan2', 'copysign', 'floor_divide', 'hypot', 'logaddexp', 'nextafter',
238+
'remainder',
239+
'greater', 'less', 'greater_equal', 'less_equal', 'maximum', 'minimum',
240+
]
241+
236242
for s in [1, 1.0, 1j, BIG_INT, False]:
237243
for a in _array_vals():
238244
for func1 in [lambda s: func(a, s), lambda s: func(s, a)]:
239-
allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name)
245+
246+
if func_name in nocomplex and type(s) == complex:
247+
allowed = False
248+
else:
249+
allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name)
240250

241251
# only check `func(array, scalar) == `func(array, array)` if
242252
# the former is legal under the promotion rules
243253
if allowed:
244-
conv_scalar = asarray(s, dtype=a.dtype)
254+
conv_scalar = a._promote_scalar(s)
245255

246256
with suppress_warnings() as sup:
247257
# ignore warnings from pow(BIG_INT)

0 commit comments

Comments
 (0)