@@ -625,7 +625,7 @@ def test_put_0d_val(data_dt):
625
625
skip_if_dtype_not_supported (data_dt , q )
626
626
627
627
x = dpt .arange (5 , dtype = data_dt , sycl_queue = q )
628
- ind = dpt .asarray ([0 ], dtype = np . intp , sycl_queue = q )
628
+ ind = dpt .asarray ([0 ], dtype = "i8" , sycl_queue = q )
629
629
val = dpt .asarray (2 , dtype = x .dtype , sycl_queue = q )
630
630
x [ind ] = val
631
631
assert_array_equal (np .asarray (2 , dtype = data_dt ), dpt .asnumpy (x [0 ]))
@@ -644,7 +644,7 @@ def test_take_0d_data(data_dt):
644
644
skip_if_dtype_not_supported (data_dt , q )
645
645
646
646
x = dpt .asarray (0 , dtype = data_dt , sycl_queue = q )
647
- ind = dpt .arange (5 , dtype = np . intp , sycl_queue = q )
647
+ ind = dpt .arange (5 , dtype = "i8" , sycl_queue = q )
648
648
649
649
y = dpt .take (x , ind )
650
650
assert (
@@ -662,7 +662,7 @@ def test_put_0d_data(data_dt):
662
662
skip_if_dtype_not_supported (data_dt , q )
663
663
664
664
x = dpt .asarray (0 , dtype = data_dt , sycl_queue = q )
665
- ind = dpt .arange (5 , dtype = np . intp , sycl_queue = q )
665
+ ind = dpt .arange (5 , dtype = "i8" , sycl_queue = q )
666
666
val = dpt .asarray (2 , dtype = data_dt , sycl_queue = q )
667
667
668
668
dpt .put (x , ind , val , axis = 0 )
@@ -710,7 +710,7 @@ def test_take_strided_1d_source(data_dt):
710
710
skip_if_dtype_not_supported (data_dt , q )
711
711
712
712
x = dpt .arange (27 , dtype = data_dt , sycl_queue = q )
713
- ind = dpt .arange (4 , 9 , dtype = np . intp , sycl_queue = q )
713
+ ind = dpt .arange (4 , 9 , dtype = "i8" , sycl_queue = q )
714
714
715
715
x_np = dpt .asnumpy (x )
716
716
ind_np = dpt .asnumpy (ind )
@@ -748,7 +748,7 @@ def test_take_strided(data_dt, order):
748
748
skip_if_dtype_not_supported (data_dt , q )
749
749
750
750
x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
751
- ind = dpt .arange (2 , dtype = np . intp , sycl_queue = q )
751
+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
752
752
753
753
x_np = dpt .asnumpy (x )
754
754
ind_np = dpt .asnumpy (ind )
@@ -781,7 +781,7 @@ def test_take_strided_1d_indices(ind_dt):
781
781
ind = dpt .arange (12 , 24 , dtype = ind_dt , sycl_queue = q )
782
782
783
783
x_np = dpt .asnumpy (x )
784
- ind_np = dpt .asnumpy (ind ).astype (np . intp )
784
+ ind_np = dpt .asnumpy (ind ).astype ("i8" )
785
785
786
786
for s in (
787
787
slice (None , None , 2 ),
@@ -820,7 +820,7 @@ def test_take_strided_indices(ind_dt, order):
820
820
)
821
821
822
822
x_np = dpt .asnumpy (x )
823
- ind_np = dpt .asnumpy (ind ).astype (np . intp )
823
+ ind_np = dpt .asnumpy (ind ).astype ("i8" )
824
824
825
825
for s in (
826
826
slice (None , None , 2 ),
@@ -845,7 +845,7 @@ def test_put_strided_1d_destination(data_dt, order):
845
845
skip_if_dtype_not_supported (data_dt , q )
846
846
847
847
x = dpt .arange (27 , dtype = data_dt , sycl_queue = q )
848
- ind = dpt .arange (4 , 9 , dtype = np . intp , sycl_queue = q )
848
+ ind = dpt .arange (4 , 9 , dtype = "i8" , sycl_queue = q )
849
849
val = dpt .asarray (9 , dtype = x .dtype , sycl_queue = q )
850
850
851
851
x_np = dpt .asnumpy (x )
@@ -875,7 +875,7 @@ def test_put_strided_destination(data_dt, order):
875
875
skip_if_dtype_not_supported (data_dt , q )
876
876
877
877
x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
878
- ind = dpt .arange (2 , dtype = np . intp , sycl_queue = q )
878
+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
879
879
val = dpt .asarray (9 , dtype = x .dtype , sycl_queue = q )
880
880
881
881
x_np = dpt .asnumpy (x )
@@ -924,7 +924,7 @@ def test_put_strided_1d_indices(ind_dt):
924
924
val = dpt .asarray (- 1 , dtype = x .dtype , sycl_queue = q )
925
925
926
926
x_np = dpt .asnumpy (x )
927
- ind_np = dpt .asnumpy (ind ).astype (np . intp )
927
+ ind_np = dpt .asnumpy (ind ).astype ("i8" )
928
928
val_np = dpt .asnumpy (val )
929
929
930
930
for s in (
@@ -955,7 +955,7 @@ def test_put_strided_indices(ind_dt, order):
955
955
val = dpt .asarray (- 1 , sycl_queue = q , dtype = x .dtype )
956
956
957
957
x_np = dpt .asnumpy (x )
958
- ind_np = dpt .asnumpy (ind ).astype (np . intp )
958
+ ind_np = dpt .asnumpy (ind ).astype ("i8" )
959
959
val_np = dpt .asnumpy (val )
960
960
961
961
for s in (
@@ -982,15 +982,15 @@ def test_integer_indexing_modes():
982
982
x_np = dpt .asnumpy (x )
983
983
984
984
# wrapping negative indices
985
- ind = dpt .asarray ([- 4 , - 3 , 0 , 2 , 4 ], dtype = np . intp , sycl_queue = q )
985
+ ind = dpt .asarray ([- 4 , - 3 , 0 , 2 , 4 ], dtype = "i8" , sycl_queue = q )
986
986
987
987
res = dpt .take (x , ind , mode = "wrap" )
988
988
expected_arr = np .take (x_np , dpt .asnumpy (ind ), mode = "raise" )
989
989
990
990
assert (dpt .asnumpy (res ) == expected_arr ).all ()
991
991
992
992
# clipping to 0 (disabling negative indices)
993
- ind = dpt .asarray ([- 6 , - 3 , 0 , 2 , 6 ], dtype = np . intp , sycl_queue = q )
993
+ ind = dpt .asarray ([- 6 , - 3 , 0 , 2 , 6 ], dtype = "i8" , sycl_queue = q )
994
994
995
995
res = dpt .take (x , ind , mode = "clip" )
996
996
expected_arr = np .take (x_np , dpt .asnumpy (ind ), mode = "clip" )
@@ -1002,7 +1002,7 @@ def test_take_arg_validation():
1002
1002
q = get_queue_or_skip ()
1003
1003
1004
1004
x = dpt .arange (4 , dtype = "i4" , sycl_queue = q )
1005
- ind0 = dpt .arange (4 , dtype = np . intp , sycl_queue = q )
1005
+ ind0 = dpt .arange (4 , dtype = "i8" , sycl_queue = q )
1006
1006
ind1 = dpt .arange (2.0 , dtype = "f" , sycl_queue = q )
1007
1007
1008
1008
with pytest .raises (TypeError ):
@@ -1034,7 +1034,7 @@ def test_put_arg_validation():
1034
1034
q = get_queue_or_skip ()
1035
1035
1036
1036
x = dpt .arange (4 , dtype = "i4" , sycl_queue = q )
1037
- ind0 = dpt .arange (4 , dtype = np . intp , sycl_queue = q )
1037
+ ind0 = dpt .arange (4 , dtype = "i8" , sycl_queue = q )
1038
1038
ind1 = dpt .arange (2.0 , dtype = "f" , sycl_queue = q )
1039
1039
val = dpt .asarray (2 , dtype = x .dtype , sycl_queue = q )
1040
1040
@@ -1890,3 +1890,69 @@ def test_put_along_axis_uint64_indices():
1890
1890
dpt .put_along_axis (x , inds , dpt .asarray (2 , dtype = x .dtype ), axis = 1 )
1891
1891
expected = dpt .tile (dpt .asarray ([0 , 2 ], dtype = "i4" ), (2 , 5 ))
1892
1892
assert dpt .all (expected == x )
1893
+
1894
+
1895
+ @pytest .mark .parametrize ("data_dt" , _all_dtypes )
1896
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
1897
+ def test_take_out (data_dt , order ):
1898
+ q = get_queue_or_skip ()
1899
+ skip_if_dtype_not_supported (data_dt , q )
1900
+
1901
+ axis = 0
1902
+ x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
1903
+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
1904
+ out_sh = x .shape [:axis ] + ind .shape + x .shape [axis + 1 :]
1905
+ out = dpt .empty (out_sh , dtype = data_dt , sycl_queue = q )
1906
+
1907
+ expected = dpt .take (x , ind , axis = axis )
1908
+
1909
+ dpt .take (x , ind , axis = axis , out = out )
1910
+
1911
+ assert dpt .all (out == expected )
1912
+
1913
+
1914
+ @pytest .mark .parametrize ("data_dt" , _all_dtypes )
1915
+ @pytest .mark .parametrize ("order" , ["C" , "F" ])
1916
+ def test_take_out_overlap (data_dt , order ):
1917
+ q = get_queue_or_skip ()
1918
+ skip_if_dtype_not_supported (data_dt , q )
1919
+
1920
+ axis = 0
1921
+ x = dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
1922
+ ind = dpt .arange (2 , dtype = "i8" , sycl_queue = q )
1923
+ out = x [x .shape [axis ] - ind .shape [axis ] : x .shape [axis ], :]
1924
+
1925
+ expected = dpt .take (x , ind , axis = axis )
1926
+
1927
+ dpt .take (x , ind , axis = axis , out = out )
1928
+
1929
+ assert dpt .all (out == expected )
1930
+ assert dpt .all (x [x .shape [0 ] - ind .shape [0 ] : x .shape [0 ], :] == out )
1931
+
1932
+
1933
+ def test_take_out_errors ():
1934
+ q1 = get_queue_or_skip ()
1935
+ q2 = get_queue_or_skip ()
1936
+
1937
+ x = dpt .arange (10 , dtype = "i4" , sycl_queue = q1 )
1938
+ ind = dpt .arange (2 , dtype = "i4" , sycl_queue = q1 )
1939
+
1940
+ with pytest .raises (TypeError ):
1941
+ dpt .take (x , ind , out = dict ())
1942
+
1943
+ out_read_only = dpt .empty (ind .shape , dtype = x .dtype , sycl_queue = q1 )
1944
+ out_read_only .flags ["W" ] = False
1945
+ with pytest .raises (ValueError ):
1946
+ dpt .take (x , ind , out = out_read_only )
1947
+
1948
+ out_bad_shape = dpt .empty (0 , dtype = x .dtype , sycl_queue = q1 )
1949
+ with pytest .raises (ValueError ):
1950
+ dpt .take (x , ind , out = out_bad_shape )
1951
+
1952
+ out_bad_dt = dpt .empty (ind .shape , dtype = "i8" , sycl_queue = q1 )
1953
+ with pytest .raises (ValueError ):
1954
+ dpt .take (x , ind , out = out_bad_dt )
1955
+
1956
+ out_bad_q = dpt .empty (ind .shape , dtype = x .dtype , sycl_queue = q2 )
1957
+ with pytest .raises (dpctl .utils .ExecutionPlacementError ):
1958
+ dpt .take (x , ind , out = out_bad_q )
0 commit comments