@@ -502,7 +502,6 @@ def index_put_converter(
502502 F = [i for i in range (rank ) if indices [i ] is None ] # Free dimensions
503503 I = [i for i in range (rank ) if indices [i ] is not None ] # Indexed dimensions
504504 K = len (I )
505-
506505 # Determine the maximum size 'N' among the index tensors
507506 if K > 0 :
508507 index_shapes = [tensor .shape [0 ] for tensor in indices if tensor is not None ]
@@ -685,16 +684,6 @@ def index_put_converter(
685684 values_reshaped = impl .shuffle .reshape (
686685 ctx , target , source_ir , f"{ name } _reshape_scalar" , values , (1 ,)
687686 )
688- num_dims = len (expected_shape )
689- ones_shape = tuple ([1 ] * num_dims )
690- values_reshaped = impl .shuffle .reshape (
691- ctx ,
692- target ,
693- source_ir ,
694- f"{ name } _reshape_to_ones" ,
695- values_reshaped ,
696- ones_shape ,
697- )
698687 values_expanded = impl .slice .expand (
699688 ctx ,
700689 target ,
@@ -705,40 +694,79 @@ def index_put_converter(
705694 )
706695 else : # Non-scalar case
707696 values_shape = list (values .shape )
708-
709- # Pad dimensions if necessary
710- if len (values_shape ) < len (expected_shape ):
711- values_shape = [1 ] * (
712- len (expected_shape ) - len (values_shape )
713- ) + values_shape
714-
715- # Calculate a broadcastable shape
716- broadcast_shape = []
717- for exp_dim , val_dim in zip (expected_shape , values_shape ):
718- if val_dim == 1 :
719- broadcast_shape .append (exp_dim )
720- elif val_dim == exp_dim :
721- broadcast_shape .append (val_dim )
697+ if K > 0 and N in values_shape :
698+ n_idx = values_shape .index (N )
699+ permute_order = [n_idx ] + [
700+ i for i in range (len (values_shape )) if i != n_idx
701+ ]
702+ values_permuted = impl .permutation .permute (
703+ ctx , target , source_ir , f"{ name } _permute_values" , values , permute_order
704+ )
705+ remaining_shape = [
706+ values_shape [i ] for i in range (len (values_shape )) if i != n_idx
707+ ]
708+ target_f_dims = len (F )
709+ current_f_dims = len (remaining_shape )
710+ if current_f_dims < target_f_dims :
711+ values_expanded_shape = (
712+ [N ] + [1 ] * (target_f_dims - current_f_dims ) + remaining_shape
713+ )
722714 else :
723- raise ValueError (f"Cannot broadcast { values_shape } to { expected_shape } " )
724-
725- # Reshape and then expand
726- values_reshaped = impl .shuffle .reshape (
727- ctx ,
728- target ,
729- source_ir ,
730- f"{ name } _reshape_values" ,
731- values ,
732- tuple (broadcast_shape ),
733- )
734- values_expanded = impl .slice .expand (
735- ctx ,
736- target ,
737- source_ir ,
738- f"{ name } _expand_values" ,
739- values_reshaped ,
740- expected_shape ,
741- )
715+ values_expanded_shape = [N ] + remaining_shape [:target_f_dims ]
716+ values_expanded = impl .shuffle .reshape (
717+ ctx ,
718+ target ,
719+ source_ir ,
720+ f"{ name } _unsqueeze_values" ,
721+ values_permuted ,
722+ tuple (values_expanded_shape ),
723+ )
724+ broadcast_shape = []
725+ for exp_dim , val_dim in zip (expected_shape , values_expanded_shape ):
726+ if val_dim == 1 :
727+ broadcast_shape .append (exp_dim )
728+ elif val_dim == exp_dim :
729+ broadcast_shape .append (val_dim )
730+ else :
731+ raise ValueError (
732+ f"Cannot broadcast { values_expanded_shape } to { expected_shape } "
733+ )
734+ values_expanded = impl .slice .expand (
735+ ctx ,
736+ target ,
737+ source_ir ,
738+ f"{ name } _expand_values" ,
739+ values_expanded ,
740+ tuple (broadcast_shape ),
741+ )
742+ else :
743+ values_shape_padded = [1 ] * (
744+ len (expected_shape ) - len (values .shape )
745+ ) + list (values .shape )
746+ broadcast_shape = []
747+ for exp_dim , val_dim in zip (expected_shape , values_shape_padded ):
748+ if val_dim == 1 or exp_dim == val_dim :
749+ broadcast_shape .append (exp_dim )
750+ else :
751+ raise ValueError (
752+ f"Cannot broadcast { values .shape } to { expected_shape } "
753+ )
754+ values_reshaped = impl .shuffle .reshape (
755+ ctx ,
756+ target ,
757+ source_ir ,
758+ f"{ name } _reshape_values" ,
759+ values ,
760+ tuple (broadcast_shape ),
761+ )
762+ values_expanded = impl .slice .expand (
763+ ctx ,
764+ target ,
765+ source_ir ,
766+ f"{ name } _expand_values" ,
767+ values_reshaped ,
768+ expected_shape ,
769+ )
742770
743771 # Flatten values to (N * F_volume,)
744772 flattened_values = impl .shuffle .reshape (
@@ -750,6 +778,7 @@ def index_put_converter(
750778 (N * F_volume ,),
751779 )
752780
781+ indices_cat = cast_trt_tensor (ctx , indices_cat , trt .int32 , f"{ name } _idx_int32" )
753782 # Perform Scatter ND operation
754783 scatter_layer = ctx .net .add_scatter (
755784 input_tensor ,
0 commit comments