Skip to content

Commit ff83dba

Browse files
authored
Intexing mod to desc part1 (#778)
1 parent c16e98a commit ff83dba

File tree

2 files changed

+62
-108
lines changed

2 files changed

+62
-108
lines changed

dpnp/dpnp_algo/dpnp_algo_indexing.pyx

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ cpdef dpnp_place(dparray arr, dparray mask, dparray vals):
202202
func(arr.get_data(), mask_.get_data(), vals.get_data(), arr.size, vals.size)
203203

204204

205-
cpdef dpnp_put(dparray input, ind, v):
205+
cpdef dpnp_put(dpnp_descriptor input, object ind, v):
206206
ind_is_list = isinstance(ind, list)
207207

208208
if dpnp.isscalar(ind):
@@ -236,17 +236,18 @@ cpdef dpnp_put(dparray input, ind, v):
236236
func(input.get_data(), ind_array.get_data(), v_array.get_data(), input.size, ind_array.size, v_array.size)
237237

238238

239-
cpdef dpnp_put_along_axis(dparray arr, dparray indices, dparray values, int axis):
239+
cpdef dpnp_put_along_axis(dpnp_descriptor arr, dpnp_descriptor indices, dpnp_descriptor values, int axis):
240+
cdef dparray_shape_type arr_shape = arr.shape
240241
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)
241242

242243
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_PUT_ALONG_AXIS, param1_type, param1_type)
243244

244245
cdef custom_indexing_3in_with_axis_func_ptr_t func = <custom_indexing_3in_with_axis_func_ptr_t > kernel_data.ptr
245246

246-
func(arr.get_data(), indices.get_data(), values.get_data(), axis, < size_t * > arr._dparray_shape.data(), arr.ndim, indices.size, values.size)
247+
func(arr.get_data(), indices.get_data(), values.get_data(), axis, < size_t * > arr_shape.data(), arr.ndim, indices.size, values.size)
247248

248249

249-
cpdef dpnp_putmask(dparray arr, dparray mask, dparray values):
250+
cpdef dpnp_putmask(object arr, object mask, object values):
250251
cdef int values_size = values.size
251252
for i in range(arr.size):
252253
if mask[i]:
@@ -269,7 +270,7 @@ cpdef dparray dpnp_select(condlist, choicelist, default):
269270
return res_array.reshape(condlist[0].shape)
270271

271272

272-
cpdef dparray dpnp_take(dparray input, dparray indices):
273+
cpdef dparray dpnp_take(dpnp_descriptor input, dpnp_descriptor indices):
273274
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
274275

275276
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TAKE, param1_type, param1_type)
@@ -284,7 +285,7 @@ cpdef dparray dpnp_take(dparray input, dparray indices):
284285
return result
285286

286287

287-
cpdef dparray dpnp_take_along_axis(dparray arr, dparray indices, int axis):
288+
cpdef dparray dpnp_take_along_axis(object arr, object indices, int axis):
288289
cdef long size_arr = arr.size
289290
cdef dparray_shape_type shape_arr = arr.shape
290291
cdef long size_indices = indices.size
@@ -391,7 +392,7 @@ cpdef tuple dpnp_tril_indices(n, k=0, m=None):
391392
return (dparray1, dparray2)
392393

393394

394-
cpdef tuple dpnp_tril_indices_from(arr, k=0):
395+
cpdef tuple dpnp_tril_indices_from(dpnp_descriptor arr, k=0):
395396
m = arr.shape[0]
396397
n = arr.shape[1]
397398
array1 = []
@@ -435,7 +436,7 @@ cpdef tuple dpnp_triu_indices(n, k=0, m=None):
435436
return (dparray1, dparray2)
436437

437438

438-
cpdef tuple dpnp_triu_indices_from(arr, k=0):
439+
cpdef tuple dpnp_triu_indices_from(dpnp_descriptor arr, k=0):
439440
m = arr.shape[0]
440441
n = arr.shape[1]
441442
array1 = []

