Skip to content

Commit d520bd1

Browse files
authored
Replication of index was causing issues for kv cache writes (#715)
Signed-off-by: Rob Suderman <[email protected]>
1 parent fecc081 commit d520bd1

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

sharktank/sharktank/examples/export_paged_llm_v1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ def _(
334334

335335
bsizes = []
336336
for bs in args.bs:
337-
generate_batch_prefill(bs)
337+
if not args.skip_prefill:
338+
generate_batch_prefill(bs)
338339
if not args.skip_decode:
339340
generate_batch_decode(bs)
340341
bsizes.append(bs)

sharktank/sharktank/layers/kv_cache.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,12 +456,25 @@ def write_timestep(
456456
page_offset = (seq_positions % self.block_seq_stride).unsqueeze(1)
457457

458458
# [1, 1]
459-
partitions = torch.tensor(idx).unsqueeze(0)
459+
if isinstance(seq_positions, ReplicatedTensor):
460+
partitions = [
461+
torch.tensor(idx).unsqueeze(0)
462+
for _ in range(seq_positions.shard_count)
463+
]
464+
465+
transformer_block = [
466+
torch.full((bs, 1), transformer_block_index, device=device)
467+
for _ in range(seq_positions.shard_count)
468+
]
469+
470+
partitions = ReplicatedTensor(ts=partitions)
471+
transformer_block = ReplicatedTensor(ts=transformer_block)
472+
else:
473+
partitions = torch.tensor(idx).unsqueeze(0)
474+
transformer_block = torch.full(
475+
(bs, 1), transformer_block_index, device=device
476+
)
460477

461-
# [bs, 1]
462-
transformer_block = torch.full(
463-
(bs, 1), transformer_block_index, device=device
464-
)
465478
partitions = partitions.repeat(bs, 1)
466479

467480
indices = (page_id, transformer_block, partitions, page_offset)

sharktank/sharktank/utils/cli.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,14 @@ def add_model_options(parser: argparse.ArgumentParser):
6969
default="torch",
7070
choices=["decomposed", "torch"],
7171
)
72+
parser.add_argument(
73+
"--skip-prefill",
74+
help="Skips exporting prefill",
75+
action="store_true",
76+
)
7277
parser.add_argument(
7378
"--skip-decode",
74-
help="Enables prefill only, skips decode",
79+
help="Skips export decode",
7580
action="store_true",
7681
)
7782

0 commit comments

Comments
 (0)