Skip to content

Commit 802aaa3

Browse files
fems14asinkLuno
authored andcommitted
[bugfix] Fixing KV Pool Memory Retention and Performance Degradation Issues (vllm-project#5751)
### What this PR does / why we need it? 1.Fixed memory retention on certain GPUs caused by missing PUT operations. 2.Fixed performance degradation resulting from architectural incompatibilities in the underlying refactor. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: vllm-project/vllm@2f4e654 --------- Signed-off-by: fems14 <[email protected]>
1 parent 66cb700 commit 802aaa3

File tree

6 files changed

+27
-22
lines changed

6 files changed

+27
-22
lines changed

tests/ut/distributed/mooncake/test_config_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
fake_engine = types.ModuleType("mooncake.engine")
77
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
88
sys.modules["mooncake.engine"] = fake_engine
9+
fake_store = types.ModuleType("mooncake.store")
10+
fake_store.ReplicateConfig = MagicMock() # type: ignore[attr-defined]
11+
sys.modules["mooncake.store"] = fake_store
912

1013
from vllm_ascend.distributed.kvpool.backend.mooncake_backend import ( # noqa: E402
1114
_convert_to_bytes, _parse_global_segment_size)

vllm_ascend/distributed/kvpool/ascend_store_connector.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -136,21 +136,9 @@ def get_finished(self,
136136
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
137137
"""Get the finished recving and sending requests."""
138138
assert self.connector_worker is not None
139-
meta = self._get_connector_metadata()
140139
done_sending, done_recving = self.connector_worker.get_finished(
141140
finished_req_ids)
142-
sended_and_finished: set[str] = set()
143-
for item in list(self.sended_but_unfinished_reqs):
144-
if item not in meta.unfinished_request_ids:
145-
sended_and_finished.add(item)
146-
self.sended_but_unfinished_reqs.remove(item)
147-
for item in done_sending:
148-
if item in meta.unfinished_request_ids:
149-
self.sended_but_unfinished_reqs.add(item)
150-
else:
151-
sended_and_finished.add(item)
152-
153-
return sended_and_finished, done_recving
141+
return done_sending, done_recving
154142

155143

156144
class LookupKeyServer:

vllm_ascend/distributed/kvpool/backend/mooncake_backend.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Union
77

88
# Third Party
9+
from mooncake.store import ReplicateConfig # type: ignore
910
from vllm.config import ParallelConfig
1011
from vllm.logger import logger
1112
from vllm.utils.network_utils import get_ip
@@ -56,7 +57,11 @@ def exists(self, keys: list[str]) -> list[int]:
5657
def put(self, keys: list[str], addrs: list[list[int]],
5758
sizes: list[list[int]]):
5859
try:
59-
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes)
60+
config = ReplicateConfig()
61+
config.preferred_segment = self.local_seg
62+
config.prefer_alloc_in_same_node = True
63+
res = self.store.batch_put_from_multi_buffers(
64+
keys, addrs, sizes, config)
6065
for value in res:
6166
if value < 0:
6267
logger.error(f"Failed to put key {keys},res:{res}")
@@ -66,7 +71,8 @@ def put(self, keys: list[str], addrs: list[list[int]],
6671
def get(self, keys: list[str], addrs: list[list[int]],
6772
sizes: list[list[int]]):
6873
try:
69-
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes)
74+
res = self.store.batch_get_into_multi_buffers(
75+
keys, addrs, sizes, True)
7076
for value in res:
7177
if value < 0:
7278
logger.error(f"Failed to get key {keys}, res:{res}")

vllm_ascend/distributed/kvpool/config_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ class LoadSpec:
223223
# Whether the scheduler allow us to load the tokens
224224
can_load: bool
225225

226+
token_len: int = 0
227+
226228

227229
@dataclass
228230
class RequestTracker:

vllm_ascend/distributed/kvpool/kv_transfer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def _handle_request(self, req_meta: ReqMeta):
125125
token_len = req_meta.token_len_chunk
126126
block_ids = req_meta.block_ids
127127
req_id = req_meta.req_id
128-
is_last_chunk = req_meta.is_last_chunk
129128
current_event = req_meta.current_event
130129
starts = []
131130
ends = []
@@ -142,15 +141,15 @@ def _handle_request(self, req_meta: ReqMeta):
142141
keys = keys[self.tp_rank % self.put_step::self.put_step]
143142

144143
if not keys:
145-
if is_last_chunk:
146-
self.set_finished_request(req_id)
144+
with self.done_task_lock:
145+
self.stored_requests[req_id] -= 1
147146
return
148147

149148
skip_block_num = self.lookup(keys)
150149

151150
if skip_block_num == len(keys):
152-
if is_last_chunk:
153-
self.set_finished_request(req_id)
151+
with self.done_task_lock:
152+
self.stored_requests[req_id] -= 1
154153
return
155154

156155
starts = starts[skip_block_num:]
@@ -208,6 +207,7 @@ def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
208207
name="KVCacheStoreRecvingThread")
209208

210209
def _handle_request(self, req_meta: ReqMeta):
210+
token_len = req_meta.load_spec.token_len # type: ignore[union-attr]
211211
req_id = req_meta.req_id
212212
mask_num = (
213213
req_meta.load_spec.vllm_cached_tokens # type: ignore[union-attr]
@@ -216,7 +216,7 @@ def _handle_request(self, req_meta: ReqMeta):
216216
size_list = []
217217
key_list = []
218218
for start, end, key in self.token_database.process_tokens(
219-
req_meta.token_len_chunk, req_meta.block_hashes, mask_num):
219+
token_len, req_meta.block_hashes, mask_num):
220220
addr, size, _ = self.token_database.prepare_value(
221221
start, end, req_meta.block_ids)
222222
key_list.append(key.to_string())

vllm_ascend/distributed/kvpool/pool_worker.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ def __init__(
134134
self.use_mla, partitions)
135135

136136
real_backend = backend_map.get(self.backend.lower())
137+
138+
# be removed later
139+
if self.backend == "mooncake":
140+
self.head_or_tp_rank = self.tp_rank
141+
self.put_step = 1
142+
137143
self.m_store = real_backend( # type: ignore[misc]
138144
parallel_config)
139145

@@ -245,7 +251,7 @@ def start_load_kv(self, metadata: AscendConnectorMetadata):
245251
token_len = request.load_spec.kvpool_cached_tokens + 1
246252
else:
247253
token_len = request.load_spec.kvpool_cached_tokens
248-
request.token_len_chunk = token_len
254+
request.load_spec.token_len = token_len
249255
if self.use_layerwise:
250256
layerwise_retriever = self.retrieve_layer(request)
251257
next(layerwise_retriever) # first layer load

0 commit comments

Comments
 (0)