Skip to content

Commit

Permalink
Rewriter write_timestep to fuse the indexing writes (#1023)
Browse files Browse the repository at this point in the history
It is possible to fuse the timestep writing with the index computation
if we manually compute the offset. This reduces the scatter write to a
single dispatch avoiding the elementwise write operations.
  • Loading branch information
rsuderman authored Mar 3, 2025
1 parent 894f4b9 commit 7bddd07
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions sharktank/sharktank/layers/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def write_timestep(
"""
device = self.device
page_table = self.unflatten_page_table(state) # 6D
page_table = page_table.flatten(0, 3)
bs, *_ = seq_positions.shape
assert len(cache_partitions) == self.cache_partition_count

Expand Down Expand Up @@ -274,15 +275,18 @@ def write_timestep(

partitions = partitions.repeat(bs, 1)

indices = (page_id, transformer_block, partitions, page_offset)
index = page_id
index = index * self.transformer_block_count + transformer_block
index = index * self.cache_partition_count + partitions
index = index * self.block_seq_stride + page_offset
values = ops.to(cache_partition, dtype=page_table.dtype)
if page_table.dtype == torch.float8_e4m3fnuz:
# Workaround for Torch not supporting torch.Tensor.index_copy_ for f8.
page_table_as_int8 = page_table.view(dtype=torch.int8)
values_int8 = values.view(dtype=torch.int8)
page_table_as_int8.index_put_(indices=indices, values=values_int8)
page_table_as_int8.index_put_(indices=(index,), values=values_int8)
else:
page_table.index_put_(indices=indices, values=values)
page_table.index_put_(indices=(index,), values=values)

return

Expand Down

0 comments on commit 7bddd07

Please sign in to comment.