@@ -502,7 +502,6 @@ def index_put_converter(
502
502
F = [i for i in range (rank ) if indices [i ] is None ] # Free dimensions
503
503
I = [i for i in range (rank ) if indices [i ] is not None ] # Indexed dimensions
504
504
K = len (I )
505
-
506
505
# Determine the maximum size 'N' among the index tensors
507
506
if K > 0 :
508
507
index_shapes = [tensor .shape [0 ] for tensor in indices if tensor is not None ]
@@ -685,16 +684,6 @@ def index_put_converter(
685
684
values_reshaped = impl .shuffle .reshape (
686
685
ctx , target , source_ir , f"{ name } _reshape_scalar" , values , (1 ,)
687
686
)
688
- num_dims = len (expected_shape )
689
- ones_shape = tuple ([1 ] * num_dims )
690
- values_reshaped = impl .shuffle .reshape (
691
- ctx ,
692
- target ,
693
- source_ir ,
694
- f"{ name } _reshape_to_ones" ,
695
- values_reshaped ,
696
- ones_shape ,
697
- )
698
687
values_expanded = impl .slice .expand (
699
688
ctx ,
700
689
target ,
@@ -705,40 +694,79 @@ def index_put_converter(
705
694
)
706
695
else : # Non-scalar case
707
696
values_shape = list (values .shape )
708
-
709
- # Pad dimensions if necessary
710
- if len (values_shape ) < len (expected_shape ):
711
- values_shape = [1 ] * (
712
- len (expected_shape ) - len (values_shape )
713
- ) + values_shape
714
-
715
- # Calculate a broadcastable shape
716
- broadcast_shape = []
717
- for exp_dim , val_dim in zip (expected_shape , values_shape ):
718
- if val_dim == 1 :
719
- broadcast_shape .append (exp_dim )
720
- elif val_dim == exp_dim :
721
- broadcast_shape .append (val_dim )
697
+ if K > 0 and N in values_shape :
698
+ n_idx = values_shape .index (N )
699
+ permute_order = [n_idx ] + [
700
+ i for i in range (len (values_shape )) if i != n_idx
701
+ ]
702
+ values_permuted = impl .permutation .permute (
703
+ ctx , target , source_ir , f"{ name } _permute_values" , values , permute_order
704
+ )
705
+ remaining_shape = [
706
+ values_shape [i ] for i in range (len (values_shape )) if i != n_idx
707
+ ]
708
+ target_f_dims = len (F )
709
+ current_f_dims = len (remaining_shape )
710
+ if current_f_dims < target_f_dims :
711
+ values_expanded_shape = (
712
+ [N ] + [1 ] * (target_f_dims - current_f_dims ) + remaining_shape
713
+ )
722
714
else :
723
- raise ValueError (f"Cannot broadcast { values_shape } to { expected_shape } " )
724
-
725
- # Reshape and then expand
726
- values_reshaped = impl .shuffle .reshape (
727
- ctx ,
728
- target ,
729
- source_ir ,
730
- f"{ name } _reshape_values" ,
731
- values ,
732
- tuple (broadcast_shape ),
733
- )
734
- values_expanded = impl .slice .expand (
735
- ctx ,
736
- target ,
737
- source_ir ,
738
- f"{ name } _expand_values" ,
739
- values_reshaped ,
740
- expected_shape ,
741
- )
715
+ values_expanded_shape = [N ] + remaining_shape [:target_f_dims ]
716
+ values_expanded = impl .shuffle .reshape (
717
+ ctx ,
718
+ target ,
719
+ source_ir ,
720
+ f"{ name } _unsqueeze_values" ,
721
+ values_permuted ,
722
+ tuple (values_expanded_shape ),
723
+ )
724
+ broadcast_shape = []
725
+ for exp_dim , val_dim in zip (expected_shape , values_expanded_shape ):
726
+ if val_dim == 1 :
727
+ broadcast_shape .append (exp_dim )
728
+ elif val_dim == exp_dim :
729
+ broadcast_shape .append (val_dim )
730
+ else :
731
+ raise ValueError (
732
+ f"Cannot broadcast { values_expanded_shape } to { expected_shape } "
733
+ )
734
+ values_expanded = impl .slice .expand (
735
+ ctx ,
736
+ target ,
737
+ source_ir ,
738
+ f"{ name } _expand_values" ,
739
+ values_expanded ,
740
+ tuple (broadcast_shape ),
741
+ )
742
+ else :
743
+ values_shape_padded = [1 ] * (
744
+ len (expected_shape ) - len (values .shape )
745
+ ) + list (values .shape )
746
+ broadcast_shape = []
747
+ for exp_dim , val_dim in zip (expected_shape , values_shape_padded ):
748
+ if val_dim == 1 or exp_dim == val_dim :
749
+ broadcast_shape .append (exp_dim )
750
+ else :
751
+ raise ValueError (
752
+ f"Cannot broadcast { values .shape } to { expected_shape } "
753
+ )
754
+ values_reshaped = impl .shuffle .reshape (
755
+ ctx ,
756
+ target ,
757
+ source_ir ,
758
+ f"{ name } _reshape_values" ,
759
+ values ,
760
+ tuple (broadcast_shape ),
761
+ )
762
+ values_expanded = impl .slice .expand (
763
+ ctx ,
764
+ target ,
765
+ source_ir ,
766
+ f"{ name } _expand_values" ,
767
+ values_reshaped ,
768
+ expected_shape ,
769
+ )
742
770
743
771
# Flatten values to (N * F_volume,)
744
772
flattened_values = impl .shuffle .reshape (
@@ -750,6 +778,7 @@ def index_put_converter(
750
778
(N * F_volume ,),
751
779
)
752
780
781
+ indices_cat = cast_trt_tensor (ctx , indices_cat , trt .int32 , f"{ name } _idx_int32" )
753
782
# Perform Scatter ND operation
754
783
scatter_layer = ctx .net .add_scatter (
755
784
input_tensor ,
0 commit comments