Skip to content

Commit 06a5351

Browse files
committed
ENH: allow 1j * float_array
1 parent 1a288de commit 06a5351

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

array_api_strict/_array_object.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@
2626
_integer_dtypes,
2727
_integer_or_boolean_dtypes,
2828
_floating_dtypes,
29+
_real_floating_dtypes,
2930
_complex_floating_dtypes,
3031
_numeric_dtypes,
3132
_result_type,
3233
_dtype_categories,
34+
_real_to_complex_map,
3335
)
3436
from ._flags import get_array_api_strict_flags, set_array_api_strict_flags
3537

@@ -243,6 +245,7 @@ def _promote_scalar(self, scalar):
243245
"""
244246
from ._data_type_functions import iinfo
245247

248+
target_dtype = self.dtype
246249
# Note: Only Python scalar types that match the array dtype are
247250
# allowed.
248251
if isinstance(scalar, bool):
@@ -268,10 +271,13 @@ def _promote_scalar(self, scalar):
268271
"Python float scalars can only be promoted with floating-point arrays."
269272
)
270273
elif isinstance(scalar, complex):
271-
if self.dtype not in _complex_floating_dtypes:
274+
if self.dtype not in _floating_dtypes:
272275
raise TypeError(
273-
"Python complex scalars can only be promoted with complex floating-point arrays."
276+
"Python complex scalars can only be promoted with floating-point arrays."
274277
)
278+
# 1j * array(floating) is allowed
279+
if self.dtype in _real_floating_dtypes:
280+
target_dtype = _real_to_complex_map[self.dtype]
275281
else:
276282
raise TypeError("'scalar' must be a Python scalar")
277283

@@ -282,7 +288,7 @@ def _promote_scalar(self, scalar):
282288
# behavior for integers within the bounds of the integer dtype.
283289
# Outside of those bounds we use the default NumPy behavior (either
284290
# cast or raise OverflowError).
285-
return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=self.device)
291+
return Array._new(np.array(scalar, dtype=target_dtype._np_dtype), device=self.device)
286292

287293
@staticmethod
288294
def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:

array_api_strict/_dtypes.py

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def __hash__(self):
126126
"floating-point": _floating_dtypes,
127127
}
128128

129+
_real_to_complex_map = {float32: complex64, float64: complex128}
129130

130131
# Note: the spec defines a restricted type promotion table compared to NumPy.
131132
# In particular, cross-kind promotions like integer + float or boolean +

0 commit comments

Comments
 (0)