dpnp/dpnp_iface_indexing.py

Lines changed: 53 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def place(arr, mask, vals):
386386
return call_origin(numpy.place, arr, mask, vals)
387387

388388

389-
def put(input, ind, v, mode='raise'):
389+
def put(x1, ind, v, mode='raise'):
390390
"""
391391
Replaces specified elements of an array with given values.
392392
For full documentation refer to :obj:`numpy.put`.
@@ -397,22 +397,21 @@ def put(input, ind, v, mode='raise'):
397397
Not supported parameter mode.
398398
"""
399399

400-
if not use_origin_backend(input):
401-
if not isinstance(input, dparray):
402-
pass
403-
elif mode != 'raise':
400+
x1_desc = dpnp.get_dpnp_descriptor(x1)
401+
if x1_desc:
402+
if mode != 'raise':
404403
pass
405404
elif type(ind) != type(v):
406405
pass
407-
elif numpy.max(ind) >= input.size or numpy.min(ind) + input.size < 0:
406+
elif numpy.max(ind) >= x1_desc.size or numpy.min(ind) + x1_desc.size < 0:
408407
pass
409408
else:
410-
return dpnp_put(input, ind, v)
409+
return dpnp_put(x1_desc, ind, v)
411410

412-
return call_origin(numpy.put, input, ind, v, mode)
411+
return call_origin(numpy.put, x1, ind, v, mode)
413412

414413

415-
def put_along_axis(arr, indices, values, axis):
414+
def put_along_axis(x1, indices, values, axis):
416415
"""
417416
Put values into the destination array by matching 1d index and data slices.
418417
For full documentation refer to :obj:`numpy.put_along_axis`.
@@ -422,62 +421,25 @@ def put_along_axis(arr, indices, values, axis):
422421
:obj:`take_along_axis` : Take values from the input array by matching 1d index and data slices.
423422
"""
424423

425-
if not use_origin_backend(arr):
426-
if not isinstance(arr, dparray):
427-
pass
428-
elif not isinstance(indices, dparray):
429-
pass
430-
elif arr.ndim != indices.ndim:
424+
x1_desc = dpnp.get_dpnp_descriptor(x1)
425+
indices_desc = dpnp.get_dpnp_descriptor(indices)
426+
values_desc = dpnp.get_dpnp_descriptor(values)
427+
if x1_desc and indices_desc and values_desc:
428+
if x1_desc.ndim != indices_desc.ndim:
431429
pass
432430
elif not isinstance(axis, int):
433431
pass
434-
elif axis >= arr.ndim:
432+
elif axis >= x1_desc.ndim:
435433
pass
436-
elif not isinstance(values, (dparray, tuple, list)) and not dpnp.isscalar(values):
434+
elif indices_desc.size != values_desc.size:
437435
pass
438-
elif not dpnp.isscalar(values) and ((isinstance(values, dparray) and indices.size != values.size) or
439-
((isinstance(values, (tuple, list)) and indices.size != len(values)))):
440-
pass
441-
elif arr.ndim == indices.ndim:
442-
val_list = []
443-
for i in list(indices.shape)[:-1]:
444-
if i == 1:
445-
val_list.append(True)
446-
else:
447-
val_list.append(False)
448-
if not all(val_list):
449-
pass
450-
else:
451-
if dpnp.isscalar(values):
452-
values_size = 1
453-
values_ = dparray(values_size, dtype=arr.dtype)
454-
values_[0] = values
455-
elif isinstance(values, dparray):
456-
values_ = values
457-
else:
458-
values_size = len(values)
459-
values_ = dparray(values_size, dtype=arr.dtype)
460-
for i in range(values_size):
461-
values_[i] = values[i]
462-
return dpnp_put_along_axis(arr, indices, values_, axis)
463436
else:
464-
if dpnp.isscalar(values):
465-
values_size = 1
466-
values_ = dparray(values_size, dtype=arr.dtype)
467-
values_[0] = values
468-
elif isinstance(values, dparray):
469-
values_ = values
470-
else:
471-
values_size = len(values)
472-
values_ = dparray(values_size, dtype=arr.dtype)
473-
for i in range(values_size):
474-
values_[i] = values[i]
475-
return dpnp_put_along_axis(arr, indices, values_, axis)
437+
return dpnp_put_along_axis(x1_desc, indices_desc, values_desc, axis)
476438

