@@ -31,6 +31,27 @@ def _create_array(buf, numdims, idims, dtype, is_device):
31
31
numdims , ct .pointer (c_dims ), dtype .value ))
32
32
return out_arr
33
33
34
+ def _create_strided_array (buf , numdims , idims , dtype , is_device , offset , strides ):
35
+ out_arr = ct .c_void_p (0 )
36
+ c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
37
+ if offset is None :
38
+ offset = 0
39
+ offset = ct .c_ulonglong (offset )
40
+ if strides is None :
41
+ strides = (1 , idims [0 ], idims [0 ]* idims [1 ], idims [0 ]* idims [1 ]* idims [2 ])
42
+ while len (strides ) < 4 :
43
+ strides = strides + (strides [- 1 ],)
44
+ strides = dim4 (strides [0 ], strides [1 ], strides [2 ], strides [3 ])
45
+ if is_device :
46
+ location = Source .device
47
+ else :
48
+ location = Source .host
49
+ safe_call (backend .get ().af_create_strided_array (ct .pointer (out_arr ), ct .c_void_p (buf ),
50
+ offset , numdims , ct .pointer (c_dims ),
51
+ ct .pointer (strides ), dtype .value ,
52
+ location .value ))
53
+ return out_arr
54
+
34
55
def _create_empty_array (numdims , idims , dtype ):
35
56
out_arr = ct .c_void_p (0 )
36
57
c_dims = dim4 (idims [0 ], idims [1 ], idims [2 ], idims [3 ])
@@ -352,7 +373,7 @@ class Array(BaseArray):
352
373
353
374
"""
354
375
355
- def __init__ (self , src = None , dims = (0 ,), dtype = None , is_device = False ):
376
+ def __init__ (self , src = None , dims = (0 ,), dtype = None , is_device = False , offset = None , strides = None ):
356
377
357
378
super (Array , self ).__init__ ()
358
379
@@ -409,8 +430,10 @@ def __init__(self, src=None, dims=(0,), dtype=None, is_device=False):
409
430
if (type_char is not None and
410
431
type_char != _type_char ):
411
432
raise TypeError ("Can not create array of requested type from input data type" )
412
-
413
- self .arr = _create_array (buf , numdims , idims , to_dtype [_type_char ], is_device )
433
+ if (offset is None and strides is None ):
434
+ self .arr = _create_array (buf , numdims , idims , to_dtype [_type_char ], is_device )
435
+ else :
436
+ self .arr = _create_strided_array (buf , numdims , idims , to_dtype [_type_char ], is_device , offset , strides )
414
437
415
438
else :
416
439
@@ -454,6 +477,26 @@ def __del__(self):
454
477
backend .get ().af_release_array (self .arr )
455
478
456
479
def device_ptr (self ):
480
+ """
481
+ Return the device pointer exclusively held by the array.
482
+
483
+ Returns
484
+ ------
485
+ ptr : int
486
+ Contains location of the device pointer
487
+
488
+ Note
489
+ ----
490
+ - This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL.
491
+ - No other arrays will share the same device pointer.
492
+ - A copy of the memory is done if multiple arrays share the same memory or the array is not the owner of the memory.
493
+ - In case of a copy the return value points to the newly allocated memory which is now exclusively owned by the array.
494
+ """
495
+ ptr = ct .c_void_p (0 )
496
+ backend .get ().af_get_device_ptr (ct .pointer (ptr ), self .arr )
497
+ return ptr .value
498
+
499
+ def raw_ptr (self ):
457
500
"""
458
501
Return the device pointer held by the array.
459
502
@@ -466,11 +509,45 @@ def device_ptr(self):
466
509
----
467
510
- This can be used to integrate with custom C code and / or PyCUDA or PyOpenCL.
468
511
- No mem copy is peformed, this function returns the raw device pointer.
512
+ - This pointer may be shared with other arrays. Use this function with caution.
513
+ - In particular the JIT compiler will not be aware of the shared arrays.
514
+ - This results in JITed operations not being immediately visible through the other array.
469
515
"""
470
516
ptr = ct .c_void_p (0 )
471
- backend .get ().af_get_device_ptr (ct .pointer (ptr ), self .arr )
517
+ backend .get ().af_get_raw_ptr (ct .pointer (ptr ), self .arr )
472
518
return ptr .value
473
519
520
+ def offset (self ):
521
+ """
522
+ Return the offset, of the first element relative to the raw pointer.
523
+
524
+ Returns
525
+ ------
526
+ offset : int
527
+ The offset in number of elements
528
+ """
529
+ offset = ct .c_longlong (0 )
530
+ safe_call (backend .get ().af_get_offset (ct .pointer (offset ), self .arr ))
531
+ return offset .value
532
+
533
+ def strides (self ):
534
+ """
535
+ Return the distance in bytes between consecutive elements for each dimension.
536
+
537
+ Returns
538
+ ------
539
+ strides : tuple
540
+ The strides for each dimension
541
+ """
542
+ s0 = ct .c_longlong (0 )
543
+ s1 = ct .c_longlong (0 )
544
+ s2 = ct .c_longlong (0 )
545
+ s3 = ct .c_longlong (0 )
546
+ safe_call (backend .get ().af_get_strides (ct .pointer (s0 ), ct .pointer (s1 ),
547
+ ct .pointer (s2 ), ct .pointer (s3 ), self .arr ))
548
+ strides = (s0 .value ,s1 .value ,s2 .value ,s3 .value )
549
+ return strides [:self .numdims ()]
550
+
474
551
def elements (self ):
475
552
"""
476
553
Return the number of elements in the array.
@@ -622,6 +699,22 @@ def is_bool(self):
622
699
safe_call (backend .get ().af_is_bool (ct .pointer (res ), self .arr ))
623
700
return res .value
624
701
702
+ def is_linear (self ):
703
+ """
704
+ Check if all elements of the array are contiguous.
705
+ """
706
+ res = ct .c_bool (False )
707
+ safe_call (backend .get ().af_is_linear (ct .pointer (res ), self .arr ))
708
+ return res .value
709
+
710
+ def is_owner (self ):
711
+ """
712
+ Check if the array owns the raw pointer or is a derived array.
713
+ """
714
+ res = ct .c_bool (False )
715
+ safe_call (backend .get ().af_is_owner (ct .pointer (res ), self .arr ))
716
+ return res .value
717
+
625
718
def __add__ (self , other ):
626
719
"""
627
720
Return self + other.
@@ -892,6 +985,12 @@ def __getitem__(self, key):
892
985
try :
893
986
out = Array ()
894
987
n_dims = self .numdims ()
988
+
989
+ if (isinstance (key , Array ) and key .type () == Dtype .b8 .value ):
990
+ n_dims = 1
991
+ if (count (key ) == 0 ):
992
+ return out
993
+
895
994
inds = _get_indices (key )
896
995
897
996
safe_call (backend .get ().af_index_gen (ct .pointer (out .arr ),
@@ -912,9 +1011,21 @@ def __setitem__(self, key, val):
912
1011
try :
913
1012
n_dims = self .numdims ()
914
1013
1014
+ is_boolean_idx = isinstance (key , Array ) and key .type () == Dtype .b8 .value
1015
+
1016
+ if (is_boolean_idx ):
1017
+ n_dims = 1
1018
+ num = count (key )
1019
+ if (num == 0 ):
1020
+ return
1021
+
915
1022
if (_is_number (val )):
916
1023
tdims = _get_assign_dims (key , self .dims ())
917
- other_arr = constant_array (val , tdims [0 ], tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
1024
+ if (is_boolean_idx ):
1025
+ n_dims = 1
1026
+ other_arr = constant_array (val , int (num ), dtype = self .type ())
1027
+ else :
1028
+ other_arr = constant_array (val , tdims [0 ] , tdims [1 ], tdims [2 ], tdims [3 ], self .type ())
918
1029
del_other = True
919
1030
else :
920
1031
other_arr = val .arr
@@ -924,8 +1035,8 @@ def __setitem__(self, key, val):
924
1035
inds = _get_indices (key )
925
1036
926
1037
safe_call (backend .get ().af_assign_gen (ct .pointer (out_arr ),
927
- self .arr , ct .c_longlong (n_dims ), inds .pointer ,
928
- other_arr ))
1038
+ self .arr , ct .c_longlong (n_dims ), inds .pointer ,
1039
+ other_arr ))
929
1040
safe_call (backend .get ().af_release_array (self .arr ))
930
1041
if del_other :
931
1042
safe_call (backend .get ().af_release_array (other_arr ))
@@ -1142,5 +1253,5 @@ def read_array(filename, index=None, key=None):
1142
1253
1143
1254
return out
1144
1255
1145
- from .algorithm import sum
1256
+ from .algorithm import ( sum , count )
1146
1257
from .arith import cast
0 commit comments