26
26
_integer_dtypes ,
27
27
_integer_or_boolean_dtypes ,
28
28
_floating_dtypes ,
29
+ _real_floating_dtypes ,
29
30
_complex_floating_dtypes ,
30
31
_numeric_dtypes ,
31
32
_result_type ,
32
33
_dtype_categories ,
34
+ _real_to_complex_map ,
33
35
)
34
36
from ._flags import get_array_api_strict_flags , set_array_api_strict_flags
35
37
@@ -243,6 +245,7 @@ def _promote_scalar(self, scalar):
243
245
"""
244
246
from ._data_type_functions import iinfo
245
247
248
+ target_dtype = self .dtype
246
249
# Note: Only Python scalar types that match the array dtype are
247
250
# allowed.
248
251
if isinstance (scalar , bool ):
@@ -268,10 +271,13 @@ def _promote_scalar(self, scalar):
268
271
"Python float scalars can only be promoted with floating-point arrays."
269
272
)
270
273
elif isinstance (scalar , complex ):
271
- if self .dtype not in _complex_floating_dtypes :
274
+ if self .dtype not in _floating_dtypes :
272
275
raise TypeError (
273
- "Python complex scalars can only be promoted with complex floating-point arrays."
276
+ "Python complex scalars can only be promoted with floating-point arrays."
274
277
)
278
+ # 1j * array(floating) is allowed
279
+ if self .dtype in _real_floating_dtypes :
280
+ target_dtype = _real_to_complex_map [self .dtype ]
275
281
else :
276
282
raise TypeError ("'scalar' must be a Python scalar" )
277
283
@@ -282,7 +288,7 @@ def _promote_scalar(self, scalar):
282
288
# behavior for integers within the bounds of the integer dtype.
283
289
# Outside of those bounds we use the default NumPy behavior (either
284
290
# cast or raise OverflowError).
285
- return Array ._new (np .array (scalar , dtype = self . dtype ._np_dtype ), device = self .device )
291
+ return Array ._new (np .array (scalar , dtype = target_dtype ._np_dtype ), device = self .device )
286
292
287
293
@staticmethod
288
294
def _normalize_two_args (x1 , x2 ) -> Tuple [Array , Array ]:
0 commit comments