477-
return call_origin(numpy.put_along_axis, arr, indices, values, axis)
439+
return call_origin(numpy.put_along_axis, x1, indices, values, axis)
478440

479441

480-
def putmask(arr, mask, values):
442+
def putmask(x1, mask, values):
481443
"""
482444
Changes elements of an array based on conditional and input values.
483445
For full documentation refer to :obj:`numpy.putmask`.
@@ -487,17 +449,13 @@ def putmask(arr, mask, values):
487449
Input arrays ``arr``, ``mask`` and ``values`` are supported as :obj:`dpnp.ndarray`.
488450
"""
489451

490-
if not use_origin_backend(arr):
491-
if not isinstance(arr, dparray):
492-
pass
493-
elif not isinstance(mask, dparray):
494-
pass
495-
elif not isinstance(values, dparray):
496-
pass
497-
else:
498-
return dpnp_putmask(arr, mask, values)
452+
x1_desc = dpnp.get_dpnp_descriptor(x1)
453+
mask_desc = dpnp.get_dpnp_descriptor(mask)
454+
values_desc = dpnp.get_dpnp_descriptor(values)
455+
if x1_desc and mask_desc and values_desc:
456+
return dpnp_putmask(x1, mask, values)
499457

500-
return call_origin(numpy.putmask, arr, mask, values)
458+
return call_origin(numpy.putmask, x1, mask, values)
501459

502460

