@@ -488,7 +488,7 @@ def device_ptr(self):
488
488
Note
489
489
----
490
490
- This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL.
491
- - No other arrays will share the same device pointer.
491
+ - No other arrays will share the same device pointer.
492
492
- A copy of the memory is done if multiple arrays share the same memory or the array is not the owner of the memory.
493
493
- In case of a copy the return value points to the newly allocated memory which is now exclusively owned by the array.
494
494
"""
@@ -985,6 +985,12 @@ def __getitem__(self, key):
985
985
try :
986
986
out = Array ()
987
987
n_dims = self .numdims ()
988
+
989
+ if (isinstance (key , Array ) and key .type () == Dtype .b8 .value ):
990
+ n_dims = 1
991
+ if (count (key ) == 0 ):
992
+ return out
993
+
988
994
inds = _get_indices (key )
989
995
990
996
safe_call (backend .get ().af_index_gen (ct .pointer (out .arr ),
@@ -1005,9 +1011,21 @@ def __setitem__(self, key, val):
1005
1011
try :
1006
1012
n_dims = self .numdims ()
1007
1013
1014
+ is_boolean_idx = isinstance (key , Array ) and key .type () == Dtype .b8 .value
1015
+
1016
+ if (is_boolean_idx ):
1017
+ n_dims = 1
1018
+ num = count (key )
1019
+ if (num == 0 ):
1020
+ return
1021
+
1008
1022
if (_is_number (val )):
1009
1023
tdims = _get_assign_dims (key , self .dims ())
1010
- other_arr = constant_array (val , tdims [0 ], tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
1024
+ if (is_boolean_idx ):
1025
+ n_dims = 1
1026
+ other_arr = constant_array (val , int (num ), dtype = self .type ())
1027
+ else :
1028
+ other_arr = constant_array (val , tdims [0 ] , tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
1011
1029
del_other = True
1012
1030
else :
1013
1031
other_arr = val .arr
@@ -1017,8 +1035,8 @@ def __setitem__(self, key, val):
1017
1035
inds = _get_indices (key )
1018
1036
1019
1037
safe_call (backend .get ().af_assign_gen (ct .pointer (out_arr ),
1020
- self .arr , ct .c_longlong (n_dims ), inds .pointer ,
1021
- other_arr ))
1038
+ self .arr , ct .c_longlong (n_dims ), inds .pointer ,
1039
+ other_arr ))
1022
1040
safe_call (backend .get ().af_release_array (self .arr ))
1023
1041
if del_other :
1024
1042
safe_call (backend .get ().af_release_array (other_arr ))
@@ -1235,5 +1253,5 @@ def read_array(filename, index=None, key=None):
1235
1253
1236
1254
return out
1237
1255
1238
- from .algorithm import sum
1256
+ from .algorithm import ( sum , count )
1239
1257
from .arith import cast
0 commit comments