@@ -181,15 +181,21 @@ def transact_cache_direct(
181
181
return xk_cache_update , xv_cache_update
182
182
else :
183
183
# Decode. Write a single timestep.
184
- # TODO: This needs to be reworked with index ops.
185
184
assert xk_cache_update .shape [1 ] == 1
186
185
assert xv_cache_update .shape [1 ] == 1
187
- max_start_pos = 0
188
- for row_index in range (bs ):
189
- row_start_pos = start_positions [row_index ].item ()
190
- max_start_pos = max (row_start_pos , max_start_pos )
191
- cache_k [row_index , row_start_pos ] = xk_cache_update [row_index , 0 ]
192
- cache_v [row_index , row_start_pos ] = xv_cache_update [row_index , 0 ]
186
+ for b in range (bs ):
187
+ # Make a tensor because indices must be all tensors, so we can avoid
188
+ # doing start_positions[row_index].item(), which generates a lot of SymInts.
189
+ row_index = torch .tensor (
190
+ b , dtype = torch .int64 , device = xk_cache_update .device
191
+ )
192
+ row_start_pos = start_positions [row_index ]
193
+ cache_k .index_put (
194
+ (row_index , row_start_pos ), xk_cache_update [row_index , 0 ]
195
+ )
196
+ cache_v .index_put (
197
+ (row_index , row_start_pos ), xv_cache_update [row_index , 0 ]
198
+ )
193
199
return cache_k [:, :kv_seq_len ], cache_v [:, :kv_seq_len ]
194
200
195
201
def transact_cache_paged (
0 commit comments