Skip to content

Commit 15eb9a0

Browse files
committed
BUGFIX: Fixing issues with boolean indexing
1 parent 1030103 commit 15eb9a0

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

arrayfire/array.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def device_ptr(self):
488488
Note
489489
----
490490
- This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL.
491-
- No other arrays will share the same device pointer.
491+
- No other arrays will share the same device pointer.
492492
- A copy of the memory is done if multiple arrays share the same memory or the array is not the owner of the memory.
493493
- In case of a copy the return value points to the newly allocated memory which is now exclusively owned by the array.
494494
"""
@@ -985,6 +985,12 @@ def __getitem__(self, key):
985985
try:
986986
out = Array()
987987
n_dims = self.numdims()
988+
989+
if (isinstance(key, Array) and key.type() == Dtype.b8.value):
990+
n_dims = 1
991+
if (count(key) == 0):
992+
return out
993+
988994
inds = _get_indices(key)
989995

990996
safe_call(backend.get().af_index_gen(ct.pointer(out.arr),
@@ -1005,9 +1011,21 @@ def __setitem__(self, key, val):
10051011
try:
10061012
n_dims = self.numdims()
10071013

1014+
is_boolean_idx = isinstance(key, Array) and key.type() == Dtype.b8.value
1015+
1016+
if (is_boolean_idx):
1017+
n_dims = 1
1018+
num = count(key)
1019+
if (num == 0):
1020+
return
1021+
10081022
if (_is_number(val)):
10091023
tdims = _get_assign_dims(key, self.dims())
1010-
other_arr = constant_array(val, tdims[0], tdims[1], tdims[2], tdims[3], self.type())
1024+
if (is_boolean_idx):
1025+
n_dims = 1
1026+
other_arr = constant_array(val, int(num), dtype=self.type())
1027+
else:
1028+
other_arr = constant_array(val, tdims[0] , tdims[1], tdims[2], tdims[3], self.type())
10111029
del_other = True
10121030
else:
10131031
other_arr = val.arr
@@ -1017,8 +1035,8 @@ def __setitem__(self, key, val):
10171035
inds = _get_indices(key)
10181036

10191037
safe_call(backend.get().af_assign_gen(ct.pointer(out_arr),
1020-
self.arr, ct.c_longlong(n_dims), inds.pointer,
1021-
other_arr))
1038+
self.arr, ct.c_longlong(n_dims), inds.pointer,
1039+
other_arr))
10221040
safe_call(backend.get().af_release_array(self.arr))
10231041
if del_other:
10241042
safe_call(backend.get().af_release_array(other_arr))
@@ -1235,5 +1253,5 @@ def read_array(filename, index=None, key=None):
12351253

12361254
return out
12371255

1238-
from .algorithm import sum
1256+
from .algorithm import (sum, count)
12391257
from .arith import cast

0 commit comments

Comments
 (0)