@@ -386,7 +386,7 @@ def place(arr, mask, vals):
386
386
return call_origin (numpy .place , arr , mask , vals )
387
387
388
388
389
- def put (input , ind , v , mode = 'raise' ):
389
+ def put (x1 , ind , v , mode = 'raise' ):
390
390
"""
391
391
Replaces specified elements of an array with given values.
392
392
For full documentation refer to :obj:`numpy.put`.
@@ -397,22 +397,21 @@ def put(input, ind, v, mode='raise'):
397
397
Not supported parameter mode.
398
398
"""
399
399
400
- if not use_origin_backend (input ):
401
- if not isinstance (input , dparray ):
402
- pass
403
- elif mode != 'raise' :
400
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
401
+ if x1_desc :
402
+ if mode != 'raise' :
404
403
pass
405
404
elif type (ind ) != type (v ):
406
405
pass
407
- elif numpy .max (ind ) >= input .size or numpy .min (ind ) + input .size < 0 :
406
+ elif numpy .max (ind ) >= x1_desc .size or numpy .min (ind ) + x1_desc .size < 0 :
408
407
pass
409
408
else :
410
- return dpnp_put (input , ind , v )
409
+ return dpnp_put (x1_desc , ind , v )
411
410
412
- return call_origin (numpy .put , input , ind , v , mode )
411
+ return call_origin (numpy .put , x1 , ind , v , mode )
413
412
414
413
415
- def put_along_axis (arr , indices , values , axis ):
414
+ def put_along_axis (x1 , indices , values , axis ):
416
415
"""
417
416
Put values into the destination array by matching 1d index and data slices.
418
417
For full documentation refer to :obj:`numpy.put_along_axis`.
@@ -422,62 +421,25 @@ def put_along_axis(arr, indices, values, axis):
422
421
:obj:`take_along_axis` : Take values from the input array by matching 1d index and data slices.
423
422
"""
424
423
425
- if not use_origin_backend (arr ):
426
- if not isinstance (arr , dparray ):
427
- pass
428
- elif not isinstance (indices , dparray ):
429
- pass
430
- elif arr .ndim != indices .ndim :
424
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
425
+ indices_desc = dpnp .get_dpnp_descriptor (indices )
426
+ values_desc = dpnp .get_dpnp_descriptor (values )
427
+ if x1_desc and indices_desc and values_desc :
428
+ if x1_desc .ndim != indices_desc .ndim :
431
429
pass
432
430
elif not isinstance (axis , int ):
433
431
pass
434
- elif axis >= arr .ndim :
432
+ elif axis >= x1_desc .ndim :
435
433
pass
436
- elif not isinstance ( values , ( dparray , tuple , list )) and not dpnp . isscalar ( values ) :
434
+ elif indices_desc . size != values_desc . size :
437
435
pass
438
- elif not dpnp .isscalar (values ) and ((isinstance (values , dparray ) and indices .size != values .size ) or
439
- ((isinstance (values , (tuple , list )) and indices .size != len (values )))):
440
- pass
441
- elif arr .ndim == indices .ndim :
442
- val_list = []
443
- for i in list (indices .shape )[:- 1 ]:
444
- if i == 1 :
445
- val_list .append (True )
446
- else :
447
- val_list .append (False )
448
- if not all (val_list ):
449
- pass
450
- else :
451
- if dpnp .isscalar (values ):
452
- values_size = 1
453
- values_ = dparray (values_size , dtype = arr .dtype )
454
- values_ [0 ] = values
455
- elif isinstance (values , dparray ):
456
- values_ = values
457
- else :
458
- values_size = len (values )
459
- values_ = dparray (values_size , dtype = arr .dtype )
460
- for i in range (values_size ):
461
- values_ [i ] = values [i ]
462
- return dpnp_put_along_axis (arr , indices , values_ , axis )
463
436
else :
464
- if dpnp .isscalar (values ):
465
- values_size = 1
466
- values_ = dparray (values_size , dtype = arr .dtype )
467
- values_ [0 ] = values
468
- elif isinstance (values , dparray ):
469
- values_ = values
470
- else :
471
- values_size = len (values )
472
- values_ = dparray (values_size , dtype = arr .dtype )
473
- for i in range (values_size ):
474
- values_ [i ] = values [i ]
475
- return dpnp_put_along_axis (arr , indices , values_ , axis )
437
+ return dpnp_put_along_axis (x1_desc , indices_desc , values_desc , axis )
476
438
477
- return call_origin (numpy .put_along_axis , arr , indices , values , axis )
439
+ return call_origin (numpy .put_along_axis , x1 , indices , values , axis )
478
440
479
441
480
- def putmask (arr , mask , values ):
442
+ def putmask (x1 , mask , values ):
481
443
"""
482
444
Changes elements of an array based on conditional and input values.
483
445
For full documentation refer to :obj:`numpy.putmask`.
@@ -487,17 +449,13 @@ def putmask(arr, mask, values):
487
449
Input arrays ``arr``, ``mask`` and ``values`` are supported as :obj:`dpnp.ndarray`.
488
450
"""
489
451
490
- if not use_origin_backend (arr ):
491
- if not isinstance (arr , dparray ):
492
- pass
493
- elif not isinstance (mask , dparray ):
494
- pass
495
- elif not isinstance (values , dparray ):
496
- pass
497
- else :
498
- return dpnp_putmask (arr , mask , values )
452
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
453
+ mask_desc = dpnp .get_dpnp_descriptor (mask )
454
+ values_desc = dpnp .get_dpnp_descriptor (values )
455
+ if x1_desc and mask_desc and values_desc :
456
+ return dpnp_putmask (x1 , mask , values )
499
457
500
- return call_origin (numpy .putmask , arr , mask , values )
458
+ return call_origin (numpy .putmask , x1 , mask , values )
501
459
502
460
503
461
def select (condlist , choicelist , default = 0 ):
@@ -510,6 +468,7 @@ def select(condlist, choicelist, default=0):
510
468
Arrays of input lists are supported as :obj:`dpnp.ndarray`.
511
469
Parameter ``default`` are supported only with default values.
512
470
"""
471
+
513
472
if not use_origin_backend ():
514
473
if not isinstance (condlist , list ):
515
474
pass
@@ -537,7 +496,7 @@ def select(condlist, choicelist, default=0):
537
496
return call_origin (numpy .select , condlist , choicelist , default )
538
497
539
498
540
- def take (input , indices , axis = None , out = None , mode = 'raise' ):
499
+ def take (x1 , indices , axis = None , out = None , mode = 'raise' ):
541
500
"""
542
501
Take elements from an array.
543
502
For full documentation refer to :obj:`numpy.take`.
@@ -554,24 +513,22 @@ def take(input, indices, axis=None, out=None, mode='raise'):
554
513
:obj:`take_along_axis` : Take elements by matching the array and the index arrays.
555
514
"""
556
515
557
- if not use_origin_backend (input ):
558
- if not isinstance (input , dparray ):
559
- pass
560
- elif not isinstance (indices , dparray ):
561
- pass
562
- elif axis is not None :
516
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
517
+ indices_desc = dpnp .get_dpnp_descriptor (indices )
518
+ if x1_desc and indices_desc :
519
+ if axis is not None :
563
520
pass
564
521
elif out is not None :
565
522
pass
566
523
elif mode != 'raise' :
567
524
pass
568
525
else :
569
- return dpnp_take (input , indices )
526
+ return dpnp_take (x1_desc , indices_desc )
570
527
571
- return call_origin (numpy .take , input , indices , axis , out , mode )
528
+ return call_origin (numpy .take , x1 , indices , axis , out , mode )
572
529
573
530
574
- def take_along_axis (arr , indices , axis ):
531
+ def take_along_axis (x1 , indices , axis ):
575
532
"""
576
533
Take values from the input array by matching 1d index and data slices.
577
534
For full documentation refer to :obj:`numpy.take_along_axis`.
@@ -582,32 +539,30 @@ def take_along_axis(arr, indices, axis):
582
539
:obj:`put_along_axis` : Put values into the destination array by matching 1d index and data slices.
583
540
"""
584
541
585
- if not use_origin_backend (arr ):
586
- if not isinstance (arr , dparray ):
587
- pass
588
- elif not isinstance (indices , dparray ):
589
- pass
590
- elif arr .ndim != indices .ndim :
542
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
543
+ indices_desc = dpnp .get_dpnp_descriptor (indices )
544
+ if x1_desc and indices_desc :
545
+ if x1_desc .ndim != indices_desc .ndim :
591
546
pass
592
547
elif not isinstance (axis , int ):
593
548
pass
594
- elif axis >= arr .ndim :
549
+ elif axis >= x1_desc .ndim :
595
550
pass
596
- elif arr .ndim == indices .ndim :
551
+ elif x1_desc .ndim == indices_desc .ndim :
597
552
val_list = []
598
- for i in list (indices .shape )[:- 1 ]:
553
+ for i in list (indices_desc .shape )[:- 1 ]:
599
554
if i == 1 :
600
555
val_list .append (True )
601
556
else :
602
557
val_list .append (False )
603
558
if not all (val_list ):
604
559
pass
605
560
else :
606
- return dpnp_take_along_axis (arr , indices , axis )
561
+ return dpnp_take_along_axis (x1 , indices , axis )
607
562
else :
608
- return dpnp_take_along_axis (arr , indices , axis )
563
+ return dpnp_take_along_axis (x1 , indices , axis )
609
564
610
- return call_origin (numpy .take_along_axis , arr , indices , axis )
565
+ return call_origin (numpy .take_along_axis , x1 , indices , axis )
611
566
612
567
613
568
def tril_indices (n , k = 0 , m = None ):
@@ -644,7 +599,7 @@ def tril_indices(n, k=0, m=None):
644
599
return call_origin (numpy .tril_indices , n , k , m )
645
600
646
601
647
- def tril_indices_from (arr , k = 0 ):
602
+ def tril_indices_from (x1 , k = 0 ):
648
603
"""
649
604
Return the indices for the lower-triangle of arr.
650
605
See `tril_indices` for full details.
@@ -659,13 +614,12 @@ def tril_indices_from(arr, k=0):
659
614
Diagonal offset (see `tril` for details).
660
615
"""
661
616
662
- is_arr_dparray = isinstance (arr , dparray )
663
-
664
- if (not use_origin_backend (arr ) and is_arr_dparray ):
617
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
618
+ if x1_desc :
665
619
if isinstance (k , int ):
666
- return dpnp_tril_indices_from (arr , k )
620
+ return dpnp_tril_indices_from (x1_desc , k )
667
621
668
- return call_origin (numpy .tril_indices_from , arr , k )
622
+ return call_origin (numpy .tril_indices_from , x1 , k )
669
623
670
624
671
625
def triu_indices (n , k = 0 , m = None ):
@@ -702,7 +656,7 @@ def triu_indices(n, k=0, m=None):
702
656
return call_origin (numpy .triu_indices , n , k , m )
703
657
704
658
705
- def triu_indices_from (arr , k = 0 ):
659
+ def triu_indices_from (x1 , k = 0 ):
706
660
"""
707
661
Return the indices for the lower-triangle of arr.
708
662
See `tril_indices` for full details.
@@ -717,10 +671,9 @@ def triu_indices_from(arr, k=0):
717
671
Diagonal offset (see `tril` for details).
718
672
"""
719
673
720
- is_arr_dparray = isinstance (arr , dparray )
721
-
722
- if (not use_origin_backend (arr ) and is_arr_dparray ):
674
+ x1_desc = dpnp .get_dpnp_descriptor (x1 )
675
+ if x1_desc :
723
676
if isinstance (k , int ):
724
- return dpnp_triu_indices_from (arr , k )
677
+ return dpnp_triu_indices_from (x1_desc , k )
725
678
726
- return call_origin (numpy .triu_indices_from , arr , k )
679
+ return call_origin (numpy .triu_indices_from , x1 , k )
0 commit comments