503461
def select(condlist, choicelist, default=0):
@@ -510,6 +468,7 @@ def select(condlist, choicelist, default=0):
510468
Arrays of input lists are supported as :obj:`dpnp.ndarray`.
511469
Parameter ``default`` are supported only with default values.
512470
"""
471+
513472
if not use_origin_backend():
514473
if not isinstance(condlist, list):
515474
pass
@@ -537,7 +496,7 @@ def select(condlist, choicelist, default=0):
537496
return call_origin(numpy.select, condlist, choicelist, default)
538497

539498

540-
def take(input, indices, axis=None, out=None, mode='raise'):
499+
def take(x1, indices, axis=None, out=None, mode='raise'):
541500
"""
542501
Take elements from an array.
543502
For full documentation refer to :obj:`numpy.take`.
@@ -554,24 +513,22 @@ def take(input, indices, axis=None, out=None, mode='raise'):
554513
:obj:`take_along_axis` : Take elements by matching the array and the index arrays.
555514
"""
556515

557-
if not use_origin_backend(input):
558-
if not isinstance(input, dparray):
559-
pass
560-
elif not isinstance(indices, dparray):
561-
pass
562-
elif axis is not None:
516+
x1_desc = dpnp.get_dpnp_descriptor(x1)
517+
indices_desc = dpnp.get_dpnp_descriptor(indices)
518+
if x1_desc and indices_desc:
519+
if axis is not None:
563520
pass
564521
elif out is not None:
565522
pass
566523
elif mode != 'raise':
567524
pass
568525
else:
569-
return dpnp_take(input, indices)
526+
return dpnp_take(x1_desc, indices_desc)
570527

571-
return call_origin(numpy.take, input, indices, axis, out, mode)
528+
return call_origin(numpy.take, x1, indices, axis, out, mode)
572529

573530

574-
def take_along_axis(arr, indices, axis):
531+
def take_along_axis(x1, indices, axis):
575532
"""
576533
Take values from the input array by matching 1d index and data slices.
577534
For full documentation refer to :obj:`numpy.take_along_axis`.
@@ -582,32 +539,30 @@ def take_along_axis(arr, indices, axis):
582539
:obj:`put_along_axis` : Put values into the destination array by matching 1d index and data slices.
583540
"""
584541

585-
if not use_origin_backend(arr):
586-
if not isinstance(arr, dparray):
587-
pass
588-
elif not isinstance(indices, dparray):
589-
pass
590-
elif arr.ndim != indices.ndim:
542+
x1_desc = dpnp.get_dpnp_descriptor(x1)
543+
indices_desc = dpnp.get_dpnp_descriptor(indices)
544+
if x1_desc and indices_desc:
545+
if x1_desc.ndim != indices_desc.ndim:
591546
pass
592547
elif not isinstance(axis, int):
593548
pass
594-
elif axis >= arr.ndim:
549+
elif axis >= x1_desc.ndim:
595550
pass
596-
elif arr.ndim == indices.ndim:
551+
elif x1_desc.ndim == indices_desc.ndim:
597552
val_list = []
598-
for i in list(indices.shape)[:-1]:
553+
for i in list(indices_desc.shape)[:-1]:
599554
if i == 1:
600555
val_list.append(True)
601556
else:
602557
val_list.append(False)
603558
if not all(val_list):
604559
pass
605560
else:
606-
return dpnp_take_along_axis(arr, indices, axis)
561+
return dpnp_take_along_axis(x1, indices, axis)
607562
else:
608-
return dpnp_take_along_axis(arr, indices, axis)
563+
return dpnp_take_along_axis(x1, indices, axis)
609564

610-
return call_origin(numpy.take_along_axis, arr, indices, axis)
565+
return call_origin(numpy.take_along_axis, x1, indices, axis)
611566

612567

613568
def tril_indices(n, k=0, m=None):
@@ -644,7 +599,7 @@ def tril_indices(n, k=0, m=None):
644599
return call_origin(numpy.tril_indices, n, k, m)
645600

646601

647-
def tril_indices_from(arr, k=0):
602+
def tril_indices_from(x1, k=0):
648603
"""
649604
Return the indices for the lower-triangle of arr.
650605
See `tril_indices` for full details.
@@ -659,13 +614,12 @@ def tril_indices_from(arr, k=0):
659614
Diagonal offset (see `tril` for details).
660615
"""
661616

662-
is_arr_dparray = isinstance(arr, dparray)
663-
664-
if (not use_origin_backend(arr) and is_arr_dparray):
617+
x1_desc = dpnp.get_dpnp_descriptor(x1)
618+
if x1_desc:
665619
if isinstance(k, int):
666-
return dpnp_tril_indices_from(arr, k)
620+
return dpnp_tril_indices_from(x1_desc, k)
667621

668-
return call_origin(numpy.tril_indices_from, arr, k)
622+
return call_origin(numpy.tril_indices_from, x1, k)
669623

670624

671625
def triu_indices(n, k=0, m=None):
@@ -702,7 +656,7 @@ def triu_indices(n, k=0, m=None):
702656
return call_origin(numpy.triu_indices, n, k, m)
703657

704658

705-
def triu_indices_from(arr, k=0):
659+
def triu_indices_from(x1, k=0):
706660
"""
707661
Return the indices for the lower-triangle of arr.
708662
See `tril_indices` for full details.
@@ -717,10 +671,9 @@ def triu_indices_from(arr, k=0):
717671
Diagonal offset (see `tril` for details).
718672
"""
719673

720-
is_arr_dparray = isinstance(arr, dparray)
721-
722-
if (not use_origin_backend(arr) and is_arr_dparray):
674+
x1_desc = dpnp.get_dpnp_descriptor(x1)
675+
if x1_desc:
723676
if isinstance(k, int):
724-
return dpnp_triu_indices_from(arr, k)
677+
return dpnp_triu_indices_from(x1_desc, k)
725678

726-
return call_origin(numpy.triu_indices_from, arr, k)
679+
return call_origin(numpy.triu_indices_from, x1, k)

0 commit comments

Comments
 (0)