Skip to content

Commit e775240

Browse files
authored
Merge pull request #73 from betatim/inherit-device-scalar-promotion
FIX Use array's device when promoting scalars
2 parents f34fa52 + f7152ff commit e775240

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

array_api_strict/_array_object.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def _promote_scalar(self, scalar):
274274
# behavior for integers within the bounds of the integer dtype.
275275
# Outside of those bounds we use the default NumPy behavior (either
276276
# cast or raise OverflowError).
277-
return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=CPU_DEVICE)
277+
return Array._new(np.array(scalar, dtype=self.dtype._np_dtype), device=self.device)
278278

279279
@staticmethod
280280
def _normalize_two_args(x1, x2) -> Tuple[Array, Array]:

array_api_strict/tests/test_array_object.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88
from .. import ones, asarray, result_type, all, equal
9-
from .._array_object import Array, CPU_DEVICE
9+
from .._array_object import Array, CPU_DEVICE, Device
1010
from .._dtypes import (
1111
_all_dtypes,
1212
_boolean_dtypes,
@@ -88,6 +88,14 @@ def test_validate_index():
8888
assert_raises(IndexError, lambda: a[0])
8989
assert_raises(IndexError, lambda: a[:])
9090

91+
def test_promoted_scalar_inherits_device():
92+
device1 = Device("device1")
93+
x = asarray([1., 2, 3], device=device1)
94+
95+
y = x ** 2
96+
97+
assert y.device == device1
98+
9199
def test_operators():
92100
# For every operator, we test that it works for the required type
93101
# combinations and raises TypeError otherwise

0 commit comments

Comments
 (0)