@@ -625,7 +625,7 @@ def test_put_0d_val(data_dt):
625625    skip_if_dtype_not_supported (data_dt , q )
626626
627627    x  =  dpt .arange (5 , dtype = data_dt , sycl_queue = q )
628-     ind  =  dpt .asarray ([0 ], dtype = np . intp , sycl_queue = q )
628+     ind  =  dpt .asarray ([0 ], dtype = "i8" , sycl_queue = q )
629629    val  =  dpt .asarray (2 , dtype = x .dtype , sycl_queue = q )
630630    x [ind ] =  val 
631631    assert_array_equal (np .asarray (2 , dtype = data_dt ), dpt .asnumpy (x [0 ]))
@@ -644,7 +644,7 @@ def test_take_0d_data(data_dt):
644644    skip_if_dtype_not_supported (data_dt , q )
645645
646646    x  =  dpt .asarray (0 , dtype = data_dt , sycl_queue = q )
647-     ind  =  dpt .arange (5 , dtype = np . intp , sycl_queue = q )
647+     ind  =  dpt .arange (5 , dtype = "i8" , sycl_queue = q )
648648
649649    y  =  dpt .take (x , ind )
650650    assert  (
@@ -662,7 +662,7 @@ def test_put_0d_data(data_dt):
662662    skip_if_dtype_not_supported (data_dt , q )
663663
664664    x  =  dpt .asarray (0 , dtype = data_dt , sycl_queue = q )
665-     ind  =  dpt .arange (5 , dtype = np . intp , sycl_queue = q )
665+     ind  =  dpt .arange (5 , dtype = "i8" , sycl_queue = q )
666666    val  =  dpt .asarray (2 , dtype = data_dt , sycl_queue = q )
667667
668668    dpt .put (x , ind , val , axis = 0 )
@@ -710,7 +710,7 @@ def test_take_strided_1d_source(data_dt):
710710    skip_if_dtype_not_supported (data_dt , q )
711711
712712    x  =  dpt .arange (27 , dtype = data_dt , sycl_queue = q )
713-     ind  =  dpt .arange (4 , 9 , dtype = np . intp , sycl_queue = q )
713+     ind  =  dpt .arange (4 , 9 , dtype = "i8" , sycl_queue = q )
714714
715715    x_np  =  dpt .asnumpy (x )
716716    ind_np  =  dpt .asnumpy (ind )
@@ -748,7 +748,7 @@ def test_take_strided(data_dt, order):
748748    skip_if_dtype_not_supported (data_dt , q )
749749
750750    x  =  dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
751-     ind  =  dpt .arange (2 , dtype = np . intp , sycl_queue = q )
751+     ind  =  dpt .arange (2 , dtype = "i8" , sycl_queue = q )
752752
753753    x_np  =  dpt .asnumpy (x )
754754    ind_np  =  dpt .asnumpy (ind )
@@ -781,7 +781,7 @@ def test_take_strided_1d_indices(ind_dt):
781781    ind  =  dpt .arange (12 , 24 , dtype = ind_dt , sycl_queue = q )
782782
783783    x_np  =  dpt .asnumpy (x )
784-     ind_np  =  dpt .asnumpy (ind ).astype (np . intp )
784+     ind_np  =  dpt .asnumpy (ind ).astype ("i8" )
785785
786786    for  s  in  (
787787        slice (None , None , 2 ),
@@ -820,7 +820,7 @@ def test_take_strided_indices(ind_dt, order):
820820    )
821821
822822    x_np  =  dpt .asnumpy (x )
823-     ind_np  =  dpt .asnumpy (ind ).astype (np . intp )
823+     ind_np  =  dpt .asnumpy (ind ).astype ("i8" )
824824
825825    for  s  in  (
826826        slice (None , None , 2 ),
@@ -845,7 +845,7 @@ def test_put_strided_1d_destination(data_dt, order):
845845    skip_if_dtype_not_supported (data_dt , q )
846846
847847    x  =  dpt .arange (27 , dtype = data_dt , sycl_queue = q )
848-     ind  =  dpt .arange (4 , 9 , dtype = np . intp , sycl_queue = q )
848+     ind  =  dpt .arange (4 , 9 , dtype = "i8" , sycl_queue = q )
849849    val  =  dpt .asarray (9 , dtype = x .dtype , sycl_queue = q )
850850
851851    x_np  =  dpt .asnumpy (x )
@@ -875,7 +875,7 @@ def test_put_strided_destination(data_dt, order):
875875    skip_if_dtype_not_supported (data_dt , q )
876876
877877    x  =  dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
878-     ind  =  dpt .arange (2 , dtype = np . intp , sycl_queue = q )
878+     ind  =  dpt .arange (2 , dtype = "i8" , sycl_queue = q )
879879    val  =  dpt .asarray (9 , dtype = x .dtype , sycl_queue = q )
880880
881881    x_np  =  dpt .asnumpy (x )
@@ -924,7 +924,7 @@ def test_put_strided_1d_indices(ind_dt):
924924    val  =  dpt .asarray (- 1 , dtype = x .dtype , sycl_queue = q )
925925
926926    x_np  =  dpt .asnumpy (x )
927-     ind_np  =  dpt .asnumpy (ind ).astype (np . intp )
927+     ind_np  =  dpt .asnumpy (ind ).astype ("i8" )
928928    val_np  =  dpt .asnumpy (val )
929929
930930    for  s  in  (
@@ -955,7 +955,7 @@ def test_put_strided_indices(ind_dt, order):
955955    val  =  dpt .asarray (- 1 , sycl_queue = q , dtype = x .dtype )
956956
957957    x_np  =  dpt .asnumpy (x )
958-     ind_np  =  dpt .asnumpy (ind ).astype (np . intp )
958+     ind_np  =  dpt .asnumpy (ind ).astype ("i8" )
959959    val_np  =  dpt .asnumpy (val )
960960
961961    for  s  in  (
@@ -982,15 +982,15 @@ def test_integer_indexing_modes():
982982    x_np  =  dpt .asnumpy (x )
983983
984984    # wrapping negative indices 
985-     ind  =  dpt .asarray ([- 4 , - 3 , 0 , 2 , 4 ], dtype = np . intp , sycl_queue = q )
985+     ind  =  dpt .asarray ([- 4 , - 3 , 0 , 2 , 4 ], dtype = "i8" , sycl_queue = q )
986986
987987    res  =  dpt .take (x , ind , mode = "wrap" )
988988    expected_arr  =  np .take (x_np , dpt .asnumpy (ind ), mode = "raise" )
989989
990990    assert  (dpt .asnumpy (res ) ==  expected_arr ).all ()
991991
992992    # clipping to 0 (disabling negative indices) 
993-     ind  =  dpt .asarray ([- 6 , - 3 , 0 , 2 , 6 ], dtype = np . intp , sycl_queue = q )
993+     ind  =  dpt .asarray ([- 6 , - 3 , 0 , 2 , 6 ], dtype = "i8" , sycl_queue = q )
994994
995995    res  =  dpt .take (x , ind , mode = "clip" )
996996    expected_arr  =  np .take (x_np , dpt .asnumpy (ind ), mode = "clip" )
@@ -1002,7 +1002,7 @@ def test_take_arg_validation():
10021002    q  =  get_queue_or_skip ()
10031003
10041004    x  =  dpt .arange (4 , dtype = "i4" , sycl_queue = q )
1005-     ind0  =  dpt .arange (4 , dtype = np . intp , sycl_queue = q )
1005+     ind0  =  dpt .arange (4 , dtype = "i8" , sycl_queue = q )
10061006    ind1  =  dpt .arange (2.0 , dtype = "f" , sycl_queue = q )
10071007
10081008    with  pytest .raises (TypeError ):
@@ -1034,7 +1034,7 @@ def test_put_arg_validation():
10341034    q  =  get_queue_or_skip ()
10351035
10361036    x  =  dpt .arange (4 , dtype = "i4" , sycl_queue = q )
1037-     ind0  =  dpt .arange (4 , dtype = np . intp , sycl_queue = q )
1037+     ind0  =  dpt .arange (4 , dtype = "i8" , sycl_queue = q )
10381038    ind1  =  dpt .arange (2.0 , dtype = "f" , sycl_queue = q )
10391039    val  =  dpt .asarray (2 , dtype = x .dtype , sycl_queue = q )
10401040
@@ -1890,3 +1890,69 @@ def test_put_along_axis_uint64_indices():
18901890    dpt .put_along_axis (x , inds , dpt .asarray (2 , dtype = x .dtype ), axis = 1 )
18911891    expected  =  dpt .tile (dpt .asarray ([0 , 2 ], dtype = "i4" ), (2 , 5 ))
18921892    assert  dpt .all (expected  ==  x )
1893+ 
1894+ 
1895+ @pytest .mark .parametrize ("data_dt" , _all_dtypes ) 
1896+ @pytest .mark .parametrize ("order" , ["C" , "F" ]) 
1897+ def  test_take_out (data_dt , order ):
1898+     q  =  get_queue_or_skip ()
1899+     skip_if_dtype_not_supported (data_dt , q )
1900+ 
1901+     axis  =  0 
1902+     x  =  dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
1903+     ind  =  dpt .arange (2 , dtype = "i8" , sycl_queue = q )
1904+     out_sh  =  x .shape [:axis ] +  ind .shape  +  x .shape [axis  +  1  :]
1905+     out  =  dpt .empty (out_sh , dtype = data_dt , sycl_queue = q )
1906+ 
1907+     expected  =  dpt .take (x , ind , axis = axis )
1908+ 
1909+     dpt .take (x , ind , axis = axis , out = out )
1910+ 
1911+     assert  dpt .all (out  ==  expected )
1912+ 
1913+ 
1914+ @pytest .mark .parametrize ("data_dt" , _all_dtypes ) 
1915+ @pytest .mark .parametrize ("order" , ["C" , "F" ]) 
1916+ def  test_take_out_overlap (data_dt , order ):
1917+     q  =  get_queue_or_skip ()
1918+     skip_if_dtype_not_supported (data_dt , q )
1919+ 
1920+     axis  =  0 
1921+     x  =  dpt .reshape (_make_3d (data_dt , q ), (9 , 3 ), order = order )
1922+     ind  =  dpt .arange (2 , dtype = "i8" , sycl_queue = q )
1923+     out  =  x [x .shape [axis ] -  ind .shape [axis ] : x .shape [axis ], :]
1924+ 
1925+     expected  =  dpt .take (x , ind , axis = axis )
1926+ 
1927+     dpt .take (x , ind , axis = axis , out = out )
1928+ 
1929+     assert  dpt .all (out  ==  expected )
1930+     assert  dpt .all (x [x .shape [0 ] -  ind .shape [0 ] : x .shape [0 ], :] ==  out )
1931+ 
1932+ 
1933+ def  test_take_out_errors ():
1934+     q1  =  get_queue_or_skip ()
1935+     q2  =  get_queue_or_skip ()
1936+ 
1937+     x  =  dpt .arange (10 , dtype = "i4" , sycl_queue = q1 )
1938+     ind  =  dpt .arange (2 , dtype = "i4" , sycl_queue = q1 )
1939+ 
1940+     with  pytest .raises (TypeError ):
1941+         dpt .take (x , ind , out = dict ())
1942+ 
1943+     out_read_only  =  dpt .empty (ind .shape , dtype = x .dtype , sycl_queue = q1 )
1944+     out_read_only .flags ["W" ] =  False 
1945+     with  pytest .raises (ValueError ):
1946+         dpt .take (x , ind , out = out_read_only )
1947+ 
1948+     out_bad_shape  =  dpt .empty (0 , dtype = x .dtype , sycl_queue = q1 )
1949+     with  pytest .raises (ValueError ):
1950+         dpt .take (x , ind , out = out_bad_shape )
1951+ 
1952+     out_bad_dt  =  dpt .empty (ind .shape , dtype = "i8" , sycl_queue = q1 )
1953+     with  pytest .raises (ValueError ):
1954+         dpt .take (x , ind , out = out_bad_dt )
1955+ 
1956+     out_bad_q  =  dpt .empty (ind .shape , dtype = x .dtype , sycl_queue = q2 )
1957+     with  pytest .raises (dpctl .utils .ExecutionPlacementError ):
1958+         dpt .take (x , ind , out = out_bad_q )
0 commit comments