@@ -763,44 +763,66 @@ def _nonzero_impl(ary):
763
763
764
764
def _take_multi_index (ary , inds , p ):
765
765
if not isinstance (ary , dpt .usm_ndarray ):
766
- raise TypeError
766
+ raise TypeError (
767
+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
768
+ )
769
+ ary_nd = ary .ndim
770
+ p = normalize_axis_index (operator .index (p ), ary_nd )
767
771
queues_ = [
768
772
ary .sycl_queue ,
769
773
]
770
774
usm_types_ = [
771
775
ary .usm_type ,
772
776
]
773
- if not isinstance (inds , list ) and not isinstance ( inds , tuple ):
777
+ if not isinstance (inds , ( list , tuple ) ):
774
778
inds = (inds ,)
775
- all_integers = True
776
779
for ind in inds :
780
+ if not isinstance (ind , dpt .usm_ndarray ):
781
+ raise TypeError ("all elements of `ind` expected to be usm_ndarrays" )
777
782
queues_ .append (ind .sycl_queue )
778
783
usm_types_ .append (ind .usm_type )
779
- if all_integers :
780
- all_integers = ind .dtype .kind in "ui"
784
+ if ind .dtype .kind not in "ui" :
785
+ raise IndexError (
786
+ "arrays used as indices must be of integer (or boolean) type"
787
+ )
788
+ res_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
781
789
exec_q = dpctl .utils .get_execution_queue (queues_ )
782
790
if exec_q is None :
783
- raise dpctl .utils .ExecutionPlacementError ("" )
784
- if not all_integers :
785
- raise IndexError (
786
- "arrays used as indices must be of integer (or boolean) type"
791
+ raise dpctl .utils .ExecutionPlacementError (
792
+ "Can not automatically determine where to allocate the "
793
+ "result or performance execution. "
794
+ "Use `usm_ndarray.to_device` method to migrate data to "
795
+ "be associated with the same queue."
787
796
)
788
797
if len (inds ) > 1 :
798
+ ind_dt = dpt .result_type (* inds )
799
+ # ind arrays have been checked to be of integer dtype
800
+ if ind_dt .kind not in "ui" :
801
+ raise ValueError (
802
+ "cannot safely promote indices to an integer data type"
803
+ )
804
+ inds = tuple (
805
+ map (
806
+ lambda ind : ind
807
+ if ind .dtype == ind_dt
808
+ else dpt .astype (ind , ind_dt ),
809
+ inds ,
810
+ )
811
+ )
789
812
inds = dpt .broadcast_arrays (* inds )
790
- ary_ndim = ary .ndim
791
- p = normalize_axis_index (operator .index (p ), ary_ndim )
792
-
793
- res_shape = ary .shape [:p ] + inds [0 ].shape + ary .shape [p + len (inds ) :]
794
- res_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
813
+ ind0 = inds [0 ]
814
+ ary_sh = ary .shape
815
+ p_end = p + len (inds )
816
+ if 0 in ary_sh [p :p_end ] and ind0 .size != 0 :
817
+ raise IndexError ("cannot take non-empty indices from an empty axis" )
818
+ res_shape = ary_sh [:p ] + ind0 .shape + ary_sh [p_end :]
795
819
res = dpt .empty (
796
820
res_shape , dtype = ary .dtype , usm_type = res_usm_type , sycl_queue = exec_q
797
821
)
798
-
799
822
hev , _ = ti ._take (
800
823
src = ary , ind = inds , dst = res , axis_start = p , mode = 0 , sycl_queue = exec_q
801
824
)
802
825
hev .wait ()
803
-
804
826
return res
805
827
806
828
@@ -864,6 +886,12 @@ def _place_impl(ary, ary_mask, vals, axis=0):
864
886
865
887
866
888
def _put_multi_index (ary , inds , p , vals ):
889
+ if not isinstance (ary , dpt .usm_ndarray ):
890
+ raise TypeError (
891
+ f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
892
+ )
893
+ ary_nd = ary .ndim
894
+ p = normalize_axis_index (operator .index (p ), ary_nd )
867
895
if isinstance (vals , dpt .usm_ndarray ):
868
896
queues_ = [ary .sycl_queue , vals .sycl_queue ]
869
897
usm_types_ = [ary .usm_type , vals .usm_type ]
@@ -874,46 +902,64 @@ def _put_multi_index(ary, inds, p, vals):
874
902
usm_types_ = [
875
903
ary .usm_type ,
876
904
]
877
- if not isinstance (inds , list ) and not isinstance ( inds , tuple ):
905
+ if not isinstance (inds , ( list , tuple ) ):
878
906
inds = (inds ,)
879
- all_integers = True
880
907
for ind in inds :
881
908
if not isinstance (ind , dpt .usm_ndarray ):
882
- raise TypeError
909
+ raise TypeError ( "all elements of `ind` expected to be usm_ndarrays" )
883
910
queues_ .append (ind .sycl_queue )
884
911
usm_types_ .append (ind .usm_type )
885
- if all_integers :
886
- all_integers = ind .dtype .kind in "ui"
912
+ if ind .dtype .kind not in "ui" :
913
+ raise IndexError (
914
+ "arrays used as indices must be of integer (or boolean) type"
915
+ )
916
+ vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
887
917
exec_q = dpctl .utils .get_execution_queue (queues_ )
918
+ if exec_q is not None :
919
+ if not isinstance (vals , dpt .usm_ndarray ):
920
+ vals = dpt .asarray (
921
+ vals , dtype = ary .dtype , usm_type = vals_usm_type , sycl_queue = exec_q
922
+ )
923
+ else :
924
+ exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
888
925
if exec_q is None :
889
926
raise dpctl .utils .ExecutionPlacementError (
890
927
"Can not automatically determine where to allocate the "
891
928
"result or performance execution. "
892
929
"Use `usm_ndarray.to_device` method to migrate data to "
893
930
"be associated with the same queue."
894
931
)
895
- if not all_integers :
896
- raise IndexError (
897
- "arrays used as indices must be of integer (or boolean) type"
898
- )
899
932
if len (inds ) > 1 :
933
+ ind_dt = dpt .result_type (* inds )
934
+ # ind arrays have been checked to be of integer dtype
935
+ if ind_dt .kind not in "ui" :
936
+ raise ValueError (
937
+ "cannot safely promote indices to an integer data type"
938
+ )
939
+ inds = tuple (
940
+ map (
941
+ lambda ind : ind
942
+ if ind .dtype == ind_dt
943
+ else dpt .astype (ind , ind_dt ),
944
+ inds ,
945
+ )
946
+ )
900
947
inds = dpt .broadcast_arrays (* inds )
901
- ary_ndim = ary .ndim
902
-
903
- p = normalize_axis_index (operator .index (p ), ary_ndim )
904
- vals_shape = ary .shape [:p ] + inds [0 ].shape + ary .shape [p + len (inds ) :]
905
-
906
- vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
907
- if not isinstance (vals , dpt .usm_ndarray ):
908
- vals = dpt .asarray (
909
- vals , dtype = ary .dtype , usm_type = vals_usm_type , sycl_queue = exec_q
948
+ ind0 = inds [0 ]
949
+ ary_sh = ary .shape
950
+ p_end = p + len (inds )
951
+ if 0 in ary_sh [p :p_end ] and ind0 .size != 0 :
952
+ raise IndexError (
953
+ "cannot put into non-empty indices along an empty axis"
910
954
)
911
-
912
- vals = dpt .broadcast_to (vals , vals_shape )
913
-
955
+ expected_vals_shape = ary_sh [:p ] + ind0 .shape + ary_sh [p_end :]
956
+ if vals .dtype == ary .dtype :
957
+ rhs = vals
958
+ else :
959
+ rhs = dpt .astype (vals , ary .dtype )
960
+ rhs = dpt .broadcast_to (rhs , expected_vals_shape )
914
961
hev , _ = ti ._put (
915
- dst = ary , ind = inds , val = vals , axis_start = p , mode = 0 , sycl_queue = exec_q
962
+ dst = ary , ind = inds , val = rhs , axis_start = p , mode = 0 , sycl_queue = exec_q
916
963
)
917
964
hev .wait ()
918
-
919
965
return
0 commit comments