Skip to content

Commit 72a3f6b

Browse files
authored
Construct KVTransferConfig properly from Python instead of using JSON blobs without CLI (#17994)
Signed-off-by: Harry Mellor <[email protected]>
1 parent 98ea356 commit 72a3f6b

File tree

5 files changed

+37
-31
lines changed

5 files changed

+37
-31
lines changed

examples/lmcache/disagg_prefill_lmcache_v0.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ def run_prefill(prefill_done, prompts):
4949

5050
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
5151

52-
ktc = KVTransferConfig.from_cli(
53-
'{"kv_connector":"LMCacheConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
54-
)
52+
ktc = KVTransferConfig(kv_connector="LMCacheConnector",
53+
kv_role="kv_producer",
54+
kv_rank=0,
55+
kv_parallel_size=2)
5556
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
5657
# memory. Reduce the value if your GPU has less memory.
5758
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
@@ -78,9 +79,10 @@ def run_decode(prefill_done, prompts, timeout=1):
7879

7980
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
8081

81-
ktc = KVTransferConfig.from_cli(
82-
'{"kv_connector":"LMCacheConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
83-
)
82+
ktc = KVTransferConfig(kv_connector="LMCacheConnector",
83+
kv_role="kv_consumer",
84+
kv_rank=1,
85+
kv_parallel_size=2)
8486
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
8587
# of memory. Reduce the value if your GPU has less memory.
8688
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",

examples/lmcache/kv_cache_sharing_lmcache_v1.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def run_store(store_done, prompts):
4949

5050
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
5151

52-
ktc = KVTransferConfig.from_cli(
53-
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
52+
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1",
53+
kv_role="kv_both")
5454
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
5555
# memory. Reduce the value if your GPU has less memory.
5656
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
@@ -76,8 +76,8 @@ def run_retrieve(store_done, prompts, timeout=1):
7676

7777
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
7878

79-
ktc = KVTransferConfig.from_cli(
80-
'{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}')
79+
ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1",
80+
kv_role="kv_both")
8181
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
8282
# of memory. Reduce the value if your GPU has less memory.
8383
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",

examples/offline_inference/disaggregated-prefill-v1/decode_example.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,17 @@
1616

1717
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
1818

19-
llm = LLM(
20-
model="meta-llama/Llama-3.2-1B-Instruct",
21-
enforce_eager=True,
22-
gpu_memory_utilization=0.8,
23-
max_num_batched_tokens=64,
24-
max_num_seqs=16,
25-
kv_transfer_config=KVTransferConfig.from_cli(
26-
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",'
27-
'"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}'
28-
)) #, max_model_len=2048, max_num_batched_tokens=2048)
19+
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
20+
enforce_eager=True,
21+
gpu_memory_utilization=0.8,
22+
max_num_batched_tokens=64,
23+
max_num_seqs=16,
24+
kv_transfer_config=KVTransferConfig(
25+
kv_connector="SharedStorageConnector",
26+
kv_role="kv_both",
27+
kv_connector_extra_config={
28+
"shared_storage_path": "local_storage"
29+
})) #, max_model_len=2048, max_num_batched_tokens=2048)
2930

3031
# 1ST generation (prefill instance)
3132
outputs = llm.generate(prompts, sampling_params)

examples/offline_inference/disaggregated-prefill-v1/prefill_example.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
1818
enforce_eager=True,
1919
gpu_memory_utilization=0.8,
20-
kv_transfer_config=KVTransferConfig.from_cli(
21-
'{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", '
22-
'"kv_connector_extra_config": '
23-
'{"shared_storage_path": "local_storage"}}')
24-
) #, max_model_len=2048, max_num_batched_tokens=2048)
20+
kv_transfer_config=KVTransferConfig(
21+
kv_connector="SharedStorageConnector",
22+
kv_role="kv_both",
23+
kv_connector_extra_config={
24+
"shared_storage_path": "local_storage"
25+
})) #, max_model_len=2048, max_num_batched_tokens=2048)
2526

2627
# 1ST generation (prefill instance)
2728
outputs = llm.generate(

examples/offline_inference/disaggregated_prefill.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ def run_prefill(prefill_done):
3232
# This instance is the prefill node (kv_producer, rank 0).
3333
# The number of parallel instances for KV cache transfer is set to 2,
3434
# as required for PyNcclConnector.
35-
ktc = KVTransferConfig.from_cli(
36-
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}'
37-
)
35+
ktc = KVTransferConfig(kv_connector="PyNcclConnector",
36+
kv_role="kv_producer",
37+
kv_rank=0,
38+
kv_parallel_size=2)
3839

3940
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
4041
# memory. You may need to adjust the value to fit your GPU.
@@ -71,9 +72,10 @@ def run_decode(prefill_done):
7172
# This instance is the decode node (kv_consumer, rank 1).
7273
# The number of parallel instances for KV cache transfer is set to 2,
7374
# as required for PyNcclConnector.
74-
ktc = KVTransferConfig.from_cli(
75-
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}'
76-
)
75+
ktc = KVTransferConfig(kv_connector="PyNcclConnector",
76+
kv_role="kv_consumer",
77+
kv_rank=1,
78+
kv_parallel_size=2)
7779

7880
# Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB
7981
# memory. You may need to adjust the value to fit your GPU.

0 commit comments

Comments
 (0)