Skip to content

Commit 61e0a50

Browse files
[Bugfix] Avoid repeatedly creating dummy data during engine startup (#17935)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 1df491c commit 61e0a50

15 files changed

+99
-4
lines changed

vllm/engine/async_llm_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,9 @@ async def start_profile(self) -> None:
12321232
async def stop_profile(self) -> None:
12331233
self.engine.stop_profile()
12341234

1235+
async def reset_mm_cache(self) -> None:
1236+
self.engine.reset_mm_cache()
1237+
12351238
async def reset_prefix_cache(self,
12361239
device: Optional[Device] = None) -> None:
12371240
self.engine.reset_prefix_cache(device)

vllm/engine/llm_engine.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
409409
# the next step without re-scheduling.
410410
self._skip_scheduling_next_step = False
411411

412+
# Don't keep the dummy data in memory
413+
self.reset_mm_cache()
414+
412415
def _initialize_kv_caches(self) -> None:
413416
"""Initialize the KV cache in the worker(s).
414417
@@ -913,6 +916,10 @@ def has_unfinished_requests_for_virtual_engine(
913916
"""
914917
return self.scheduler[virtual_engine].has_unfinished_seqs()
915918

919+
def reset_mm_cache(self) -> bool:
920+
"""Reset the multi-modal cache."""
921+
return self.input_preprocessor.mm_registry.reset_processor_cache()
922+
916923
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
917924
"""Reset prefix cache for all devices."""
918925

vllm/engine/multiprocessing/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ class RPCUProfileRequest(Enum):
123123
STOP_PROFILE = 2
124124

125125

126+
class RPCResetMultiModalCacheRequest(Enum):
127+
RESET = 1
128+
129+
126130
@dataclass
127131
class RPCResetPrefixCacheRequest:
128132
device: Device
@@ -164,6 +168,7 @@ class RPCAdapterLoadedResponse:
164168

165169
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
166170
RPCUProfileRequest, RPCLoadAdapterRequest,
171+
RPCResetMultiModalCacheRequest,
167172
RPCResetPrefixCacheRequest, RPCSleepRequest,
168173
RPCWakeUpRequest, RPCIsSleepingRequest]
169174

