Skip to content

Commit 74dc5bf

Browse files
committed
TST: add workarounds for funcs which do not support 1j<op>float_array
1 parent 06a5351 commit 74dc5bf

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

array_api_strict/tests/test_array_object.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,15 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT):
108108
# - a Python int or float for real floating-point array dtypes
109109
# - a Python int, float, or complex for complex floating-point array dtypes
110110

111+
# an exception: complex scalar <op> floating array
112+
scalar_types_for_float = [float, int]
113+
if not func_name.startswith("__i"):
114+
scalar_types_for_float += [complex]
115+
116+
# real // complex is not really defined
117+
if ("floordiv" in func_name or "mod" in func_name) and type(s) == complex:
118+
return False
119+
111120
if ((dtypes == "all"
112121
or dtypes == "numeric" and a.dtype in _numeric_dtypes
113122
or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes
@@ -121,7 +130,7 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT):
121130
# isinstance here.
122131
and (a.dtype in _boolean_dtypes and type(s) == bool
123132
or a.dtype in _integer_dtypes and type(s) == int
124-
or a.dtype in _real_floating_dtypes and type(s) in [float, int]
133+
or a.dtype in _real_floating_dtypes and type(s) in scalar_types_for_float
125134
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
126135
)):
127136
if a.dtype in _integer_dtypes and s == BIG_INT:

array_api_strict/tests/test_elementwise_functions.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,20 @@ def _array_vals():
236236
for s in [1, 1.0, 1j, BIG_INT, False]:
237237
for a in _array_vals():
238238
for func1 in [lambda s: func(a, s), lambda s: func(s, a)]:
239-
allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name)
239+
240+
if (func_name in ['atan2', 'copysign', 'floor_divide',
241+
'greater', 'less', 'greater_equal', 'less_equal', 'hypot',
242+
'logaddexp', 'maximum', 'minimum', 'nextafter', 'remainder'] and
243+
type(s) == complex
244+
):
245+
allowed = False
246+
else:
247+
allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name)
240248

241249
# only check `func(array, scalar) == `func(array, array)` if
242250
# the former is legal under the promotion rules
243251
if allowed:
244-
conv_scalar = asarray(s, dtype=a.dtype)
252+
conv_scalar = a._promote_scalar(s)
245253

246254
with suppress_warnings() as sup:
247255
# ignore warnings from pow(BIG_INT)

0 commit comments

Comments
 (0)