Skip to content

Commit a038133

Browse files
authored
Disable strictness for export of llama (#168)
Strictness validates correctness but this results in loading the tensors to memory. Disabling helps with export speed.
1 parent aead69f commit a038133

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

sharktank/sharktank/examples/export_paged_llm_v1.py

+7
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def main():
4545
help="Include verbose logging",
4646
action="store_true",
4747
)
48+
parser.add_argument(
49+
"--strict",
50+
help="Enables strictness during export",
51+
action="store_true",
52+
)
4853

4954
args = cli.parse(parser)
5055
dataset = cli.get_input_dataset(args)
@@ -117,6 +122,7 @@ def generate_batch_prefill(bs: int):
117122
name=f"prefill_bs{bs}",
118123
args=(tokens, seq_lens, seq_block_ids, cache_state),
119124
dynamic_shapes=dynamic_shapes,
125+
strict=args.strict,
120126
)
121127
def _(model, tokens, seq_lens, seq_block_ids, cache_state):
122128
sl = tokens.shape[1]
@@ -174,6 +180,7 @@ def generate_batch_decode(bs: int):
174180
cache_state,
175181
),
176182
dynamic_shapes=dynamic_shapes,
183+
strict=args.strict,
177184
)
178185
def _(
179186
model,

0 commit comments

Comments
 (0)