@@ -358,9 +358,13 @@ def _prepare_encoder_model_input_tensors(
358358
359359 return attn_metadata
360360
361+ @torch .inference_mode ()
361362 def profile_run (self ) -> None :
362363 num_layers = self .model_config .get_num_layers (self .parallel_config )
363- kv_caches = [None ] * num_layers
364+ kv_caches = [
365+ torch .tensor ([], dtype = torch .bfloat16 , device = self .device )
366+ for _ in range (num_layers )
367+ ]
364368 max_batch_size = self .max_num_prefill_seqs
365369 _ , max_seq_len = self .bucketing_ctx .get_max_prompt_shape ()
366370 max_seq_len = min (self .max_num_batched_tokens // max_batch_size ,
@@ -446,24 +450,18 @@ def create_dummy_seq_group_metadata(self,
446450 sampling_params = SamplingParams (temperature = temperature )
447451 num_blocks = math .ceil (seq_len / self .block_size )
448452 cross_block_table : Optional [List [int ]] = None
449- seq_len = max (seq_len , 1 )
453+ encoder_dummy_data \
454+ = self .input_registry .dummy_data_for_profiling (
455+ self .model_config ,
456+ seq_len ,
457+ self .mm_registry ,
458+ is_encoder_data = True )
450459 mm_counts = self .mm_registry .get_mm_limits_per_prompt (
451460 self .model_config )
452461 num_images = mm_counts ["image" ]
453462 max_mm_tokens = self .mm_registry .get_max_multimodal_tokens (
454463 self .model_config ) * num_images
455- decoder_dummy_data \
456- = self .input_registry .dummy_data_for_profiling (
457- self .model_config ,
458- seq_len ,
459- self .mm_registry ,
460- is_encoder_data = False )
461- encoder_dummy_data \
462- = self .input_registry .dummy_data_for_profiling (
463- self .model_config ,
464- max_mm_tokens ,
465- self .mm_registry ,
466- is_encoder_data = True )
464+ seq_len = max (seq_len , 1 )
467465 if is_prompt :
468466 input_len = seq_len
469467 output_len = 0
@@ -477,21 +475,25 @@ def create_dummy_seq_group_metadata(self,
477475 num_cross_blocks = min (self .bucketing_ctx .num_hpu_blocks ,
478476 max_mm_tokens ) // self .block_size
479477 cross_block_table = [_PAD_BLOCK_ID ] * num_cross_blocks
480- prompt_token_ids = [0 ] * input_len
481478 output_token_ids = [1 ] * output_len
482- prompt_token_ids_array = array ('l' , prompt_token_ids ) # noqa: F821
483- seq_data = SequenceData (prompt_token_ids_array )
479+ decoder_dummy_data = self .input_registry \
480+ .dummy_data_for_profiling (self .model_config ,
481+ seq_len ,
482+ self .mm_registry ,
483+ is_encoder_data = False )
484+ seq_data = decoder_dummy_data .seq_data
484485 seq_data .output_token_ids = output_token_ids
486+
485487 return SequenceGroupMetadata (
486488 request_id = str (group_id ),
487489 is_prompt = is_prompt ,
488490 seq_data = {group_id : seq_data },
489491 sampling_params = sampling_params ,
490492 block_tables = block_tables ,
491493 encoder_seq_data = encoder_dummy_data .seq_data ,
492- multi_modal_data = decoder_dummy_data .multi_modal_data ,
494+ multi_modal_data = decoder_dummy_data .multi_modal_data or encoder_dummy_data . multi_modal_data ,
493495 multi_modal_placeholders = decoder_dummy_data .
494- multi_modal_placeholders ,
496+ multi_modal_placeholders or encoder_dummy_data . multi_modal_placeholders ,
495497 cross_block_table = cross_block_table )
496498
497499 def trim_attn_metadata (self , metadata : AttentionMetadata ) -> object :
0 commit comments