vllm/engine/multiprocessing/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
RPCIsSleepingResponse,
3232
RPCLoadAdapterRequest,
3333
RPCProcessRequest,
34+
RPCResetMultiModalCacheRequest,
3435
RPCResetPrefixCacheRequest,
3536
RPCSleepRequest, RPCStartupRequest,
3637
RPCStartupResponse,
@@ -687,6 +688,13 @@ async def stop_profile(self) -> None:
687688
await self._send_one_way_rpc_request(
688689
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
689690

691+
async def reset_mm_cache(self) -> None:
692+
"""Reset the multi-modal cache"""
693+
694+
await self._send_one_way_rpc_request(
695+
request=RPCResetMultiModalCacheRequest.RESET,
696+
socket=self.input_socket)
697+
690698
async def reset_prefix_cache(self,
691699
device: Optional[Device] = None) -> None:
692700
"""Reset the prefix cache"""

vllm/engine/multiprocessing/engine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
RPCIsSleepingResponse,
2323
RPCLoadAdapterRequest,
2424
RPCProcessRequest,
25+
RPCResetMultiModalCacheRequest,
2526
RPCResetPrefixCacheRequest,
2627
RPCSleepRequest, RPCStartupRequest,
2728
RPCStartupResponse,
@@ -269,6 +270,8 @@ def handle_new_input(self):
269270
self.stop_profile()
270271
elif isinstance(request, RPCLoadAdapterRequest):
271272
self._handle_load_adapter_request(request)
273+
elif isinstance(request, RPCResetMultiModalCacheRequest):
274+
self.reset_mm_cache()
272275
elif isinstance(request, RPCResetPrefixCacheRequest):
273276
self.reset_prefix_cache()
274277
elif isinstance(request, RPCSleepRequest):
@@ -409,6 +412,9 @@ def start_profile(self) -> None:
409412
def stop_profile(self) -> None:
410413
self.engine.stop_profile()
411414

415+
def reset_mm_cache(self) -> bool:
416+
return self.engine.reset_mm_cache()
417+
412418
def reset_prefix_cache(self) -> bool:
413419
return self.engine.reset_prefix_cache()
414420

vllm/engine/protocol.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ async def stop_profile(self) -> None:
278278
"""Start profiling the engine"""
279279
...
280280

281+
@abstractmethod
282+
async def reset_mm_cache(self) -> None:
283+
"""Reset the multi-modal cache"""
284+
...
285+
281286
@abstractmethod
282287
async def reset_prefix_cache(self,
283288
device: Optional[Device] = None) -> None:

vllm/entrypoints/openai/api_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,10 @@ async def build_async_engine_client(
150150

151151
async with build_async_engine_client_from_engine_args(
152152
engine_args, args.disable_frontend_multiprocessing) as engine:
153+
154+
# Don't keep the dummy data in memory
155+
await engine.reset_mm_cache()
156+
153157
yield engine
154158

155159

vllm/multimodal/processing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,6 +1026,11 @@ def put(
10261026
def put_item(self, item: ProcessingCacheItem) -> None:
10271027
self._cache[item.key] = item.value
10281028

1029+
def reset(self) -> bool:
1030+
self._cache.clear()
1031+
1032+
return True
1033+
10291034

10301035
class BaseProcessingInfo:
10311036
"""Base class to provide the information necessary for data processing."""

vllm/multimodal/registry.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ def __init__(self) -> None:
8888

8989
self._processing_cache = ProcessingCache(VLLM_MM_INPUT_CACHE_GIB)
9090

91+
def reset_processor_cache(self) -> bool:
92+
"""Reset the multi-modal processing cache."""
93+
self._processing_cache.reset()
94+
95+
return True # Success
96+
9197
@deprecated("Legacy input processor/mapper pipeline has been removed. "
9298
"Please update your model runner to use "
9399
"`seq_group_metadata.multi_modal_data` directly without "
@@ -106,7 +112,7 @@ def get_max_tokens_per_item_by_modality(
106112
if not model_config.is_multimodal_model:
107113
return {}
108114

109-
processor = self.create_processor(model_config, disable_cache=True)
115+
processor = self.create_processor(model_config, disable_cache=False)
110116
profiler = MultiModalProfiler(processor)
111117

112118
seq_len = model_config.max_model_len
@@ -190,7 +196,7 @@ def get_mm_limits_per_prompt(
190196
if not model_config.is_multimodal_model:
191197
return {}
192198

193-
processor = self.create_processor(model_config, disable_cache=True)
199+
processor = self.create_processor(model_config, disable_cache=False)
194200
profiler = MultiModalProfiler(processor)
195201
return profiler.get_mm_limits()
196202

@@ -286,7 +292,7 @@ def get_decoder_dummy_data(
286292
287293
The model is identified by ``model_config``.
288294
"""
289-
processor = self.create_processor(model_config, disable_cache=True)
295+
processor = self.create_processor(model_config, disable_cache=False)
290296
profiler = MultiModalProfiler(processor)
291297
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
292298

@@ -310,7 +316,7 @@ def get_encoder_dummy_data(
310316
311317
The model is identified by ``model_config``.
312318
"""
313-
processor = self.create_processor(model_config, disable_cache=True)
319+
processor = self.create_processor(model_config, disable_cache=False)
314320
profiler = MultiModalProfiler(processor)
315321
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
316322

vllm/v1/engine/async_llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,11 @@ async def start_profile(self) -> None:
476476
async def stop_profile(self) -> None:
477477
await self.engine_core.profile_async(False)
478478

479+
async def reset_mm_cache(self) -> None:
480+
self.processor.mm_registry.reset_processor_cache()
481+
self.processor.mm_input_cache_client.reset()
482+
await self.engine_core.reset_mm_cache_async()
483+
479484
async def reset_prefix_cache(self,
480485
device: Optional[Device] = None) -> None:
481486
if device == Device.CPU:

vllm/v1/engine/core.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,15 @@ def shutdown(self):
286286
def profile(self, is_start: bool = True):
287287
self.model_executor.profile(is_start)
288288

289+
def reset_mm_cache(self):
290+
# NOTE: Since this is mainly for debugging, we don't attempt to
291+
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
292+
if self.scheduler.get_num_unfinished_requests():
293+
logger.warning("Resetting the multi-modal cache when requests are "
294+
"in progress may lead to desynced internal caches.")
295+
296+
self.mm_input_cache_server.reset()
297+
289298
def reset_prefix_cache(self):
290299
self.scheduler.reset_prefix_cache()
291300

vllm/v1/engine/core_client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def add_request(self, request: EngineCoreRequest) -> None:
8888
def profile(self, is_start: bool = True) -> None:
8989
raise NotImplementedError
9090

91+
def reset_mm_cache(self) -> None:
92+
raise NotImplementedError
93+
9194
def reset_prefix_cache(self) -> None:
9295
raise NotImplementedError
9396

@@ -143,6 +146,9 @@ async def add_request_async(self, request: EngineCoreRequest) -> None:
143146
async def profile_async(self, is_start: bool = True) -> None:
144147
raise NotImplementedError
145148

149+
async def reset_mm_cache_async(self) -> None:
150+
raise NotImplementedError
151+
146152
async def reset_prefix_cache_async(self) -> None:
147153
raise NotImplementedError
148154

@@ -214,6 +220,9 @@ def shutdown(self) -> None:
214220
def profile(self, is_start: bool = True) -> None:
215221
self.engine_core.profile(is_start)
216222

223+
def reset_mm_cache(self) -> None:
224+
self.engine_core.reset_mm_cache()
225+
217226
def reset_prefix_cache(self) -> None:
218227
self.engine_core.reset_prefix_cache()
219228

@@ -600,6 +609,9 @@ def abort_requests(self, request_ids: list[str]) -> None:
600609
def profile(self, is_start: bool = True) -> None:
601610
self.call_utility("profile", is_start)
602611

612+
def reset_mm_cache(self) -> None:
613+
self.call_utility("reset_mm_cache")
614+
603615
def reset_prefix_cache(self) -> None:
604616
self.call_utility("reset_prefix_cache")
605617

@@ -787,6 +799,9 @@ async def abort_requests_async(self, request_ids: list[str]) -> None:
787799
async def profile_async(self, is_start: bool = True) -> None:
788800
await self.call_utility_async("profile", is_start)
789801

802+
async def reset_mm_cache_async(self) -> None:
803+
await self.call_utility_async("reset_mm_cache")
804+
790805
async def reset_prefix_cache_async(self) -> None:
791806
await self.call_utility_async("reset_prefix_cache")
792807

vllm/v1/engine/llm_engine.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def __init__(
101101
# for v0 compatibility
102102
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
103103

104+
# Don't keep the dummy data in memory
105+
self.reset_mm_cache()
106+
104107
@classmethod
105108
def from_vllm_config(
106109
cls,
@@ -240,6 +243,11 @@ def start_profile(self):
240243
def stop_profile(self):
241244
self.engine_core.profile(False)
242245

246+
def reset_mm_cache(self):
247+
self.processor.mm_registry.reset_processor_cache()
248+
self.processor.mm_input_cache_client.reset()
249+
self.engine_core.reset_mm_cache()
250+
243251
def reset_prefix_cache(self, device: Optional[Device] = None):
244252
self.engine_core.reset_prefix_cache()
245253

vllm/v1/engine/mm_input_cache.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,8 @@ def get_and_update_p1(
8383
full_mm_inputs.append(mm_input)
8484

8585
return full_mm_inputs
86+
87+
def reset(self) -> bool:
88+
self.mm_cache.clear()
89+
90+
return True

vllm/v1/engine/processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def __init__(
5454
self.use_hash = self.mm_input_cache_client.use_cache or \
5555
self.cache_config.enable_prefix_caching
5656

57+
@property
58+
def mm_registry(self):
59+
return self.input_preprocessor.mm_registry
60+
5761
def _validate_logprobs(
5862
self,
5963
params: SamplingParams,

0 commit comments

Comments
 (0)