@@ -107,13 +107,14 @@ def test_indexing_arrays(device):
107
107
device = None if device is None else Device (device )
108
108
109
109
# 1D array
110
- a = arange (5 )
110
+ a = arange (5 , device = device )
111
111
idx = asarray ([1 , 0 , 1 , 2 , - 1 ], device = device )
112
112
a_idx = a [idx ]
113
113
114
114
a_idx_loop = stack ([a [idx [i ]] for i in range (idx .shape [0 ])])
115
115
assert all (a_idx == a_idx_loop )
116
116
assert a_idx .shape == idx .shape
117
+ assert a .device == idx .device == a_idx .device
117
118
118
119
# setitem with arrays is not allowed
119
120
with assert_raises (IndexError ):
@@ -126,20 +127,39 @@ def test_indexing_arrays(device):
126
127
a_idx_loop = stack ([a [idx [i ], 1 ] for i in range (idx .shape [0 ])])
127
128
assert all (a_idx == a_idx_loop )
128
129
assert a_idx .shape == idx .shape
130
+ assert a .device == idx .device == a_idx .device
129
131
130
132
# index with two arrays
131
133
a_idx = a [idx , idx ]
132
134
a_idx_loop = stack ([a [idx [i ], idx [i ]] for i in range (idx .shape [0 ])])
133
135
assert all (a_idx == a_idx_loop )
134
136
assert a_idx .shape == a_idx .shape
137
+ assert a .device == idx .device == a_idx .device
135
138
136
139
# setitem with arrays is not allowed
137
140
with assert_raises (IndexError ):
138
141
a [idx , idx ] = 42
139
142
140
143
# smoke test indexing with ndim > 1 arrays
141
144
idx = idx [..., None ]
142
- a [idx , idx ]
145
+ a_idx = a [idx , idx ]
146
+ assert a .device == idx .device == a_idx .device
147
+
148
+
149
+ def test_indexing_arrays_different_devices ():
150
+ # Ensure indexing via array on different device errors
151
+ device1 = Device ("CPU_DEVICE" )
152
+ device2 = Device ("device1" )
153
+
154
+ a = arange (5 , device = device1 )
155
+ idx1 = asarray ([1 , 0 , 1 , 2 , - 1 ], device = device2 )
156
+ idx2 = asarray ([1 , 0 , 1 , 2 , - 1 ], device = device1 )
157
+
158
+ with pytest .raises (ValueError , match = "Array indexing is only allowed when" ):
159
+ a [idx1 ]
160
+
161
+ with pytest .raises (ValueError , match = "Array indexing is only allowed when" ):
162
+ a [idx1 , idx2 ]
143
163
144
164
145
165
def test_promoted_scalar_inherits_device ():
0 commit comments