Skip to content

Commit f88f7a8

Browse files
authoredFeb 18, 2025
Merge pull request #121 from ev-br/cmplx_scalars
Define complex scalar <op> float array
2 parents 42dd9ce + 2d9254d commit f88f7a8

File tree

4 files changed

+31
-6
lines changed

4 files changed

+31
-6
lines changed
 

‎array_api_strict/_array_object.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@
2626
_integer_dtypes,
2727
_integer_or_boolean_dtypes,
2828
_floating_dtypes,
29+
_real_floating_dtypes,
2930
_complex_floating_dtypes,
3031
_numeric_dtypes,
3132
_result_type,
3233
_dtype_categories,
34+
_real_to_complex_map,
3335
)
3436
from ._flags import get_array_api_strict_flags, set_array_api_strict_flags
3537

@@ -243,6 +245,7 @@ def _promote_scalar(self, scalar):
243245
"""
244246
from ._data_type_functions import iinfo
245247

248+
target_dtype = self.dtype
246249
# Note: Only Python scalar types that match the array dtype are
247250
# allowed.
248251
if isinstance(scalar, bool):
@@ -268,10 +271,13 @@ def _promote_scalar(self, scalar):
268271
"Python float scalars can only be promoted with floating-point arrays."
269272
)
270273
elif isinstance(scalar, complex):
271-
if self.dtype not in _complex_floating_dtypes:
274+
if self.dtype not in _floating_dtypes:
272275
raise TypeError(
273-
"Python complex scalars can only be promoted with complex floating-point arrays."
276+
"Python complex scalars can only be promoted with floating-point arrays."
274277
)
278+
# 1j * array(floating) is allowed
279+
if self.dtype in _real_floating_dtypes:
280+
target_dtype = _real_to_complex_map[self.dtype]
275281
else:
276282
raise TypeError("'scalar' must be a Python scalar")
277283

@@ -282,7 +288,7 @@ def _promote_scalar(self, scalar):
282288
# behavior for integers within the bounds of the integer dtype.
283289
# Outside of those bounds we use the default NumPy behavior (either
284290
# cast or raise OverflowError).
285-
return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=self.device)
291+
return Array._new(np.array(scalar, dtype=target_dtype._np_dtype), device=self.device)
286292

287293
@staticmethod
288294
def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:

‎array_api_strict/_dtypes.py

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def __hash__(self):
126126
"floating-point": _floating_dtypes,
127127
}
128128

129+
_real_to_complex_map = {float32: complex64, float64: complex128}
129130

130131
# Note: the spec defines a restricted type promotion table compared to NumPy.
131132
# In particular, cross-kind promotions like integer + float or boolean +

‎array_api_strict/tests/test_array_object.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT):
108108
# - a Python int or float for real floating-point array dtypes
109109
# - a Python int, float, or complex for complex floating-point array dtypes
110110

111+
# an exception: complex scalar <op> floating array
112+
scalar_types_for_float = [float, int]
113+
if not (func_name.startswith("__i")
114+
or (func_name in ["__floordiv__", "__rfloordiv__", "__mod__", "__rmod__"]
115+
and type(s) == complex)
116+
):
117+
scalar_types_for_float += [complex]
118+
111119
if ((dtypes == "all"
112120
or dtypes == "numeric" and a.dtype in _numeric_dtypes
113121
or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes
@@ -121,7 +129,7 @@ def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT):
121129
# isinstance here.
122130
and (a.dtype in _boolean_dtypes and type(s) == bool
123131
or a.dtype in _integer_dtypes and type(s) == int
124-
or a.dtype in _real_floating_dtypes and type(s) in [float, int]
132+
or a.dtype in _real_floating_dtypes and type(s) in scalar_types_for_float
125133
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
126134
)):
127135
if a.dtype in _integer_dtypes and s == BIG_INT:

‎array_api_strict/tests/test_elementwise_functions.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,25 @@ def _array_vals():
233233
if nargs(func) != 2:
234234
continue
235235

236+
nocomplex = [
237+
'atan2', 'copysign', 'floor_divide', 'hypot', 'logaddexp', 'nextafter',
238+
'remainder',
239+
'greater', 'less', 'greater_equal', 'less_equal', 'maximum', 'minimum',
240+
]
241+
236242
for s in [1, 1.0, 1j, BIG_INT, False]:
237243
for a in _array_vals():
238244
for func1 in [lambda s: func(a, s), lambda s: func(s, a)]:
239-
allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name)
245+
246+
if func_name in nocomplex and type(s) == complex:
247+
allowed = False
248+
else:
249+
allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name)
240250

241251
# only check `func(array, scalar) == `func(array, array)` if
242252
# the former is legal under the promotion rules
243253
if allowed:
244-
conv_scalar = asarray(s, dtype=a.dtype)
254+
conv_scalar = a._promote_scalar(s)
245255

246256
with suppress_warnings() as sup:
247257
# ignore warnings from pow(BIG_INT)

0 commit comments

Comments
 (0)