Skip to content

Commit 7d0557c

Browse files
committed
[Bugfix] Avoid repeatedly creating dummy data during engine startup
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 246e3e0 commit 7d0557c

File tree

8 files changed

+57
-4
lines changed

8 files changed

+57
-4
lines changed

vllm/engine/llm_engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
409409
# the next step without re-scheduling.
410410
self._skip_scheduling_next_step = False
411411

412+
# Don't keep the dummy data in memory
413+
self.reset_mm_cache()
414+
412415
def _initialize_kv_caches(self) -> None:
413416
"""Initialize the KV cache in the worker(s).
414417
@@ -913,6 +916,10 @@ def has_unfinished_requests_for_virtual_engine(
913916
"""
914917
return self.scheduler[virtual_engine].has_unfinished_seqs()
915918

919+
def reset_mm_cache(self) -> bool:
920+
"""Reset the multi-modal cache."""
921+
return self.input_preprocessor.mm_registry.reset_processor_cache()
922+
916923
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
917924
"""Reset prefix cache for all devices."""
918925

vllm/multimodal/processing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,11 @@ def put(
10261026
def put_item(self, item: ProcessingCacheItem) -> None:
10271027
self._cache[item.key] = item.value
10281028

1029+
def reset(self) -> bool:
1030+
self._cache.clear()
1031+
1032+
return True
1033+
10291034

10301035
class BaseProcessingInfo:
10311036
"""Base class to provide the information necessary for data processing."""

vllm/multimodal/registry.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ def __init__(self) -> None:
8888

8989
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
9090

91+
def reset_processor_cache(self) -> bool:
92+
"""Reset the multi-modal processing cache."""
93+
self._processing_cache.reset()
94+
95+
return True # Success
96+
9197
@deprecated("Legacy input processor/mapper pipeline has been removed. "
9298
"Please update your model runner to use "
9399
"`seq_group_metadata.multi_modal_data` directly without "
@@ -106,7 +112,7 @@ def get_max_tokens_per_item_by_modality(
106112
if not model_config.is_multimodal_model:
107113
return {}
108114

109-
processor = self.create_processor(model_config, disable_cache=True)
115+
processor = self.create_processor(model_config, disable_cache=False)
110116
profiler = MultiModalProfiler(processor)
111117

112118
seq_len = model_config.max_model_len
@@ -190,7 +196,7 @@ def get_mm_limits_per_prompt(
190196
if not model_config.is_multimodal_model:
191197
return {}
192198

193-
processor = self.create_processor(model_config, disable_cache=True)
199+
processor = self.create_processor(model_config, disable_cache=False)
194200
profiler = MultiModalProfiler(processor)
195201
return profiler.get_mm_limits()
196202

@@ -286,7 +292,7 @@ def get_decoder_dummy_data(
286292
287293
The model is identified by ``model_config``.
288294
"""
289-
processor = self.create_processor(model_config, disable_cache=True)
295+
processor = self.create_processor(model_config, disable_cache=False)
290296
profiler = MultiModalProfiler(processor)
291297
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
292298

@@ -310,7 +316,7 @@ def get_encoder_dummy_data(
310316
311317
The model is identified by ``model_config``.
312318
"""
313-
processor = self.create_processor(model_config, disable_cache=True)
319+
processor = self.create_processor(model_config, disable_cache=False)
314320
profiler = MultiModalProfiler(processor)
315321
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
316322

vllm/v1/engine/core.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,9 @@ def __init__(self,
121121
self.batch_queue = queue.Queue(self.batch_queue_size)
122122
self.vllm_config = vllm_config
123123

124+
# Don't keep the dummy data in memory
125+
self.reset_mm_cache()
126+
124127
def _initialize_kv_caches(
125128
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
126129
start = time.time()
@@ -277,6 +280,15 @@ def shutdown(self):
277280
def profile(self, is_start: bool = True):
278281
self.model_executor.profile(is_start)
279282

283+
def reset_mm_cache(self):
284+
# NOTE: Since this is mainly for debugging, we don't attempt to
285+
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
286+
if self.scheduler.get_num_unfinished_requests():
287+
logger.warning("Resetting the multi-modal cache when requests are "
288+
"in progress may lead to desynced internal caches.")
289+
290+
self.mm_input_cache_server.reset()
291+
280292
def reset_prefix_cache(self):
281293
self.scheduler.reset_prefix_cache()
282294

vllm/v1/engine/core_client.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def add_request(self, request: EngineCoreRequest) -> None:
8888
def profile(self, is_start: bool = True) -> None:
8989
raise NotImplementedError
9090

91+
def reset_mm_cache(self) -> None:
92+
raise NotImplementedError
93+
9194
def reset_prefix_cache(self) -> None:
9295
raise NotImplementedError
9396

@@ -214,6 +217,9 @@ def shutdown(self) -> None:
214217
def profile(self, is_start: bool = True) -> None:
215218
self.engine_core.profile(is_start)
216219

220+
def reset_mm_cache(self) -> None:
221+
self.engine_core.reset_mm_cache()
222+
217223
def reset_prefix_cache(self) -> None:
218224
self.engine_core.reset_prefix_cache()
219225

@@ -600,6 +606,9 @@ def abort_requests(self, request_ids: list[str]) -> None:
600606
def profile(self, is_start: bool = True) -> None:
601607
self.call_utility("profile", is_start)
602608

609+
def reset_mm_cache(self) -> None:
610+
self.call_utility("reset_mm_cache")
611+
603612
def reset_prefix_cache(self) -> None:
604613
self.call_utility("reset_prefix_cache")
605614

vllm/v1/engine/llm_engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,11 @@ def start_profile(self):
240240
def stop_profile(self):
241241
self.engine_core.profile(False)
242242

243+
def reset_mm_cache(self):
244+
self.processor.mm_registry.reset_processor_cache()
245+
self.processor.mm_input_cache_client.reset()
246+
self.engine_core.reset_mm_cache()
247+
243248
def reset_prefix_cache(self, device: Optional[Device] = None):
244249
self.engine_core.reset_prefix_cache()
245250

vllm/v1/engine/mm_input_cache.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,8 @@ def get_and_update_p1(
8383
full_mm_inputs.append(mm_input)
8484

8585
return full_mm_inputs
86+
87+
def reset(self) -> bool:
88+
self.mm_cache.clear()
89+
90+
return True

vllm/v1/engine/processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def __init__(
5454
self.use_hash = self.mm_input_cache_client.use_cache or \
5555
self.cache_config.enable_prefix_caching
5656

57+
@property
58+
def mm_registry(self):
59+
return self.input_preprocessor.mm_registry
60+
5761
def _validate_logprobs(
5862
self,
5963
params: SamplingParams,

0 commit comments

Comments
 (0)