Skip to content

Commit c6147a3

Browse files
committed
fix profile_run
Signed-off-by: yan ma <[email protected]>
1 parent aaa438a commit c6147a3

File tree

1 file changed

+21
-19
lines changed

1 file changed

+21
-19
lines changed

vllm/worker/hpu_enc_dec_model_runner.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,13 @@ def _prepare_encoder_model_input_tensors(
358358

359359
return attn_metadata
360360

361+
@torch.inference_mode()
361362
def profile_run(self) -> None:
362363
num_layers = self.model_config.get_num_layers(self.parallel_config)
363-
kv_caches = [None] * num_layers
364+
kv_caches = [
365+
torch.tensor([], dtype=torch.bfloat16, device=self.device)
366+
for _ in range(num_layers)
367+
]
364368
max_batch_size = self.max_num_prefill_seqs
365369
_, max_seq_len = self.bucketing_ctx.get_max_prompt_shape()
366370
max_seq_len = min(self.max_num_batched_tokens // max_batch_size,
@@ -446,24 +450,18 @@ def create_dummy_seq_group_metadata(self,
446450
sampling_params = SamplingParams(temperature=temperature)
447451
num_blocks = math.ceil(seq_len / self.block_size)
448452
cross_block_table: Optional[List[int]] = None
449-
seq_len = max(seq_len, 1)
453+
encoder_dummy_data \
454+
= self.input_registry.dummy_data_for_profiling(
455+
self.model_config,
456+
seq_len,
457+
self.mm_registry,
458+
is_encoder_data=True)
450459
mm_counts = self.mm_registry.get_mm_limits_per_prompt(
451460
self.model_config)
452461
num_images = mm_counts["image"]
453462
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
454463
self.model_config) * num_images
455-
decoder_dummy_data \
456-
= self.input_registry.dummy_data_for_profiling(
457-
self.model_config,
458-
seq_len,
459-
self.mm_registry,
460-
is_encoder_data=False)
461-
encoder_dummy_data \
462-
= self.input_registry.dummy_data_for_profiling(
463-
self.model_config,
464-
max_mm_tokens,
465-
self.mm_registry,
466-
is_encoder_data=True)
464+
seq_len = max(seq_len, 1)
467465
if is_prompt:
468466
input_len = seq_len
469467
output_len = 0
@@ -477,21 +475,25 @@ def create_dummy_seq_group_metadata(self,
477475
num_cross_blocks = min(self.bucketing_ctx.num_hpu_blocks,
478476
max_mm_tokens) // self.block_size
479477
cross_block_table = [_PAD_BLOCK_ID] * num_cross_blocks
480-
prompt_token_ids = [0] * input_len
481478
output_token_ids = [1] * output_len
482-
prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821
483-
seq_data = SequenceData(prompt_token_ids_array)
479+
decoder_dummy_data = self.input_registry \
480+
.dummy_data_for_profiling(self.model_config,
481+
seq_len,
482+
self.mm_registry,
483+
is_encoder_data=False)
484+
seq_data = decoder_dummy_data.seq_data
484485
seq_data.output_token_ids = output_token_ids
486+
485487
return SequenceGroupMetadata(
486488
request_id=str(group_id),
487489
is_prompt=is_prompt,
488490
seq_data={group_id: seq_data},
489491
sampling_params=sampling_params,
490492
block_tables=block_tables,
491493
encoder_seq_data=encoder_dummy_data.seq_data,
492-
multi_modal_data=decoder_dummy_data.multi_modal_data,
494+
multi_modal_data=decoder_dummy_data.multi_modal_data or encoder_dummy_data.multi_modal_data,
493495
multi_modal_placeholders=decoder_dummy_data.
494-
multi_modal_placeholders,
496+
multi_modal_placeholders or encoder_dummy_data.multi_modal_placeholders,
495497
cross_block_table=cross_block_table)
496498

497499
def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:

0 commit comments

Comments
 (0)