Skip to content

Commit 4ebbccf

Browse files
rsudermanrenxida
andauthored
Make sharktank export prefill/decode batch size seperately (#1046)
Prefill and decode have different preferable batch sizes. We should export them separately. --------- Co-authored-by: Xida Ren <[email protected]>
1 parent 9641f3d commit 4ebbccf

File tree

6 files changed

+25
-12
lines changed

6 files changed

+25
-12
lines changed

app_tests/integration_tests/llm/model_management.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,8 @@ def export_model(self, weights_path: Path) -> Tuple[Path, Path]:
327327
f"--{weights_path.suffix.strip('.')}-file={weights_path}",
328328
f"--output-mlir={mlir_path}",
329329
f"--output-config={config_path}",
330-
f"--bs={bs_string}",
330+
f"--bs-prefill={bs_string}",
331+
f"--bs-decode={bs_string}",
331332
],
332333
check=True,
333334
)

docs/shortfin/llm/developer/e2e_llama8b_mi300x.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ python -m sharktank.examples.export_paged_llm_v1 \
6969
--irpa-file=$MODEL_PARAMS_PATH \
7070
--output-mlir=$MLIR_PATH \
7171
--output-config=$OUTPUT_CONFIG_PATH \
72-
--bs=$BS
72+
--bs-prefill=$BS
73+
--bs-decode=$BS
7374
```
7475

7576
## Compiling to `.vmfb`

docs/shortfin/llm/user/llama_serving.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ python -m sharktank.examples.export_paged_llm_v1 \
159159
--gguf-file=$MODEL_PARAMS_PATH \
160160
--output-mlir=$MLIR_PATH \
161161
--output-config=$OUTPUT_CONFIG_PATH \
162-
--bs=$EXPORT_BATCH_SIZES
162+
--bs-prefill=$EXPORT_BATCH_SIZES \
163+
--bs-decode=$EXPORT_BATCH_SIZES
163164
```
164165
165166
### Compile using IREE to a `.vmfb` file

sharktank/sharktank/evaluate/README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ For Llama3.1 8B (FP16) model on a MI300 server:
2222
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py -k test_llama3_8B_f16 \
2323
--llama3-8b-f16-model-path=llama3.1_8b_instruct_fp16.irpa \
2424
--llama3-8b-tokenizer-path=tokenizer_config.json \
25-
--bs=4 \
25+
--bs-prefill=4 \
26+
--bs-decode=4 \
2627
--run-nightly-llama-tests
2728
```
2829

@@ -31,7 +32,8 @@ pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_torch_test.py -k test_llam
3132
pytest -n 8 -v -s sharktank/tests/evaluate/perplexity_iree_test.py -k test_llama3_8B_f16 \
3233
--llama3-8b-f16-model-path=llama3.1_8b_instruct_fp16.irpa \
3334
--llama3-8b-tokenizer-path=tokenizer_config.json \
34-
--bs=4 \
35+
--bs-prefill=4 \
36+
--bs-decode=4 \
3537
--iree-device=hip://1 \
3638
--iree-hip-target=gfx942 \
3739
--iree-hal-target-device=hip

sharktank/sharktank/examples/export_paged_llm_v1.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ def main():
4040
default="/tmp/batch_llama_v1.json",
4141
)
4242
parser.add_argument(
43-
"--bs",
43+
"--bs-prefill",
44+
help="Comma-separated batch size(s) to generate, e.g. `4` or `2,4`",
45+
type=lambda arg: [int(bs) for bs in arg.split(",")],
46+
default="4",
47+
)
48+
parser.add_argument(
49+
"--bs-decode",
4450
help="Comma-separated batch size(s) to generate, e.g. `4` or `2,4`",
4551
type=lambda arg: [int(bs) for bs in arg.split(",")],
4652
default="4",
@@ -336,13 +342,14 @@ def _(
336342
return logits
337343

338344
bsizes = []
339-
for bs in args.bs:
340-
if not args.skip_prefill:
345+
if not args.skip_prefill:
346+
for bs in args.bs_prefill:
341347
generate_batch_prefill(bs)
342-
if not args.skip_decode:
348+
if not args.skip_decode:
349+
for bs in args.bs_decode:
343350
generate_batch_decode(bs)
344-
bsizes.append(bs)
345-
config = generate_params_json(hp, bsizes, bsizes)
351+
352+
config = generate_params_json(hp, args.bs_prefill, args.bs_decode)
346353
print("GENERATED!")
347354

348355
if args.verbose:

sharktank/sharktank/utils/export_artifacts.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def export_to_mlir(
192192
f"--irpa-file={self.irpa_path}",
193193
f"--output-mlir={mlir_path}",
194194
f"--output-config={json_path}",
195-
f"--bs={str(self.batch_size)}",
195+
f"--bs-prefill={str(self.batch_size)}",
196+
f"--bs-decode={str(self.batch_size)}",
196197
f"--block-seq-stride={self.block_seq_stride}",
197198
f"--attention-dtype={self.attention_dtype}",
198199
f"--activation-dtype={self.activation_dtype}",

0 commit comments

Comments
 (0)