5
5
import numpy as np
6
6
import pytest
7
7
8
- from .. import ones , asarray , result_type , all , equal
8
+ from .. import ones , arange , reshape , asarray , result_type , all , equal
9
9
from .._array_object import Array , CPU_DEVICE , Device
10
10
from .._dtypes import (
11
11
_all_dtypes ,
@@ -70,11 +70,25 @@ def test_validate_index():
70
70
assert_raises (IndexError , lambda : a [[True , True , True ]])
71
71
assert_raises (IndexError , lambda : a [(True , True , True ),])
72
72
73
- # Integer array indices are not allowed (except for 0-D)
74
- idx = asarray ([0 , 1 ])
73
+ # Integer array indices are not allowed (except for 0-D or 1D )
74
+ idx = asarray ([[ 0 , 1 ]]) # idx.ndim == 2
75
75
assert_raises (IndexError , lambda : a [idx , 0 ])
76
76
assert_raises (IndexError , lambda : a [0 , idx ])
77
77
78
+ # Mixing 1D integer array indices with slices, ellipsis or booleans is not allowed
79
+ idx = asarray ([0 , 1 ])
80
+ assert_raises (IndexError , lambda : a [..., idx ])
81
+ assert_raises (IndexError , lambda : a [:, idx ])
82
+ assert_raises (IndexError , lambda : a [asarray ([True , True ]), idx ])
83
+
84
+ # 1D integer array indices must have the same length
85
+ idx1 = asarray ([0 , 1 ])
86
+ idx2 = asarray ([0 , 1 , 1 ])
87
+ assert_raises (IndexError , lambda : a [idx1 , idx2 ])
88
+
89
+ # Non-integer array indices are not allowed
90
+ assert_raises (IndexError , lambda : a [ones (2 ), 0 ])
91
+
78
92
# Array-likes (lists, tuples) are not allowed as indices
79
93
assert_raises (IndexError , lambda : a [[0 , 1 ]])
80
94
assert_raises (IndexError , lambda : a [(0 , 1 ), (0 , 1 )])
@@ -91,6 +105,37 @@ def test_validate_index():
91
105
assert_raises (IndexError , lambda : a [:])
92
106
assert_raises (IndexError , lambda : a [idx ])
93
107
108
+
109
+ def test_indexing_arrays ():
110
+ # indexing with 1D integer arrays and mixes of integers and 1D integer are allowed
111
+
112
+ # 1D array
113
+ a = arange (5 )
114
+ idx = asarray ([1 , 0 , 1 , 2 , - 1 ])
115
+ a_idx = a [idx ]
116
+
117
+ a_idx_loop = asarray ([a [idx [i ]] for i in range (idx .shape [0 ])])
118
+ assert all (a_idx == a_idx_loop )
119
+
120
+ # setitem with arrays is not allowed # XXX
121
+ # with assert_raises(IndexError):
122
+ # a[idx] = 42
123
+
124
+ # mixed array and integer indexing
125
+ a = reshape (arange (3 * 4 ), (3 , 4 ))
126
+ idx = asarray ([1 , 0 , 1 , 2 , - 1 ])
127
+ a_idx = a [idx , 1 ]
128
+
129
+ a_idx_loop = asarray ([a [idx [i ], 1 ] for i in range (idx .shape [0 ])])
130
+ assert all (a_idx == a_idx_loop )
131
+
132
+
133
+ # index with two arrays
134
+ a_idx = a [idx , idx ]
135
+ a_idx_loop = asarray ([a [idx [i ], idx [i ]] for i in range (idx .shape [0 ])])
136
+ assert all (a_idx == a_idx_loop )
137
+
138
+
94
139
def test_promoted_scalar_inherits_device ():
95
140
device1 = Device ("device1" )
96
141
x = asarray ([1. , 2 , 3 ], device = device1 )
0 commit comments