@@ -10835,29 +10835,33 @@ def replace(tup, val):
10835
10835
collapsed_slice_dims = []
10836
10836
operand_batching_dims = []
10837
10837
start_indices_batching_dims = []
10838
+
10839
+ # We will squeeze the array. i is the index of the unsqueezed shape, while
10840
+ # new_i is the index of the squeezed shape. j is the index of the gather
10841
+ # indices.
10842
+ dims_to_squeeze = []
10843
+ new_i = 0
10838
10844
j = 0
10839
10845
for i in range (rank ):
10840
10846
if i == axis_int :
10841
10847
if mode != 'promise_in_bounds' :
10842
10848
indices = _normalize_index (indices , axis_size )
10843
10849
gather_indices .append (lax .reshape (indices , gather_index_shape ))
10844
10850
slice_sizes .append (1 )
10845
- start_index_map .append (i )
10846
- collapsed_slice_dims .append (i )
10851
+ start_index_map .append (new_i )
10852
+ collapsed_slice_dims .append (new_i )
10853
+ new_i += 1
10847
10854
j += 1
10848
10855
elif core .definitely_equal (idx_shape [i ], 1 ):
10849
10856
# If idx_shape[i] == 1, we can just take the entirety of the arr's axis
10850
10857
# and avoid forming an iota index.
10851
10858
offset_dims .append (i )
10852
10859
slice_sizes .append (arr_shape [i ])
10860
+ new_i += 1
10853
10861
elif core .definitely_equal (arr_shape [i ], 1 ):
10854
- # If the array dimension is 1 but the index dimension is not, we
10855
- # broadcast the array dimension to the index dimension by repeatedly
10856
- # gathering the first element.
10857
- gather_indices .append (zeros (gather_index_shape , dtype = index_dtype ))
10858
- slice_sizes .append (1 )
10859
- start_index_map .append (i )
10860
- collapsed_slice_dims .append (i )
10862
+ # If the array dimension is 1 but the index dimension is not, we will
10863
+ # squeeze this dimension.
10864
+ dims_to_squeeze .append (i )
10861
10865
j += 1
10862
10866
else :
10863
10867
# Otherwise, idx_shape[i] == arr_shape[i]. Mark the dimensions in both
@@ -10866,10 +10870,13 @@ def replace(tup, val):
10866
10870
slice_sizes .append (0 )
10867
10871
else :
10868
10872
slice_sizes .append (1 )
10869
- operand_batching_dims .append (i )
10873
+ operand_batching_dims .append (new_i )
10870
10874
start_indices_batching_dims .append (j )
10875
+ new_i += 1
10871
10876
j += 1
10872
10877
10878
+ # Squeeze a to remove singleton dimensions.
10879
+ a = lax .squeeze (a , dims_to_squeeze )
10873
10880
gather_indices_arr = lax .concatenate (gather_indices , dimension = j )
10874
10881
dnums = lax .GatherDimensionNumbers (
10875
10882
offset_dims = tuple (offset_dims ),
0 commit comments