Skip to content

Commit e52b086

Browse files
authored
replaces some logic with index ops (#176)
fixes bs1/direct cache exports.
1 parent 4957412 commit e52b086

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

sharktank/sharktank/layers/paged_llama_attention_block.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,21 @@ def transact_cache_direct(
181181
return xk_cache_update, xv_cache_update
182182
else:
183183
# Decode. Write a single timestep.
184-
# TODO: This needs to be reworked with index ops.
185184
assert xk_cache_update.shape[1] == 1
186185
assert xv_cache_update.shape[1] == 1
187-
max_start_pos = 0
188-
for row_index in range(bs):
189-
row_start_pos = start_positions[row_index].item()
190-
max_start_pos = max(row_start_pos, max_start_pos)
191-
cache_k[row_index, row_start_pos] = xk_cache_update[row_index, 0]
192-
cache_v[row_index, row_start_pos] = xv_cache_update[row_index, 0]
186+
for b in range(bs):
187+
# Make a tensor because indices must be all tensors, so we can avoid
188+
# doing start_positions[row_index].item(), which generates a lot of SymInts.
189+
row_index = torch.tensor(
190+
b, dtype=torch.int64, device=xk_cache_update.device
191+
)
192+
row_start_pos = start_positions[row_index]
193+
cache_k.index_put(
194+
(row_index, row_start_pos), xk_cache_update[row_index, 0]
195+
)
196+
cache_v.index_put(
197+
(row_index, row_start_pos), xv_cache_update[row_index, 0]
198+
)
193199
return cache_k[:, :kv_seq_len], cache_v[:, :kv_seq_len]
194200

195201
def transact_cache_paged(

0 commit comments

Comments
 (0)