From 81d27c8e31c26a435a062fbeaff66357d28a773c Mon Sep 17 00:00:00 2001
From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
Date: Sun, 19 Jan 2025 12:13:27 +0800
Subject: [PATCH 001/147] Refactor to add TypeBasedDispatcher to simplify
dispatching (#2958)
---
python/sglang/srt/managers/scheduler.py | 113 +++++-----
.../sglang/srt/managers/tokenizer_manager.py | 209 +++++++++---------
python/sglang/utils.py | 13 +-
3 files changed, 171 insertions(+), 164 deletions(-)
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index d62abaff931..d859a30a038 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -97,7 +97,7 @@
set_random_seed,
suppress_other_loggers,
)
-from sglang.utils import get_exception_traceback
+from sglang.utils import TypeBasedDispatcher, get_exception_traceback
logger = logging.getLogger(__name__)
@@ -422,6 +422,34 @@ def __init__(
},
)
+ self._dispatcher = TypeBasedDispatcher(
+ [
+ (TokenizedGenerateReqInput, self.handle_generate_request),
+ (TokenizedEmbeddingReqInput, self.handle_embedding_request),
+ (FlushCacheReq, self.flush_cache_wrapped),
+ (AbortReq, self.abort_request),
+ (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
+ (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
+ (
+ UpdateWeightsFromDistributedReqInput,
+ self.update_weights_from_distributed,
+ ),
+ (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
+ (GetWeightsByNameReqInput, self.get_weights_by_name),
+ (ProfileReq, self.profile),
+ (OpenSessionReqInput, self.open_session),
+ (CloseSessionReqInput, self.close_session),
+ (
+ ReleaseMemoryOccupationReqInput,
+ lambda _: self.release_memory_occupation(),
+ ),
+ (
+ ResumeMemoryOccupationReqInput,
+ lambda _: self.resume_memory_occupation(),
+ ),
+ ]
+ )
+
def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
self.watchdog_last_forward_ct = 0
@@ -563,57 +591,9 @@ def recv_requests(self) -> List[Req]:
def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs:
- if isinstance(recv_req, TokenizedGenerateReqInput):
- self.handle_generate_request(recv_req)
- elif isinstance(recv_req, TokenizedEmbeddingReqInput):
- self.handle_embedding_request(recv_req)
- elif isinstance(recv_req, FlushCacheReq):
- self.flush_cache()
- elif isinstance(recv_req, AbortReq):
- self.abort_request(recv_req)
- elif isinstance(recv_req, UpdateWeightFromDiskReqInput):
- success, message = self.update_weights_from_disk(recv_req)
- self.send_to_tokenizer.send_pyobj(
- UpdateWeightFromDiskReqOutput(success, message)
- )
- elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
- success, message = self.init_weights_update_group(recv_req)
- self.send_to_tokenizer.send_pyobj(
- InitWeightsUpdateGroupReqOutput(success, message)
- )
- elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
- success, message = self.update_weights_from_distributed(recv_req)
- self.send_to_tokenizer.send_pyobj(
- UpdateWeightsFromDistributedReqOutput(success, message)
- )
- elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
- success, message = self.update_weights_from_tensor(recv_req)
- self.send_to_tokenizer.send_pyobj(
- UpdateWeightsFromTensorReqOutput(success, message)
- )
- elif isinstance(recv_req, GetWeightsByNameReqInput):
- parameter = self.get_weights_by_name(recv_req)
- self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
- elif isinstance(recv_req, ReleaseMemoryOccupationReqInput):
- self.release_memory_occupation()
- self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput())
- elif isinstance(recv_req, ResumeMemoryOccupationReqInput):
- self.resume_memory_occupation()
- self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput())
- elif isinstance(recv_req, ProfileReq):
- if recv_req == ProfileReq.START_PROFILE:
- self.start_profile()
- else:
- self.stop_profile()
- elif isinstance(recv_req, OpenSessionReqInput):
- session_id, success = self.open_session(recv_req)
- self.send_to_tokenizer.send_pyobj(
- OpenSessionReqOutput(session_id=session_id, success=success)
- )
- elif isinstance(recv_req, CloseSessionReqInput):
- self.close_session(recv_req)
- else:
- raise ValueError(f"Invalid request: {recv_req}")
+ output = self._dispatcher(recv_req)
+ if output is not None:
+ self.send_to_tokenizer.send_pyobj(output)
def handle_generate_request(
self,
@@ -1545,6 +1525,9 @@ def move_ready_grammar_requests(self):
self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
+ def flush_cache_wrapped(self, recv_req: FlushCacheReq):
+ self.flush_cache()
+
def flush_cache(self):
"""Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and (
@@ -1597,12 +1580,12 @@ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
- return success, message
+ return UpdateWeightFromDiskReqOutput(success, message)
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
- return success, message
+ return InitWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed(
self,
@@ -1615,7 +1598,7 @@ def update_weights_from_distributed(
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
- return success, message
+ return UpdateWeightsFromDistributedReqOutput(success, message)
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
@@ -1626,11 +1609,11 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
- return success, message
+ return UpdateWeightsFromTensorReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
- return parameter
+ return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self):
self.stashed_model_static_state = _export_static_state(
@@ -1638,6 +1621,7 @@ def release_memory_occupation(self):
)
self.memory_saver_adapter.pause()
self.flush_cache()
+ return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self):
self.memory_saver_adapter.resume()
@@ -1645,6 +1629,13 @@ def resume_memory_occupation(self):
self.tp_worker.worker.model_runner.model, self.stashed_model_static_state
)
del self.stashed_model_static_state
+ return ResumeMemoryOccupationReqOutput()
+
+ def profile(self, recv_req: ProfileReq):
+ if recv_req == ProfileReq.START_PROFILE:
+ self.start_profile()
+ else:
+ self.stop_profile()
def start_profile(self) -> None:
if self.profiler is None:
@@ -1660,20 +1651,20 @@ def stop_profile(self) -> None:
)
logger.info("Profiler is done")
- def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]:
+ def open_session(self, recv_req: OpenSessionReqInput):
# handle error
session_id = recv_req.session_id
if session_id in self.sessions:
logger.warning(f"session id {session_id} already exist, cannot open.")
- return session_id, False
+ return OpenSessionReqOutput(session_id, False)
elif session_id is None:
logger.warning(f"session id is None, cannot open.")
- return session_id, False
+ return OpenSessionReqOutput(session_id, False)
else:
self.sessions[session_id] = Session(
recv_req.capacity_of_str_len, session_id
)
- return session_id, True
+ return OpenSessionReqOutput(session_id, True)
def close_session(self, recv_req: CloseSessionReqInput):
# handle error
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 85dcbcbd04c..74f46538c93 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -80,7 +80,7 @@
get_zmq_socket,
kill_process_tree,
)
-from sglang.utils import get_exception_traceback
+from sglang.utils import TypeBasedDispatcher, get_exception_traceback
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
@@ -221,6 +221,43 @@ def __init__(
},
)
+ self._dispatcher = TypeBasedDispatcher(
+ [
+ (BatchStrOut, self._handle_batch_output),
+ (BatchEmbeddingOut, self._handle_batch_output),
+ (BatchTokenIDOut, self._handle_batch_output),
+ (OpenSessionReqOutput, self._handle_open_session_req_output),
+ (
+ UpdateWeightFromDiskReqOutput,
+ self._handle_update_weights_from_disk_req_output,
+ ),
+ (
+ InitWeightsUpdateGroupReqOutput,
+ self.init_weights_update_group_communicator.handle_recv,
+ ),
+ (
+ UpdateWeightsFromDistributedReqOutput,
+ self.update_weights_from_distributed_communicator.handle_recv,
+ ),
+ (
+ UpdateWeightsFromTensorReqOutput,
+ self.update_weights_from_tensor_communicator.handle_recv,
+ ),
+ (
+ GetWeightsByNameReqOutput,
+ self.get_weights_by_name_communicator.handle_recv,
+ ),
+ (
+ ReleaseMemoryOccupationReqOutput,
+ self.release_memory_occupation_communicator.handle_recv,
+ ),
+ (
+ ResumeMemoryOccupationReqOutput,
+ self.resume_memory_occupation_communicator.handle_recv,
+ ),
+ ]
+ )
+
async def generate_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -712,110 +749,64 @@ async def handle_loop(self):
"""The event loop that handles requests"""
while True:
- recv_obj: Union[
- BatchStrOut,
- BatchEmbeddingOut,
- BatchTokenIDOut,
- UpdateWeightFromDiskReqOutput,
- UpdateWeightsFromDistributedReqOutput,
- GetWeightsByNameReqOutput,
- InitWeightsUpdateGroupReqOutput,
- ReleaseMemoryOccupationReqOutput,
- ResumeMemoryOccupationReqOutput,
- ] = await self.recv_from_detokenizer.recv_pyobj()
-
- if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
- for i, rid in enumerate(recv_obj.rids):
- state = self.rid_to_state.get(rid, None)
- if state is None:
- continue
-
- meta_info = {
- "id": rid,
- "finish_reason": recv_obj.finished_reasons[i],
- "prompt_tokens": recv_obj.prompt_tokens[i],
- }
+ recv_obj = await self.recv_from_detokenizer.recv_pyobj()
+ self._dispatcher(recv_obj)
- if getattr(state.obj, "return_logprob", False):
- self.convert_logprob_style(
- meta_info,
- state.obj.top_logprobs_num,
- state.obj.return_text_in_logprobs,
- recv_obj,
- i,
- )
-
- if not isinstance(recv_obj, BatchEmbeddingOut):
- meta_info.update(
- {
- "completion_tokens": recv_obj.completion_tokens[i],
- "cached_tokens": recv_obj.cached_tokens[i],
- }
- )
-
- if isinstance(recv_obj, BatchStrOut):
- out_dict = {
- "text": recv_obj.output_strs[i],
- "meta_info": meta_info,
- }
- elif isinstance(recv_obj, BatchTokenIDOut):
- out_dict = {
- "token_ids": recv_obj.output_ids[i],
- "meta_info": meta_info,
- }
- else:
- assert isinstance(recv_obj, BatchEmbeddingOut)
- out_dict = {
- "embedding": recv_obj.embeddings[i],
- "meta_info": meta_info,
- }
- state.out_list.append(out_dict)
- state.finished = recv_obj.finished_reasons[i] is not None
- state.event.set()
-
- if self.enable_metrics and state.obj.log_metrics:
- self.collect_metrics(state, recv_obj, i)
- if (
- self.dump_requests_folder
- and state.finished
- and state.obj.log_metrics
- ):
- self.dump_requests(state, out_dict)
- elif isinstance(recv_obj, OpenSessionReqOutput):
- self.session_futures[recv_obj.session_id].set_result(
- recv_obj.session_id if recv_obj.success else None
+ def _handle_batch_output(
+ self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
+ ):
+ for i, rid in enumerate(recv_obj.rids):
+ state = self.rid_to_state.get(rid, None)
+ if state is None:
+ continue
+
+ meta_info = {
+ "id": rid,
+ "finish_reason": recv_obj.finished_reasons[i],
+ "prompt_tokens": recv_obj.prompt_tokens[i],
+ }
+
+ if getattr(state.obj, "return_logprob", False):
+ self.convert_logprob_style(
+ meta_info,
+ state.obj.top_logprobs_num,
+ state.obj.return_text_in_logprobs,
+ recv_obj,
+ i,
)
- elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
- if self.server_args.dp_size == 1:
- self.model_update_result.set_result(recv_obj)
- else: # self.server_args.dp_size > 1
- self.model_update_tmp.append(recv_obj)
- # set future if the all results are recevied
- if len(self.model_update_tmp) == self.server_args.dp_size:
- self.model_update_result.set_result(self.model_update_tmp)
- elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
- assert (
- self.server_args.dp_size == 1
- ), "dp_size must be 1 for init parameter update group"
- self.init_weights_update_group_communicator.handle_recv(recv_obj)
- elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
- assert (
- self.server_args.dp_size == 1
- ), "dp_size must be 1 for update weights from distributed"
- self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
- elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
- assert (
- self.server_args.dp_size == 1
- ), "dp_size must be 1 for update weights from distributed"
- self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
- elif isinstance(recv_obj, GetWeightsByNameReqOutput):
- self.get_weights_by_name_communicator.handle_recv(recv_obj)
- elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
- self.release_memory_occupation_communicator.handle_recv(recv_obj)
- elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
- self.resume_memory_occupation_communicator.handle_recv(recv_obj)
+
+ if not isinstance(recv_obj, BatchEmbeddingOut):
+ meta_info.update(
+ {
+ "completion_tokens": recv_obj.completion_tokens[i],
+ "cached_tokens": recv_obj.cached_tokens[i],
+ }
+ )
+
+ if isinstance(recv_obj, BatchStrOut):
+ out_dict = {
+ "text": recv_obj.output_strs[i],
+ "meta_info": meta_info,
+ }
+ elif isinstance(recv_obj, BatchTokenIDOut):
+ out_dict = {
+ "token_ids": recv_obj.output_ids[i],
+ "meta_info": meta_info,
+ }
else:
- raise ValueError(f"Invalid object: {recv_obj=}")
+ assert isinstance(recv_obj, BatchEmbeddingOut)
+ out_dict = {
+ "embedding": recv_obj.embeddings[i],
+ "meta_info": meta_info,
+ }
+ state.out_list.append(out_dict)
+ state.finished = recv_obj.finished_reasons[i] is not None
+ state.event.set()
+
+ if self.enable_metrics and state.obj.log_metrics:
+ self.collect_metrics(state, recv_obj, i)
+ if self.dump_requests_folder and state.finished and state.obj.log_metrics:
+ self.dump_requests(state, out_dict)
def convert_logprob_style(
self,
@@ -943,6 +934,20 @@ def background_task():
# Schedule the task to run in the background without awaiting it
asyncio.create_task(asyncio.to_thread(background_task))
+ def _handle_open_session_req_output(self, recv_obj):
+ self.session_futures[recv_obj.session_id].set_result(
+ recv_obj.session_id if recv_obj.success else None
+ )
+
+ def _handle_update_weights_from_disk_req_output(self, recv_obj):
+ if self.server_args.dp_size == 1:
+ self.model_update_result.set_result(recv_obj)
+ else: # self.server_args.dp_size > 1
+ self.model_update_tmp.append(recv_obj)
+ # set future if the all results are recevied
+ if len(self.model_update_tmp) == self.server_args.dp_size:
+ self.model_update_result.set_result(self.model_update_tmp)
+
async def print_exception_wrapper(func):
"""
diff --git a/python/sglang/utils.py b/python/sglang/utils.py
index 98e0f3f4f8d..98942fbb39c 100644
--- a/python/sglang/utils.py
+++ b/python/sglang/utils.py
@@ -15,7 +15,7 @@
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from json import dumps
-from typing import Optional, Union
+from typing import Any, Callable, List, Optional, Tuple, Type, Union
import numpy as np
import requests
@@ -363,3 +363,14 @@ def terminate_process(process):
def print_highlight(html_content: str):
html_content = str(html_content).replace("\n", "
")
display(HTML(f"{html_content}"))
+
+
+class TypeBasedDispatcher:
+ def __init__(self, mapping: List[Tuple[Type, Callable]]):
+ self._mapping = mapping
+
+ def __call__(self, obj: Any):
+ for ty, fn in self._mapping:
+ if isinstance(obj, ty):
+ return fn(obj)
+ raise ValueError(f"Invalid object: {obj}")
From 7906d1d29863bc3b33c4bcfb942a5d61f9867127 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Sat, 18 Jan 2025 20:20:23 -0800
Subject: [PATCH 002/147] Remove the unused write_with_records (#2972)
---
python/sglang/srt/managers/schedule_batch.py | 1 -
python/sglang/srt/mem_cache/memory_pool.py | 28 +------------------
.../sglang/srt/model_executor/model_runner.py | 1 -
3 files changed, 1 insertion(+), 29 deletions(-)
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index cec2262c487..afbc98b7ca9 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -158,7 +158,6 @@ class ImageInputs:
im_end_id: Optional[torch.Tensor] = None
slice_start_id: Optional[torch.Tensor] = None
slice_end_id: Optional[torch.Tensor] = None
-
tgt_sizes: Optional[list] = None
@staticmethod
diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py
index ab27e81b743..e307367223a 100644
--- a/python/sglang/srt/mem_cache/memory_pool.py
+++ b/python/sglang/srt/mem_cache/memory_pool.py
@@ -49,7 +49,6 @@ def __init__(
size: int,
max_context_len: int,
device: str,
- use_records: bool,
enable_memory_saver: bool,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
@@ -64,17 +63,9 @@ def __init__(
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size))
- self.write_records = []
- self.use_records = use_records
-
- if self.use_records:
- self.write = self.write_with_records
- else:
- self.write = self.write_without_records
def write(self, indices, values):
- # Keep the signature for type checking. It will be assigned during runtime.
- raise NotImplementedError()
+ self.req_to_token[indices] = values
def available_size(self):
return len(self.free_slots)
@@ -96,23 +87,6 @@ def free(self, free_index: Union[int, List[int]]):
def clear(self):
self.free_slots = list(range(self.size))
- self.write_records = []
-
- def write_without_records(self, indices, values):
- self.req_to_token[indices] = values
-
- def write_with_records(self, indices, values):
- self.req_to_token[indices] = values
- self.write_records.append((indices, values))
-
- def get_write_records(self):
- ret = self.write_records
- self.write_records = []
- return ret
-
- def apply_write_records(self, write_records: List[Tuple]):
- for indices, values in write_records:
- self.req_to_token[indices] = values
class BaseTokenToKVPool:
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index bca4711eb64..46920d92249 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -617,7 +617,6 @@ def init_memory_pool(
size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4,
device=self.device,
- use_records=False,
enable_memory_saver=self.server_args.enable_memory_saver,
)
if (
From 93b77c8e8a14d74bea70b643c4f40ea5f5fbc666 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Sat, 18 Jan 2025 21:45:00 -0800
Subject: [PATCH 003/147] Fix the request loggings to make it fully able to be
easily replayed (#2973)
---
python/sglang/srt/managers/configure_logging.py | 3 +++
python/sglang/srt/managers/io_struct.py | 1 +
python/sglang/srt/managers/tokenizer_manager.py | 11 +++++++++--
python/sglang/srt/utils.py | 6 +++---
4 files changed, 16 insertions(+), 5 deletions(-)
diff --git a/python/sglang/srt/managers/configure_logging.py b/python/sglang/srt/managers/configure_logging.py
index 3351cdc400c..187af4d9c08 100644
--- a/python/sglang/srt/managers/configure_logging.py
+++ b/python/sglang/srt/managers/configure_logging.py
@@ -27,6 +27,7 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url", type=str, default="http://localhost:30000")
+ parser.add_argument("--log-requests", action="store_true")
parser.add_argument(
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
)
@@ -36,6 +37,8 @@
response = requests.post(
args.url + "/configure_logging",
json={
+ "log_requests": args.log_requests,
+ "log_requests_level": 1, # Log full requests
"dump_requests_folder": args.dump_requests_folder,
"dump_requests_threshold": args.dump_requests_threshold,
},
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 7f07055132f..c5a35ced00c 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -495,6 +495,7 @@ class ProfileReq(Enum):
@dataclass
class ConfigureLoggingReq:
log_requests: Optional[bool] = None
+ log_requests_level: Optional[int] = None
dump_requests_folder: Optional[str] = None
dump_requests_threshold: Optional[int] = None
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 74f46538c93..033a660df5e 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -117,6 +117,7 @@ def __init__(
self.server_args = server_args
self.enable_metrics = server_args.enable_metrics
self.log_requests = server_args.log_requests
+ self.log_requests_level = 0
# Init inter-process communication
context = zmq.asyncio.Context(2)
@@ -276,7 +277,10 @@ async def generate_request(
obj.normalize_batch_and_arguments()
if self.log_requests:
- logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
+ max_length = 2048 if self.log_requests_level == 0 else 1 << 30
+ logger.info(
+ f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
+ )
async with self.model_update_lock.reader_lock:
is_single = obj.is_single
@@ -419,7 +423,8 @@ async def _wait_one_response(
state.out_list = []
if state.finished:
if self.log_requests:
- msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
+ max_length = 2048 if self.log_requests_level == 0 else 1 << 30
+ msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
logger.info(msg)
del self.rid_to_state[obj.rid]
@@ -682,6 +687,8 @@ async def close_session(
def configure_logging(self, obj: ConfigureLoggingReq):
if obj.log_requests is not None:
self.log_requests = obj.log_requests
+ if obj.log_requests_level is not None:
+ self.log_requests_level = obj.log_requests_level
if obj.dump_requests_folder is not None:
self.dump_requests_folder = obj.dump_requests_folder
if obj.dump_requests_threshold is not None:
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
index 3e8b95b1597..c67b6635b30 100644
--- a/python/sglang/srt/utils.py
+++ b/python/sglang/srt/utils.py
@@ -1262,9 +1262,9 @@ def dataclass_to_string_truncated(data, max_length=2048):
if isinstance(data, str):
if len(data) > max_length:
half_length = max_length // 2
- return f'"{data[:half_length]} ... {data[-half_length:]}"'
+ return f"{repr(data[:half_length])} ... {repr(data[-half_length:])}"
else:
- return f'"{data}"'
+ return f"{repr(data)}"
elif isinstance(data, (list, tuple)):
if len(data) > max_length:
half_length = max_length // 2
@@ -1275,7 +1275,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
return (
"{"
+ ", ".join(
- f"{k}: {dataclass_to_string_truncated(v, max_length)}"
+ f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
for k, v in data.items()
)
+ "}"
From 23196d5254ff9f9d7cadd6a028b264bf5db8b18c Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Sat, 18 Jan 2025 23:03:49 -0800
Subject: [PATCH 004/147] Simplify logits processor (#2974)
Co-authored-by: SangBin Cho
---
python/sglang/srt/layers/logits_processor.py | 71 ++++++++++++--------
1 file changed, 44 insertions(+), 27 deletions(-)
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
index 10f26467787..e5794f052c3 100644
--- a/python/sglang/srt/layers/logits_processor.py
+++ b/python/sglang/srt/layers/logits_processor.py
@@ -14,6 +14,7 @@
"""Logits processing."""
import dataclasses
+import logging
from typing import List, Optional, Union
import torch
@@ -32,6 +33,8 @@
ForwardMode,
)
+logger = logging.getLogger(__name__)
+
@dataclasses.dataclass
class LogitsProcessorOutput:
@@ -136,50 +139,61 @@ def forward(
logits_metadata.forward_mode.is_decode_or_idle()
or logits_metadata.forward_mode.is_target_verify()
):
- last_index = None
- last_hidden = hidden_states
- else:
+ pruned_states = hidden_states
+ sample_indices = None
+ elif (
+ logits_metadata.forward_mode.is_extend()
+ and not logits_metadata.extend_return_logprob
+ ):
+ # Prefill without input logprobs.
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
- last_hidden = hidden_states[last_index]
+ pruned_states = hidden_states[last_index]
+ sample_indices = None
+ else:
+ # Slice the requested tokens to compute logprob
+ sample_index_pt = -1
+ sample_indices = []
+ pt, pruned_states, pruned_input_ids = 0, [], []
+ for start_len, extend_len in zip(
+ logits_metadata.extend_logprob_start_lens_cpu,
+ logits_metadata.extend_seq_lens_cpu,
+ ):
+ pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
+ sample_index_pt += extend_len - start_len
+ sample_indices.append(sample_index_pt)
+ pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
+ pt += extend_len
+
+ pruned_states = torch.cat(pruned_states)
+
+ # Compute logits for both input and sampled tokens.
+ logits = self._get_logits(pruned_states, lm_head, logits_metadata)
+ sampled_logits = (
+ logits[sample_indices] if sample_indices is not None else logits
+ )
- # Compute logits
- last_logits = self._get_logits(last_hidden, lm_head)
if (
not logits_metadata.extend_return_logprob
or logits_metadata.capture_hidden_mode.need_capture()
):
# Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput(
- next_token_logits=last_logits,
+ next_token_logits=sampled_logits,
hidden_states=(
hidden_states
if logits_metadata.capture_hidden_mode.is_full()
else (
- last_hidden
+ pruned_states
if logits_metadata.capture_hidden_mode.is_last()
else None
)
),
)
else:
- # Slice the requested tokens to compute logprob
- pt, pruned_states, pruned_input_ids = 0, [], []
- for start_len, extend_len in zip(
- logits_metadata.extend_logprob_start_lens_cpu,
- logits_metadata.extend_seq_lens_cpu,
- ):
- pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
- pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
- pt += extend_len
-
- # Compute the logits of all required tokens
- pruned_states = torch.cat(pruned_states)
- del hidden_states
- input_token_logits = self._get_logits(pruned_states, lm_head)
- del pruned_states
+ input_logprobs = logits
+ del hidden_states, logits
# Normalize the logprob w/o temperature, top-p
- input_logprobs = input_token_logits
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata
)
@@ -194,17 +208,17 @@ def forward(
input_top_logprobs_val = input_top_logprobs_idx = None
input_token_logprobs = input_logprobs[
- torch.arange(input_logprobs.shape[0], device="cuda"),
+ torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
torch.cat(
[
torch.cat(pruned_input_ids)[1:],
- torch.tensor([0], device="cuda"),
+ torch.tensor([0], device=input_logprobs.device),
]
),
]
return LogitsProcessorOutput(
- next_token_logits=last_logits,
+ next_token_logits=sampled_logits,
input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
@@ -214,8 +228,11 @@ def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
+ logits_metadata: LogitsMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
+ """Get logits from hidden_states."""
+
if hasattr(lm_head, "weight"):
logits = torch.matmul(hidden_states, lm_head.weight.T)
else:
From d33cbb7e5857da4cf4023ecfac2706ffbd0c76b6 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Sun, 19 Jan 2025 15:51:27 +0800
Subject: [PATCH 005/147] remove cub and add cccl (#2976)
---
.gitmodules | 6 +++---
sgl-kernel/3rdparty/cccl | 1 +
sgl-kernel/3rdparty/cub | 1 -
3 files changed, 4 insertions(+), 4 deletions(-)
create mode 160000 sgl-kernel/3rdparty/cccl
delete mode 160000 sgl-kernel/3rdparty/cub
diff --git a/.gitmodules b/.gitmodules
index c588176e7c0..c584a21e8bd 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,6 +1,6 @@
[submodule "sgl-kernel/3rdparty/cutlass"]
path = sgl-kernel/3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
-[submodule "sgl-kernel/3rdparty/cub"]
- path = sgl-kernel/3rdparty/cub
- url = https://github.com/NVIDIA/cub.git
+[submodule "sgl-kernel/3rdparty/cccl"]
+ path = sgl-kernel/3rdparty/cccl
+ url = https://github.com/NVIDIA/cccl.git
diff --git a/sgl-kernel/3rdparty/cccl b/sgl-kernel/3rdparty/cccl
new file mode 160000
index 00000000000..b5fe509fd11
--- /dev/null
+++ b/sgl-kernel/3rdparty/cccl
@@ -0,0 +1 @@
+Subproject commit b5fe509fd11a925f90d6495176707cc1184eed9d
diff --git a/sgl-kernel/3rdparty/cub b/sgl-kernel/3rdparty/cub
deleted file mode 160000
index 0fc3c370163..00000000000
--- a/sgl-kernel/3rdparty/cub
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 0fc3c3701632a4be906765b73be20a9ad0da603d
From 53cc91e504a3865d4086ac0f73d7198e66c89833 Mon Sep 17 00:00:00 2001
From: Byron Hsu
Date: Sun, 19 Jan 2025 16:34:01 +0800
Subject: [PATCH 006/147] [devcontainer] Fix mount and GPU & Support rust dev
(#2978)
---
.devcontainer/devcontainer.json | 7 +++++--
docker/Dockerfile.dev | 8 ++++++++
2 files changed, 13 insertions(+), 2 deletions(-)
diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index aee28589864..66f7aecbf82 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -15,6 +15,9 @@
]
}
},
- "workspaceFolder": "/sgl-workspace/sglang",
- "forwardPorts": []
+ "forwardPorts": [],
+ "runArgs": [
+ "--gpus",
+ "all"
+ ]
}
diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev
index 70860d8ef88..20a373184b8 100644
--- a/docker/Dockerfile.dev
+++ b/docker/Dockerfile.dev
@@ -18,6 +18,8 @@ RUN apt-get update && apt-get install -y \
silversearcher-ag \
cloc \
unzip \
+ pkg-config \
+ libssl-dev \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
@@ -63,6 +65,12 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1
&& cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \
&& rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz
+# Install uv
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh
+
+# Install rust
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
+
# Add yank script
COPY --chown=root:root <<-"EOF" /usr/local/bin/yank
#!/bin/bash
From ef18b0eda28b37082d158fade59a24b29f6a986c Mon Sep 17 00:00:00 2001
From: Byron Hsu
Date: Sun, 19 Jan 2025 17:05:23 +0800
Subject: [PATCH 007/147] [router] Allow empty worker list for
sglang.launch_router (#2979)
---
.github/workflows/pr-test-rust.yml | 4 +--
scripts/ci_install_rust.sh | 11 +++++---
sgl-router/README.md | 10 +++++++
.../py_src/sglang_router/launch_router.py | 6 ++---
.../py_src/sglang_router/launch_server.py | 2 +-
sgl-router/py_src/sglang_router/version.py | 2 +-
sgl-router/py_test/test_launch_router.py | 26 ++++++++++++++-----
sgl-router/pyproject.toml | 2 +-
8 files changed, 46 insertions(+), 17 deletions(-)
diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml
index 928d0efa5b3..277ddef774e 100644
--- a/.github/workflows/pr-test-rust.yml
+++ b/.github/workflows/pr-test-rust.yml
@@ -40,7 +40,7 @@ jobs:
cd sgl-router/
cargo test
- e2e-rust:
+ e2e-python:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 2-gpu-runner
steps:
@@ -65,7 +65,7 @@ jobs:
python3 run_suite.py
finish:
- needs: [unit-test-rust, e2e-rust]
+ needs: [unit-test-rust, e2e-python]
runs-on: ubuntu-latest
steps:
- name: Finish
diff --git a/scripts/ci_install_rust.sh b/scripts/ci_install_rust.sh
index 724207fd782..519155dfbe8 100755
--- a/scripts/ci_install_rust.sh
+++ b/scripts/ci_install_rust.sh
@@ -1,9 +1,14 @@
#!/bin/bash
set -euxo pipefail
-# these are required for actix
-apt-get update
-apt-get install -y libssl-dev pkg-config
+# Check if sudo is available
+if command -v sudo >/dev/null 2>&1; then
+ sudo apt-get update
+ sudo apt-get install -y libssl-dev pkg-config
+else
+ apt-get update
+ apt-get install -y libssl-dev pkg-config
+fi
# Install rustup (Rust installer and version manager)
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
diff --git a/sgl-router/README.md b/sgl-router/README.md
index f39d63625de..61c9e692c92 100644
--- a/sgl-router/README.md
+++ b/sgl-router/README.md
@@ -67,6 +67,16 @@ $ pip install -e .
**Note:** When modifying Rust code, you must rebuild the wheel for changes to take effect.
+### Troubleshooting
+
+1. If rust analyzer is not working in VSCode, set `rust-analyzer.linkedProjects` to the absolute path of `Cargo.toml` in your repo. For example:
+
+```json
+{
+ "rust-analyzer.linkedProjects": ["/workspaces/sglang/sgl-router/Cargo.toml"]
+}
+```
+
### CI/CD Setup
The continuous integration pipeline consists of three main steps:
diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py
index e4f26a8d4bc..28cd5d11fbb 100644
--- a/sgl-router/py_src/sglang_router/launch_router.py
+++ b/sgl-router/py_src/sglang_router/launch_router.py
@@ -27,7 +27,7 @@ def setup_logger():
@dataclasses.dataclass
class RouterArgs:
# Worker configuration
- worker_urls: List[str]
+ worker_urls: List[str] = dataclasses.field(default_factory=list)
host: str = "127.0.0.1"
port: int = 30000
@@ -141,8 +141,9 @@ def from_cli_args(
use_router_prefix: If True, look for arguments with 'router-' prefix
"""
prefix = "router_" if use_router_prefix else ""
+ worker_urls = args.worker_urls if args.worker_urls is not None else []
return cls(
- worker_urls=args.worker_urls,
+ worker_urls=worker_urls,
host=args.host,
port=args.port,
policy=getattr(args, f"{prefix}policy"),
@@ -237,7 +238,6 @@ def parse_router_args(args: List[str]) -> RouterArgs:
def main() -> None:
- logger = setup_logger()
router_args = parse_router_args(sys.argv[1:])
router = launch_router(router_args)
diff --git a/sgl-router/py_src/sglang_router/launch_server.py b/sgl-router/py_src/sglang_router/launch_server.py
index 6ee19241542..2f433269efa 100644
--- a/sgl-router/py_src/sglang_router/launch_server.py
+++ b/sgl-router/py_src/sglang_router/launch_server.py
@@ -23,7 +23,7 @@ def setup_logger():
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
- "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
+ "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
datefmt="%Y-%m-%d %H:%M:%S",
)
diff --git a/sgl-router/py_src/sglang_router/version.py b/sgl-router/py_src/sglang_router/version.py
index 485f44ac21b..b3f4756216d 100644
--- a/sgl-router/py_src/sglang_router/version.py
+++ b/sgl-router/py_src/sglang_router/version.py
@@ -1 +1 @@
-__version__ = "0.1.1"
+__version__ = "0.1.2"
diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py
index 1c3700d423b..94912f69491 100644
--- a/sgl-router/py_test/test_launch_router.py
+++ b/sgl-router/py_test/test_launch_router.py
@@ -22,11 +22,9 @@ def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) ->
class TestLaunchRouter(unittest.TestCase):
- def test_launch_router_no_exception(self):
-
- # Create SimpleNamespace with default arguments
- args = SimpleNamespace(
- worker_urls=["http://localhost:8000"],
+ def setUp(self):
+ """Set up default arguments for router tests."""
+ self.default_args = SimpleNamespace(
host="127.0.0.1",
port=30000,
policy="cache_aware",
@@ -39,6 +37,15 @@ def test_launch_router_no_exception(self):
verbose=False,
)
+ def create_router_args(self, **kwargs):
+ """Create router arguments by updating default args with provided kwargs."""
+ args_dict = vars(self.default_args).copy()
+ args_dict.update(kwargs)
+ return SimpleNamespace(**args_dict)
+
+ def run_router_process(self, args):
+ """Run router in a separate process and verify it starts successfully."""
+
def run_router():
try:
from sglang_router.launch_router import launch_router
@@ -51,7 +58,6 @@ def run_router():
print(e)
return 1
- # Start router in separate process
process = multiprocessing.Process(target=run_router)
try:
process.start()
@@ -62,6 +68,14 @@ def run_router():
finally:
terminate_process(process)
+ def test_launch_router_common(self):
+ args = self.create_router_args(worker_urls=["http://localhost:8000"])
+ self.run_router_process(args)
+
+ def test_launch_router_with_empty_worker_urls(self):
+ args = self.create_router_args(worker_urls=[])
+ self.run_router_process(args)
+
if __name__ == "__main__":
unittest.main()
diff --git a/sgl-router/pyproject.toml b/sgl-router/pyproject.toml
index 20096b6b491..90e82cecf37 100644
--- a/sgl-router/pyproject.toml
+++ b/sgl-router/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sglang-router"
-version = "0.1.1"
+version = "0.1.2"
description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances."
authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}]
requires-python = ">=3.8"
From 4719c1d04a10bd11258c0c05f08db6e7beab0414 Mon Sep 17 00:00:00 2001
From: Byron Hsu
Date: Sun, 19 Jan 2025 17:11:06 +0800
Subject: [PATCH 008/147] [router] Fix sgl router path for release (#2980)
---
.github/workflows/release-pypi-router.yml | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/.github/workflows/release-pypi-router.yml b/.github/workflows/release-pypi-router.yml
index df20c211cb3..bba0c0fca53 100644
--- a/.github/workflows/release-pypi-router.yml
+++ b/.github/workflows/release-pypi-router.yml
@@ -7,7 +7,7 @@ on:
branches:
- main
paths:
- - sglang-router/pyproject.toml
+ - sgl-router/pyproject.toml
workflow_dispatch:
jobs:
@@ -26,9 +26,9 @@ jobs:
with:
path: sglang-repo
- - name: Move sglang-router folder to root and delete sglang-repo
+ - name: Move sgl-router folder to root and delete sglang-repo
run: |
- mv sglang-repo/sglang-router/* .
+ mv sglang-repo/sgl-router/* .
rm -rf sglang-repo
ls -alt
@@ -69,9 +69,9 @@ jobs:
with:
path: sglang-repo
- - name: Move sglang-router folder to root, copy the license file, and delete sglang-repo
+ - name: Move sgl-router folder to root, copy the license file, and delete sglang-repo
run: |
- mv sglang-repo/sglang-router/* .
+ mv sglang-repo/sgl-router/* .
mv sglang-repo/LICENSE .
rm -rf sglang-repo
ls -alt
From 5a176c92dfa13183deca012fe4c43d9d75815390 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Sun, 19 Jan 2025 21:33:27 +0800
Subject: [PATCH 009/147] fix deepseek v2 with cpu device (#2975)
---
python/sglang/srt/layers/rotary_embedding.py | 114 ++++++++++++++++++-
python/sglang/srt/models/deepseek_v2.py | 4 +-
python/sglang/srt/models/minicpmv.py | 2 +-
python/sglang/srt/models/olmo2.py | 0
4 files changed, 115 insertions(+), 5 deletions(-)
mode change 100755 => 100644 python/sglang/srt/models/olmo2.py
diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py
index 7c18c683e96..bc38fa8c0f9 100644
--- a/python/sglang/srt/layers/rotary_embedding.py
+++ b/python/sglang/srt/layers/rotary_embedding.py
@@ -664,6 +664,7 @@ def __init__(
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
+ device: Optional[str] = "cuda",
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
@@ -676,13 +677,14 @@ def __init__(
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor
)
+ self.device = device
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** (
- torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda")
+ torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
/ self.rotary_dim
)
inv_freq_extrapolation = 1.0 / pos_freqs
@@ -710,7 +712,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(
self.max_position_embeddings * self.scaling_factor,
- device="cuda",
+ device=self.device,
dtype=torch.float32,
)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
@@ -1174,3 +1176,111 @@ def get_rope(
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb
+
+
+def get_rope_cpu(
+ head_size: int,
+ rotary_dim: int,
+ max_position: int,
+ base: int,
+ is_neox_style: bool = True,
+ rope_scaling: Optional[Dict[str, Any]] = None,
+ dtype: Optional[torch.dtype] = None,
+ partial_rotary_factor: float = 1.0,
+ device: Optional[str] = None,
+) -> RotaryEmbedding:
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if rope_scaling is not None:
+ # Transforms every value that is a list into a tuple for caching calls
+ rope_scaling_tuple = {
+ k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
+ }
+ rope_scaling_args = tuple(rope_scaling_tuple.items())
+ else:
+ rope_scaling_args = None
+ if partial_rotary_factor < 1.0:
+ rotary_dim = int(rotary_dim * partial_rotary_factor)
+ key = (
+ head_size,
+ rotary_dim,
+ max_position,
+ base,
+ is_neox_style,
+ rope_scaling_args,
+ dtype,
+ )
+ if key in _ROPE_DICT:
+ return _ROPE_DICT[key]
+
+ assert rope_scaling is not None
+ scaling_type = rope_scaling["rope_type"]
+ assert (
+ scaling_type == "deepseek_yarn"
+ ), "Only deepseek_yarn is supported for CPU for now"
+
+ scaling_factor = rope_scaling["factor"]
+ original_max_position = rope_scaling["original_max_position_embeddings"]
+ extra_kwargs = {
+ k: v
+ for k, v in rope_scaling.items()
+ if k
+ in (
+ "extrapolation_factor",
+ "attn_factor",
+ "beta_fast",
+ "beta_slow",
+ "mscale",
+ "mscale_all_dim",
+ )
+ }
+ extra_kwargs["device"] = device
+ rotary_emb = DeepseekScalingRotaryEmbedding(
+ head_size,
+ rotary_dim,
+ original_max_position,
+ base,
+ is_neox_style,
+ scaling_factor,
+ dtype,
+ **extra_kwargs,
+ )
+
+ _ROPE_DICT[key] = rotary_emb
+ return rotary_emb
+
+
+def get_rope_wrapper(
+ head_size: int,
+ rotary_dim: int,
+ max_position: int,
+ base: int,
+ is_neox_style: bool = True,
+ rope_scaling: Optional[Dict[str, Any]] = None,
+ dtype: Optional[torch.dtype] = None,
+ partial_rotary_factor: float = 1.0,
+ device: Optional[str] = None,
+):
+ if device != "cpu":
+ return get_rope(
+ head_size,
+ rotary_dim,
+ max_position,
+ base,
+ is_neox_style,
+ rope_scaling,
+ dtype,
+ partial_rotary_factor,
+ )
+
+ return get_rope_cpu(
+ head_size,
+ rotary_dim,
+ max_position,
+ base,
+ is_neox_style,
+ rope_scaling,
+ dtype,
+ partial_rotary_factor,
+ device,
+ )
diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py
index 0d327c0ca97..17d7fcf8924 100644
--- a/python/sglang/srt/models/deepseek_v2.py
+++ b/python/sglang/srt/models/deepseek_v2.py
@@ -48,7 +48,7 @@
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.layers.radix_attention import RadixAttention
-from sglang.srt.layers.rotary_embedding import get_rope
+from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
@@ -271,7 +271,7 @@ def __init__(
quant_config=quant_config,
)
rope_scaling["rope_type"] = "deepseek_yarn"
- self.rotary_emb = get_rope(
+ self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py
index 5ff941b6c27..23147529a64 100644
--- a/python/sglang/srt/models/minicpmv.py
+++ b/python/sglang/srt/models/minicpmv.py
@@ -39,12 +39,12 @@
from torch import nn
from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig
-from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
+from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import (
diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py
old mode 100755
new mode 100644
From 24cafe317746a1051ae965925eeaab539049a09f Mon Sep 17 00:00:00 2001
From: yizhang2077 <1109276519@qq.com>
Date: Sun, 19 Jan 2025 22:30:38 +0800
Subject: [PATCH 010/147] add config to swtich from vllm custom allreduce to
sgl_kernel custom allreduce (#2981)
---
python/sglang/srt/_custom_ops.py | 115 +++++++++++------
.../device_communicators/custom_all_reduce.py | 117 ++++++++++++------
2 files changed, 160 insertions(+), 72 deletions(-)
diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py
index f59f67605b3..3c00a8552ff 100644
--- a/python/sglang/srt/_custom_ops.py
+++ b/python/sglang/srt/_custom_ops.py
@@ -3,6 +3,7 @@
import functools
import importlib
import logging
+import os
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
@@ -11,12 +12,19 @@
from sglang.srt.utils import is_hpu
logger = logging.getLogger(__name__)
+use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=False)
if not is_hpu():
- try:
- import sgl_kernel
- except ImportError as e:
- logger.warning("Failed to import from custom_ar with %r", e)
+ if use_vllm_custom_allreduce:
+ try:
+ import vllm._C
+ except ImportError as e:
+ logger.warning("Failed to import from vllm._C with %r", e)
+ else:
+ try:
+ import sgl_kernel
+ except ImportError as e:
+ logger.warning("Failed to import from custom_ar with %r", e)
def hint_on_error(fn):
@@ -48,43 +56,78 @@ def wrapper(*args, **kwargs):
return wrapper
-# custom ar
-def init_custom_ar(
- rank_id: int,
- world_size: int,
- rank_data_base: torch.Tensor,
- buffers: List[int],
- tmp_result_buffers: List[int],
- barrier_in: List[int],
- barrier_out: List[int],
-) -> int:
- return sgl_kernel.ops.init_custom_reduce(
- rank_id,
- world_size,
- rank_data_base,
- buffers,
- tmp_result_buffers,
- barrier_in,
- barrier_out,
- )
-
-
-def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
- sgl_kernel.ops.custom_reduce(fa, inp, out)
-
+if use_vllm_custom_allreduce:
+ # custom ar
+ def init_custom_ar(
+ ipc_tensors: List[torch.Tensor],
+ rank_data: torch.Tensor,
+ rank: int,
+ full_nvlink: bool,
+ ) -> int:
+ return torch.ops._C_custom_ar.init_custom_ar(
+ ipc_tensors, rank_data, rank, full_nvlink
+ )
-def dispose(fa: int) -> None:
- sgl_kernel.ops.custom_dispose(fa)
+ def all_reduce(
+ fa: int,
+ inp: torch.Tensor,
+ out: torch.Tensor,
+ reg_buffer: int,
+ reg_buffer_sz_bytes: int,
+ ) -> None:
+ torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
+
+ def dispose(fa: int) -> None:
+ torch.ops._C_custom_ar.dispose(fa)
+
+ def meta_size() -> int:
+ return torch.ops._C_custom_ar.meta_size()
+
+ def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
+ return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
+
+ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
+ return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
+
+ def register_graph_buffers(
+ fa: int, handles: List[List[int]], offsets: List[List[int]]
+ ) -> None:
+ torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
+
+else:
+ # custom ar
+ def init_custom_ar(
+ rank_id: int,
+ world_size: int,
+ rank_data_base: torch.Tensor,
+ buffers: List[int],
+ tmp_result_buffers: List[int],
+ barrier_in: List[int],
+ barrier_out: List[int],
+ ) -> int:
+ return sgl_kernel.ops.init_custom_reduce(
+ rank_id,
+ world_size,
+ rank_data_base,
+ buffers,
+ tmp_result_buffers,
+ barrier_in,
+ barrier_out,
+ )
+ def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
+ sgl_kernel.ops.custom_reduce(fa, inp, out)
-def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
- return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
+ def dispose(fa: int) -> None:
+ sgl_kernel.ops.custom_dispose(fa)
+ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
+ return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
-def register_graph_buffers(
- fa: int, handles: List[List[int]], offsets: List[List[int]]
-) -> None:
- sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
+ def register_graph_buffers(
+ fa: int, handles: List[List[int]], offsets: List[List[int]]
+ ) -> None:
+ sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
index ba9feb59d0c..28aa9d4811e 100644
--- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
+++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
@@ -21,8 +21,10 @@
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
try:
- import sgl_kernel
-
+ if ops.use_vllm_custom_allreduce:
+ ops.meta_size()
+ else:
+ import sgl_kernel
custom_ar = True
except Exception:
# For AMD GPUs and CPUs
@@ -201,33 +203,58 @@ def __init__(
self.world_size = world_size
self.full_nvlink = full_nvlink
- # From TensorRT-LLM getMaxRequiredWorkspaceSize
- self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
-
- # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
- self.barrier_max_size = 8 * (36 + 2) * 8
-
- self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
- self.tmp_result_buffer_ptrs = self.create_shared_buffer(max_size, group=group)
- self.rank_data_base = torch.empty(
- 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
- )
- self.barrier_in_ptrs = self.create_shared_buffer(
- self.barrier_max_size, group=group
- )
- self.barrier_out_ptrs = self.create_shared_buffer(
- self.barrier_max_size, group=group
- )
-
- self._ptr = ops.init_custom_ar(
- rank,
- world_size,
- self.rank_data_base,
- self.buffer_ptrs,
- self.tmp_result_buffer_ptrs,
- self.barrier_in_ptrs,
- self.barrier_out_ptrs,
- )
+ if ops.use_vllm_custom_allreduce:
+ # Buffers memory are owned by this Python class and passed to C++.
+ # Meta data composes of two parts: meta data for synchronization and a
+ # temporary buffer for storing intermediate allreduce results.
+ self.meta_ptrs = self.create_shared_buffer(
+ ops.meta_size() + max_size, group=group
+ )
+ # This is a pre-registered IPC buffer. In eager mode, input tensors
+ # are first copied into this buffer before allreduce is performed
+ self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
+ # This is a buffer for storing the tuples of pointers pointing to
+ # IPC buffers from all ranks. Each registered tuple has size of
+ # 8*world_size bytes where world_size is at most 8. Allocating 8MB
+ # is enough for 131072 such tuples. The largest model I've seen only
+ # needs less than 10000 of registered tuples.
+ self.rank_data = torch.empty(
+ 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
+ )
+ self._ptr = ops.init_custom_ar(
+ self.meta_ptrs, self.rank_data, rank, self.full_nvlink
+ )
+ ops.register_buffer(self._ptr, self.buffer_ptrs)
+ else:
+ # From TensorRT-LLM getMaxRequiredWorkspaceSize
+ self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
+
+ # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
+ self.barrier_max_size = 8 * (36 + 2) * 8
+
+ self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
+ self.tmp_result_buffer_ptrs = self.create_shared_buffer(
+ max_size, group=group
+ )
+ self.rank_data_base = torch.empty(
+ 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
+ )
+ self.barrier_in_ptrs = self.create_shared_buffer(
+ self.barrier_max_size, group=group
+ )
+ self.barrier_out_ptrs = self.create_shared_buffer(
+ self.barrier_max_size, group=group
+ )
+
+ self._ptr = ops.init_custom_ar(
+ rank,
+ world_size,
+ self.rank_data_base,
+ self.buffer_ptrs,
+ self.tmp_result_buffer_ptrs,
+ self.barrier_in_ptrs,
+ self.barrier_out_ptrs,
+ )
self.disabled = False
@staticmethod
@@ -307,6 +334,11 @@ def should_custom_ar(self, inp: torch.Tensor):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
+ if ops.use_vllm_custom_allreduce:
+ if self.world_size == 2 or self.full_nvlink:
+ return inp_size < self.max_size
+ return False
+
if self.world_size == 2:
return (
inp_size < self.max_size
@@ -326,6 +358,7 @@ def all_reduce(
inp: torch.Tensor,
*,
out: torch.Tensor = None,
+ registered: bool = False,
):
"""Performs an out-of-place all reduce.
@@ -335,7 +368,15 @@ def all_reduce(
"""
if out is None:
out = torch.empty_like(inp)
- ops.all_reduce(self._ptr, inp, out)
+ if ops.use_vllm_custom_allreduce:
+ if registered:
+ ops.all_reduce(self._ptr, inp, out, 0, 0)
+ else:
+ ops.all_reduce(
+ self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
+ )
+ else:
+ ops.all_reduce(self._ptr, inp, out)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
@@ -345,21 +386,25 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
- return self.all_reduce(input)
+ return self.all_reduce(input, registered=True)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
- return self.all_reduce(input)
+ return self.all_reduce(input, registered=False)
def close(self):
if not self.disabled and self._ptr:
ops.dispose(self._ptr)
- self.free_shared_buffer(self.buffer_ptrs)
- self.free_shared_buffer(self.tmp_result_buffer_ptrs)
- self.free_shared_buffer(self.barrier_in_ptrs)
- self.free_shared_buffer(self.barrier_out_ptrs)
+ if ops.use_vllm_custom_allreduce:
+ self.free_shared_buffer(self.meta_ptrs)
+ self.free_shared_buffer(self.buffer_ptrs)
+ else:
+ self.free_shared_buffer(self.buffer_ptrs)
+ self.free_shared_buffer(self.tmp_result_buffer_ptrs)
+ self.free_shared_buffer(self.barrier_in_ptrs)
+ self.free_shared_buffer(self.barrier_out_ptrs)
self._ptr = 0
def __del__(self):
From 6ada05d0ed52f099ec8ffb49c7f7aa7efc31cd49 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Sun, 19 Jan 2025 23:33:04 +0800
Subject: [PATCH 011/147] feat: check for is_cuda for sgl_kernel import (#2984)
---
.../layers/moe/fused_moe_triton/fused_moe.py | 20 +++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
index 01ecce1a6ed..c0d55808558 100644
--- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
@@ -15,18 +15,18 @@
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
-from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
+from sglang.srt.utils import (
+ direct_register_custom_op,
+ get_device_name,
+ is_cuda_available,
+ is_hip,
+)
-is_hip_flag = False
-if not is_hip():
- if torch.cuda.is_available():
- from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
- else:
- sgl_moe_align_block_size = None
+is_cuda = is_cuda_available()
+is_hip_flag = is_hip()
+if is_cuda:
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
- is_hip_flag = False
-else:
- is_hip_flag = True
logger = logging.getLogger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
From 3fc2b625891029bf6207186098e0450e85c0c638 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Sun, 19 Jan 2025 23:45:39 +0800
Subject: [PATCH 012/147] update docker dev image (#2985)
---
docker/Dockerfile.dev | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev
index 20a373184b8..9d05ee5997e 100644
--- a/docker/Dockerfile.dev
+++ b/docker/Dockerfile.dev
@@ -20,6 +20,7 @@ RUN apt-get update && apt-get install -y \
unzip \
pkg-config \
libssl-dev \
+ bear \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean
@@ -41,7 +42,8 @@ RUN python3 -m pip install --no-cache-dir \
pytest \
black \
isort \
- icdiff
+ icdiff \
+ pre-commit
# Install diff-so-fancy
RUN curl -LSso /usr/local/bin/diff-so-fancy https://github.com/so-fancy/diff-so-fancy/releases/download/v1.4.4/diff-so-fancy \
From def5c31873d9a667ab375ef13f5a77a0e5493e25 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Mon, 20 Jan 2025 00:44:30 +0800
Subject: [PATCH 013/147] docs: update supported_models (#2987)
---
docs/references/supported_models.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md
index 860841816e0..23c98ea9305 100644
--- a/docs/references/supported_models.md
+++ b/docs/references/supported_models.md
@@ -81,6 +81,7 @@ To port a model from vLLM to SGLang, you can compare these two files [SGLang Lla
- Remove `Sample`.
- Change `forward()` functions, and add `forward_batch`.
- Add `EntryClass` at the end.
+ - Please ensure the new implementation uses **only SGLang components and does not rely on any vLLM components**.
### Registering an external model implementation
From a69cb5cff7389fb6ce1b4c45c52b6796e78ce0f3 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Mon, 20 Jan 2025 00:44:49 +0800
Subject: [PATCH 014/147] cleanup unused header in sgl_kernel (#2986)
---
.../epilogue/epilogue_per_row_per_col_scale.h | 7 ++-----
.../gemm/gemm_universal_base_compat.h | 13 +++----------
.../gemm/gemm_with_epilogue_visitor.h | 13 +++++--------
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu | 2 ++
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu | 3 ---
5 files changed, 12 insertions(+), 26 deletions(-)
diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
index a9deeb9a7da..c83cf49ad83 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
+++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
@@ -3,11 +3,8 @@
#pragma once
-#include "cutlass/arch/memory.h"
-#include "cutlass/arch/memory_sm75.h"
-#include "cutlass/cutlass.h"
-#include "cutlass/fast_math.h"
-#include "cutlass/numeric_conversion.h"
+#include
+#include
namespace cutlass {
namespace epilogue {
diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
index 10be552a8ec..33e82decc2b 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
+++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h
@@ -2,16 +2,9 @@
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
#pragma once
-#include "cutlass/arch/arch.h"
-#include "cutlass/cutlass.h"
-#include "cutlass/device_kernel.h"
-#include "cutlass/gemm/device/default_gemm_configuration.h"
-#include "cutlass/gemm/gemm.h"
-#include "cutlass/gemm/kernel/default_gemm_universal.h"
-#include "cutlass/gemm/kernel/gemm_universal.h"
-#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
-#include "cutlass/numeric_types.h"
-#include "cutlass/trace.h"
+#include
+#include
+#include
////////////////////////////////////////////////////////////////////////////////
diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
index cf0b9cfa3e9..674e191a077 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
+++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h
@@ -3,14 +3,11 @@
#pragma once
-#include "cutlass/complex.h"
-#include "cutlass/cutlass.h"
-#include "cutlass/fast_math.h"
-#include "cutlass/gemm/gemm.h"
-#include "cutlass/matrix_coord.h"
-#include "cutlass/semaphore.h"
-#include "cutlass/trace.h"
-#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
+#include
+#include
+#include
+#include
+#include
/////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
index b9879b114fe..99d0326cf07 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
@@ -1,3 +1,5 @@
+#include
+
#include "utils.hpp"
// trt_reduce
diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
index d80beedec82..d647c349602 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
@@ -3,9 +3,6 @@
#include
#include
-#include
-#include
-#include
#include "trt_reduce_internal.cuh"
From 8b6a4486ecbf83c915c6a9d3c727d188a22f455e Mon Sep 17 00:00:00 2001
From: giorgiopiatti-dfinity
Date: Sun, 19 Jan 2025 20:36:07 +0100
Subject: [PATCH 015/147] fix missing revision arg when loading tokenizer
(#2982)
---
python/sglang/srt/managers/detokenizer_manager.py | 1 +
python/sglang/srt/managers/scheduler.py | 2 ++
python/sglang/srt/managers/tokenizer_manager.py | 2 ++
python/sglang/srt/managers/tp_worker.py | 2 ++
python/sglang/srt/server.py | 1 +
5 files changed, 8 insertions(+)
diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py
index f0605ee1fea..a8dc14f0103 100644
--- a/python/sglang/srt/managers/detokenizer_manager.py
+++ b/python/sglang/srt/managers/detokenizer_manager.py
@@ -71,6 +71,7 @@ def __init__(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
+ revision=server_args.revision,
)
self.decode_status = LimitedCapacityDict()
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index d859a30a038..5df9c24cee1 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -206,6 +206,7 @@ def __init__(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
+ revision=server_args.revision,
)
self.tokenizer = self.processor.tokenizer
else:
@@ -213,6 +214,7 @@ def __init__(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
+ revision=server_args.revision,
)
# Check whether overlap can be enabled
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 033a660df5e..9cf6d9cc556 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -158,6 +158,7 @@ def __init__(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
+ revision=server_args.revision,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -171,6 +172,7 @@ def __init__(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
+ revision=server_args.revision,
)
# Store states
diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py
index 47e3eea4084..fd4dbae9900 100644
--- a/python/sglang/srt/managers/tp_worker.py
+++ b/python/sglang/srt/managers/tp_worker.py
@@ -83,6 +83,7 @@ def __init__(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
+ revision=server_args.revision,
)
self.tokenizer = self.processor.tokenizer
else:
@@ -90,6 +91,7 @@ def __init__(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
+ revision=server_args.revision,
)
self.device = self.model_runner.device
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
index b3526520cd1..a2c1cb375dc 100644
--- a/python/sglang/srt/server.py
+++ b/python/sglang/srt/server.py
@@ -1027,6 +1027,7 @@ def get_tokenizer(self):
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
trust_remote_code=self.server_args.trust_remote_code,
+ revision=self.server_args.revision,
)
async def async_generate(
From d77caa2b757044f84e0078336b43de531cdd5688 Mon Sep 17 00:00:00 2001
From: Seungduk Kim
Date: Mon, 20 Jan 2025 04:36:53 +0900
Subject: [PATCH 016/147] [#2812] Make the decode status dict capcity
adjustable by a CLI param (#2839)
---
.../srt/managers/detokenizer_manager.py | 23 ++++++++++++++++---
1 file changed, 20 insertions(+), 3 deletions(-)
diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py
index a8dc14f0103..972f9595b2c 100644
--- a/python/sglang/srt/managers/detokenizer_manager.py
+++ b/python/sglang/srt/managers/detokenizer_manager.py
@@ -15,6 +15,7 @@
import dataclasses
import logging
+import os
import signal
from collections import OrderedDict
from typing import Dict, List, Union
@@ -35,6 +36,12 @@
logger = logging.getLogger(__name__)
+# Maximum number of request states that detokenizer can hold. When exceeded,
+# oldest request states will be evicted. Default: 65536 (1<<16).
+# For more details, see: https://github.com/sgl-project/sglang/issues/2812
+# Use power of 2 values for better memory allocation.
+DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 << 16))
+
@dataclasses.dataclass
class DecodeStatus:
@@ -74,7 +81,7 @@ def __init__(
revision=server_args.revision,
)
- self.decode_status = LimitedCapacityDict()
+ self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES)
def trim_matched_stop(
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
@@ -156,7 +163,17 @@ def event_loop(self):
# Incremental decoding
output_strs = []
for i in range(bs):
- s = self.decode_status[recv_obj.rids[i]]
+ try:
+ s = self.decode_status[recv_obj.rids[i]]
+ except KeyError:
+ raise RuntimeError(
+ f"Decode status not found for request {recv_obj.rids[i]}. "
+ "It may be due to the request being evicted from the decode status due to memory pressure. "
+ "Please increase the maximum number of requests by setting "
+ "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. "
+ f"The current value is {DETOKENIZER_MAX_STATES}. "
+ "For more details, see: https://github.com/sgl-project/sglang/issues/2812"
+ )
new_text = read_texts[i][len(surr_texts[i]) :]
if recv_obj.finished_reasons[i] is None:
# Streaming chunk: update the decode status
@@ -197,7 +214,7 @@ def event_loop(self):
class LimitedCapacityDict(OrderedDict):
- def __init__(self, capacity=1 << 15, *args, **kwargs):
+ def __init__(self, capacity: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.capacity = capacity
From 2c05f81f157fdd5e532baea78bb0121a0ba2c1a0 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Mon, 20 Jan 2025 04:21:29 +0800
Subject: [PATCH 017/147] fix custom op version compatibility (#2988)
---
python/pyproject.toml | 2 +-
python/sglang/srt/layers/rotary_embedding.py | 4 +++-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/python/pyproject.toml b/python/pyproject.toml
index 379a4c9acf8..f1fcc4679d8 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -27,7 +27,7 @@ runtime_common = [
]
srt = [
"sglang[runtime_common]", "cuda-python",
- "sgl-kernel>=0.0.2.post14", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
+ "sgl-kernel>=0.0.2.post14", "torch", "vllm==0.6.4.post1",
"flashinfer==0.1.6"
]
diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py
index bc38fa8c0f9..964152905be 100644
--- a/python/sglang/srt/layers/rotary_embedding.py
+++ b/python/sglang/srt/layers/rotary_embedding.py
@@ -8,6 +8,8 @@
import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp
+from sglang.srt.layers.custom_op_util import register_custom_op
+
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
@@ -51,7 +53,7 @@ def _apply_rotary_emb(
return torch.stack((o1, o2), dim=-1).flatten(-2)
-@CustomOp.register("rotary_embedding")
+@register_custom_op("sglang_rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
From 3bcf5ecea7a1a9a4f34606c739230b99d453d09b Mon Sep 17 00:00:00 2001
From: Enrique Shockwave <33002121+qeternity@users.noreply.github.com>
Date: Sun, 19 Jan 2025 20:34:41 +0000
Subject: [PATCH 018/147] support regex in xgrammar backend (#2983)
---
docs/backend/openai_api_completions.ipynb | 2 +-
docs/backend/structured_outputs.ipynb | 3 +-
docs/references/sampling_params.md | 2 +-
python/pyproject.toml | 2 +-
.../srt/constrained/xgrammar_backend.py | 12 +-
test/srt/run_suite.py | 1 +
test/srt/test_regex_constrained.py | 186 ++++++++++++++++++
7 files changed, 200 insertions(+), 8 deletions(-)
create mode 100644 test/srt/test_regex_constrained.py
diff --git a/docs/backend/openai_api_completions.ipynb b/docs/backend/openai_api_completions.ipynb
index 8660da2f98f..42cdbb11210 100644
--- a/docs/backend/openai_api_completions.ipynb
+++ b/docs/backend/openai_api_completions.ipynb
@@ -219,7 +219,7 @@
"SGLang supports two grammar backends:\n",
"\n",
"- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n",
- "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints.\n",
+ "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n",
" - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n",
"\n",
"Initialize the XGrammar backend using `--grammar-backend xgrammar` flag\n",
diff --git a/docs/backend/structured_outputs.ipynb b/docs/backend/structured_outputs.ipynb
index 55ca0b627f9..a5e6f2335b5 100644
--- a/docs/backend/structured_outputs.ipynb
+++ b/docs/backend/structured_outputs.ipynb
@@ -16,7 +16,8 @@
"SGLang supports two grammar backends:\n",
"\n",
"- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n",
- "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints and currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md).\n",
+ "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n",
+ " - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n",
"\n",
"We suggest using XGrammar whenever possible for its better performance. For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n",
"\n",
diff --git a/docs/references/sampling_params.md b/docs/references/sampling_params.md
index 5dad3fd1259..cdc53da61a4 100644
--- a/docs/references/sampling_params.md
+++ b/docs/references/sampling_params.md
@@ -189,7 +189,7 @@ You can specify a JSON schema, regular expression or [EBNF](https://en.wikipedia
SGLang supports two grammar backends:
- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.
-- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints.
+- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.
- XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)
Initialize the XGrammar backend using `--grammar-backend xgrammar` flag
diff --git a/python/pyproject.toml b/python/pyproject.toml
index f1fcc4679d8..f97c9c26679 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -23,7 +23,7 @@ runtime_common = [
"packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
- "xgrammar>=0.1.6"
+ "xgrammar>=0.1.10"
]
srt = [
"sglang[runtime_common]", "cuda-python",
diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py
index b0b2c31c2ac..c423a567eda 100644
--- a/python/sglang/srt/constrained/xgrammar_backend.py
+++ b/python/sglang/srt/constrained/xgrammar_backend.py
@@ -19,6 +19,7 @@
import torch
from xgrammar import (
CompiledGrammar,
+ Grammar,
GrammarCompiler,
GrammarMatcher,
TokenizerInfo,
@@ -133,10 +134,13 @@ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
return None
elif key_type == "regex":
- logger.warning(
- "regex hasn't been supported by xgrammar yet. This is skipped."
- )
- return None
+ try:
+ ctx = self.grammar_compiler.compile_grammar(
+ Grammar.from_regex(key_string)
+ )
+ except RuntimeError as e:
+ logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
+ return None
else:
raise ValueError(f"Invalid key_type: {key_type}")
diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py
index fb1c6abf29b..2ed2522755a 100644
--- a/test/srt/run_suite.py
+++ b/test/srt/run_suite.py
@@ -31,6 +31,7 @@
"test_openai_server.py",
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
+ "test_regex_constrained.py",
"test_release_memory_occupation.py",
"test_request_length_validation.py",
"test_retract_decode.py",
diff --git a/test/srt/test_regex_constrained.py b/test/srt/test_regex_constrained.py
new file mode 100644
index 00000000000..6d5acec15e2
--- /dev/null
+++ b/test/srt/test_regex_constrained.py
@@ -0,0 +1,186 @@
+"""
+python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email
+python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting
+"""
+
+import json
+import unittest
+
+import requests
+
+from sglang.srt.utils import kill_process_tree
+from sglang.test.test_utils import (
+ DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
+ DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ DEFAULT_URL_FOR_TEST,
+ popen_launch_server,
+)
+
+
+def setup_class(cls, disable_overlap: bool):
+ cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
+ cls.base_url = DEFAULT_URL_FOR_TEST
+
+ other_args = [
+ "--max-running-requests",
+ "10",
+ "--grammar-backend",
+ "xgrammar",
+ ]
+
+ if disable_overlap:
+ other_args += ["--disable-overlap-schedule"]
+
+ cls.process = popen_launch_server(
+ cls.model,
+ cls.base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ other_args=other_args,
+ )
+
+
+class TestRegexConstrained(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ setup_class(cls, disable_overlap=False)
+ cls.check_jump_forward = False
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_process_tree(cls.process.pid)
+
+ def run_decode(
+ self,
+ regex,
+ prompt,
+ return_logprob=False,
+ top_logprobs_num=0,
+ n=1,
+ ):
+ response = requests.post(
+ self.base_url + "/generate",
+ json={
+ "text": prompt,
+ "sampling_params": {
+ "temperature": 0 if n == 1 else 0.5,
+ "max_new_tokens": 128,
+ "n": n,
+ "regex": regex,
+ },
+ "stream": False,
+ "return_logprob": return_logprob,
+ "top_logprobs_num": top_logprobs_num,
+ "logprob_start_len": 0,
+ },
+ )
+
+ ret = response.json()
+ print(json.dumps(ret, indent=2))
+ print("=" * 100)
+
+ if not isinstance(ret, list):
+ self.fail(f"Expected response to be a list, but got {type(ret)}")
+
+ for item in ret:
+ text = item.get("text", "").strip()
+ if not text:
+ self.fail("Generated text is empty.")
+
+ if not self.regex_match(text, regex):
+ self.fail(f"Text '{text}' does not match regex pattern.")
+
+ def regex_match(self, text, pattern):
+ import re
+
+ return re.match(pattern, text) is not None
+
+ def test_regex_generate_email(self):
+ pattern = r"^user@example\.com$"
+ prompt = "Generate an email address:"
+
+ self.run_decode(
+ regex=pattern,
+ prompt=prompt,
+ n=3,
+ )
+
+ def test_regex_generate_greeting(self):
+ pattern = r"^(Hello|Hi|Hey)$"
+ prompt = "Generate a greeting:"
+
+ self.run_decode(
+ regex=pattern,
+ prompt=prompt,
+ n=3,
+ )
+
+ def test_regex_generate_number(self):
+ pattern = r"^\d{3}$"
+ prompt = "Generate a three-digit number:"
+
+ self.run_decode(
+ regex=pattern,
+ prompt=prompt,
+ n=3,
+ )
+
+ def test_regex_generate_phone(self):
+ pattern = r"^\(\d{3}\) \d{3}-\d{4}$"
+ prompt = "Generate a phone number:"
+
+ self.run_decode(
+ regex=pattern,
+ prompt=prompt,
+ n=3,
+ )
+
+ def test_regex_generate_date(self):
+ pattern = r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"
+ prompt = "Generate a date in YYYY-MM-DD format:"
+
+ self.run_decode(
+ regex=pattern,
+ prompt=prompt,
+ n=3,
+ )
+
+ def test_regex_generate_hex_color(self):
+ pattern = r"^#[0-9A-F]{6}$"
+ prompt = "Generate a hex color code:"
+
+ self.run_decode(
+ regex=pattern,
+ prompt=prompt,
+ n=3,
+ )
+
+ def test_regex_generate_complex_json(self):
+ pattern = r'^\{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*\}$'
+ prompt = "Generate a simple JSON with name, age, and city:"
+
+ self.run_decode(
+ regex=pattern,
+ prompt=prompt,
+ n=3,
+ )
+
+ def test_regex_generate_custom_log_format(self):
+ pattern = r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$"
+ prompt = "Generate a log entry:"
+
+ self.run_decode(
+ regex=pattern,
+ prompt=prompt,
+ n=3,
+ )
+
+
+class TestJumpForward(TestRegexConstrained):
+ @classmethod
+ def setUpClass(cls):
+ setup_class(cls, disable_overlap=True)
+ cls.check_jump_forward = True
+
+
+if __name__ == "__main__":
+ unittest.main()
From e403d2375719a79c4b9e1e998474aa1ee3384399 Mon Sep 17 00:00:00 2001
From: Hongpeng Guo
Date: Sun, 19 Jan 2025 14:46:53 -0800
Subject: [PATCH 019/147] [Feature] Add sampler custom logits processor (#2396)
Signed-off-by: Hongpeng Guo
---
python/sglang/srt/layers/sampler.py | 30 ++++-
python/sglang/srt/managers/io_struct.py | 19 +++
python/sglang/srt/managers/schedule_batch.py | 2 +
python/sglang/srt/managers/scheduler.py | 14 ++
.../sglang/srt/managers/session_controller.py | 1 +
.../sglang/srt/managers/tokenizer_manager.py | 1 +
.../srt/sampling/custom_logit_processor.py | 38 ++++++
.../srt/sampling/sampling_batch_info.py | 122 +++++++++++++++++-
python/sglang/srt/sampling/sampling_params.py | 4 +-
python/sglang/srt/server.py | 4 +
python/sglang/srt/server_args.py | 8 ++
test/srt/test_srt_endpoint.py | 63 ++++++++-
12 files changed, 302 insertions(+), 4 deletions(-)
create mode 100644 python/sglang/srt/sampling/custom_logit_processor.py
diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py
index 23037650a31..e8b25da0704 100644
--- a/python/sglang/srt/layers/sampler.py
+++ b/python/sglang/srt/layers/sampler.py
@@ -1,11 +1,12 @@
import logging
-from typing import List
+from typing import Dict, List
import torch
from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
+from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
@@ -35,6 +36,10 @@ def forward(
):
logits = logits_output.next_token_logits
+ # Apply the custom logit processors if registered in the sampling info.
+ if sampling_info.has_custom_logit_processor:
+ self._apply_custom_logit_processor(logits, sampling_info)
+
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
@@ -121,6 +126,29 @@ def forward(
return batch_next_token_ids
+ def _apply_custom_logit_processor(
+ self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
+ ):
+ """Apply custom logit processors to the logits.
+ This function will modify the logits in-place."""
+
+ for _, (
+ processor,
+ batch_mask,
+ ) in sampling_batch_info.custom_logit_processor.items():
+ # Get the batch indices that need to be processed
+ batch_indices = batch_mask.nonzero(as_tuple=True)[0]
+
+ # Apply the processor to the logits
+ logits[batch_mask] = processor(
+ logits[batch_mask],
+ [sampling_batch_info.custom_params[i] for i in batch_indices],
+ )
+
+ logger.debug(
+ f"Custom logit processor {processor.__class__.__name__} is applied."
+ )
+
def top_k_top_p_min_p_sampling_from_probs_torch(
probs: torch.Tensor,
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index c5a35ced00c..5a803dd997a 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -22,6 +22,7 @@
from typing import Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason
+from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import SamplingParams
@@ -69,6 +70,8 @@ class GenerateReqInput:
# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
+ # Custom logit processor (serialized function)
+ custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None
def normalize_batch_and_arguments(self):
if (
@@ -183,6 +186,13 @@ def normalize_batch_and_arguments(self):
else:
assert self.parallel_sample_num == 1
+ if self.custom_logit_processor is None:
+ self.custom_logit_processor = [None] * num
+ elif not isinstance(self.custom_logit_processor, list):
+ self.custom_logit_processor = [self.custom_logit_processor] * num
+ else:
+ assert self.parallel_sample_num == 1
+
def regenerate_rid(self):
self.rid = uuid.uuid4().hex
return self.rid
@@ -202,6 +212,11 @@ def __getitem__(self, i):
log_metrics=self.log_metrics,
modalities=self.modalities[i] if self.modalities else None,
lora_path=self.lora_path[i] if self.lora_path is not None else None,
+ custom_logit_processor=(
+ self.custom_logit_processor[i]
+ if self.custom_logit_processor is not None
+ else None
+ ),
)
@@ -234,6 +249,10 @@ class TokenizedGenerateReqInput:
# Session info for continual prompting
session_params: Optional[SessionParams] = None
+ # Custom logit processor (serialized function)
+ # TODO (hpguo): Add an example and update doc string here
+ custom_logit_processor: Optional[str] = None
+
@dataclass
class EmbeddingReqInput:
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index afbc98b7ca9..a09810a3871 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -232,6 +232,7 @@ def __init__(
lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None,
+ custom_logit_processor: Optional[str] = None,
eos_token_ids: Optional[Set[int]] = None,
):
# Input and output info
@@ -252,6 +253,7 @@ def __init__(
# Sampling info
self.sampling_params = sampling_params
self.lora_path = lora_path
+ self.custom_logit_processor = custom_logit_processor
# Memory pool info
self.req_pool_idx = None
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 5df9c24cee1..a89bd1bc4f4 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -614,6 +614,19 @@ def handle_generate_request(
fake_input_ids = [1] * seq_length
recv_req.input_ids = fake_input_ids
+ # Handle custom logit processor passed to the request
+ custom_logit_processor = recv_req.custom_logit_processor
+ if (
+ not self.server_args.enable_custom_logit_processor
+ and custom_logit_processor is not None
+ ):
+ logger.warning(
+ "The SGLang server is not configured to enable custom logit processor."
+ "The custom logit processor passed in will be ignored."
+ "Please set --enable-custom-logits-processor to enable this feature."
+ )
+ custom_logit_processor = None
+
req = Req(
recv_req.rid,
recv_req.input_text,
@@ -624,6 +637,7 @@ def handle_generate_request(
stream=recv_req.stream,
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
+ custom_logit_processor=custom_logit_processor,
eos_token_ids=self.model_config.hf_eos_token_id,
)
req.tokenizer = self.tokenizer
diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py
index e9c0c909d52..4f4af636757 100644
--- a/python/sglang/srt/managers/session_controller.py
+++ b/python/sglang/srt/managers/session_controller.py
@@ -131,6 +131,7 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
sampling_params=req.sampling_params,
lora_path=req.lora_path,
session_id=self.session_id,
+ custom_logit_processor=req.custom_logit_processor,
)
if last_req is not None:
new_req.image_inputs = last_req.image_inputs
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 9cf6d9cc556..3e349300553 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -381,6 +381,7 @@ async def _tokenize_one_request(
lora_path=obj.lora_path,
input_embeds=input_embeds,
session_params=session_params,
+ custom_logit_processor=obj.custom_logit_processor,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py
new file mode 100644
index 00000000000..a64b2498f23
--- /dev/null
+++ b/python/sglang/srt/sampling/custom_logit_processor.py
@@ -0,0 +1,38 @@
+import json
+from abc import ABC, abstractmethod
+from functools import lru_cache
+from typing import Any, Dict, List, Optional
+
+import dill
+import torch
+
+
+@lru_cache(maxsize=None)
+def _cache_from_str(json_str: str):
+ """Deserialize a json string to a Callable object.
+ This function is cached to avoid redundant deserialization.
+ """
+ data = json.loads(json_str)
+ return dill.loads(bytes.fromhex(data["callable"]))
+
+
+class CustomLogitProcessor(ABC):
+ """Abstract base class for callable functions."""
+
+ @abstractmethod
+ def __call__(
+ self,
+ logits: torch.Tensor,
+ custom_param_list: Optional[List[Dict[str, Any]]] = None,
+ ) -> torch.Tensor:
+ """Define the callable behavior."""
+ raise NotImplementedError
+
+ def to_str(self) -> str:
+ """Serialize the callable function to a JSON-compatible string."""
+ return json.dumps({"callable": dill.dumps(self).hex()})
+
+ @classmethod
+ def from_str(cls, json_str: str):
+ """Deserialize a callable function from a JSON string."""
+ return _cache_from_str(json_str)
diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py
index 6eda63c706a..d4c5c32386a 100644
--- a/python/sglang/srt/sampling/sampling_batch_info.py
+++ b/python/sglang/srt/sampling/sampling_batch_info.py
@@ -3,7 +3,7 @@
import dataclasses
import logging
import threading
-from typing import TYPE_CHECKING, Callable, List, Optional
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch
@@ -14,6 +14,7 @@
from sgl_kernel import sampling_scaling_penalties
import sglang.srt.sampling.penaltylib as penaltylib
+from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
logger = logging.getLogger(__name__)
@@ -36,6 +37,9 @@ class SamplingBatchInfo:
# Dispatch in CUDA graph
need_min_p_sampling: bool
+ # Whether any request has custom logit processor
+ has_custom_logit_processor: bool
+
# Bias Tensors
vocab_size: int
grammars: Optional[List] = None
@@ -52,6 +56,14 @@ class SamplingBatchInfo:
# Device
device: str = "cuda"
+ # Custom Parameters
+ custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
+
+ # Custom Logit Processor
+ custom_logit_processor: Optional[
+ Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
+ ] = None
+
@classmethod
def from_schedule_batch(
cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
@@ -76,6 +88,36 @@ def from_schedule_batch(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)
+ # Check if any request has custom logit processor
+ has_custom_logit_processor = any(r.custom_logit_processor for r in reqs)
+
+ if has_custom_logit_processor:
+ # Merge the same type of custom logit processors together
+ processor_dict = {}
+ for i, r in enumerate(reqs):
+ if r.custom_logit_processor is None:
+ continue
+ processor_str = r.custom_logit_processor
+ if processor_str not in processor_dict:
+ processor_dict[processor_str] = []
+ processor_dict[processor_str].append(i)
+
+ merged_custom_logit_processor = {
+ hash(processor_str): (
+ # The deserialized custom logit processor object
+ CustomLogitProcessor.from_str(processor_str),
+ # The mask tensor for the requests that use this custom logit processor
+ torch.zeros(len(reqs), dtype=torch.bool)
+ .scatter_(0, torch.tensor(true_indices), True)
+ .to(device, non_blocking=True),
+ )
+ for processor_str, true_indices in processor_dict.items()
+ }
+ custom_params = [r.sampling_params.custom_params for r in reqs]
+ else:
+ merged_custom_logit_processor = None
+ custom_params = None
+
ret = cls(
temperatures=temperatures,
top_ps=top_ps,
@@ -83,8 +125,11 @@ def from_schedule_batch(
min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
+ has_custom_logit_processor=has_custom_logit_processor,
vocab_size=vocab_size,
device=device,
+ custom_params=custom_params,
+ custom_logit_processor=merged_custom_logit_processor,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
@@ -184,6 +229,8 @@ def update_regex_vocab_mask(self):
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
+ if self.has_custom_logit_processor:
+ self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
for item in [
"temperatures",
@@ -196,6 +243,26 @@ def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor)
if value is not None: # logit_bias can be None
setattr(self, item, value[new_indices])
+ def _filter_batch_custom_logit_processor(
+ self, unfinished_indices: List[int], new_indices: torch.Tensor
+ ):
+ """Filter the custom logit processor and custom params"""
+ if not self.custom_logit_processor:
+ return
+ self.custom_logit_processor = {
+ k: (p, mask[new_indices])
+ for k, (p, mask) in self.custom_logit_processor.items()
+ if any(
+ mask[new_indices]
+ ) # ignore the custom logit processor whose mask is all False
+ }
+ self.custom_params = [self.custom_params[i] for i in unfinished_indices]
+
+ if len(self) == 0:
+ self.custom_logit_processor = None
+ self.custom_params = None
+ self.has_custom_logit_processor = False
+
@staticmethod
def merge_bias_tensor(
lhs: torch.Tensor,
@@ -221,6 +288,39 @@ def merge_bias_tensor(
return None
+ @staticmethod
+ def merge_custom_logit_processor(
+ lhs: Optional[Dict[str, torch.Tensor]],
+ rhs: Optional[Dict[str, torch.Tensor]],
+ bs1: int,
+ bs2: int,
+ device: str,
+ ):
+ if lhs is None and rhs is None:
+ return None
+ lhs, rhs = lhs or {}, rhs or {}
+
+ keys = set(lhs.keys()).union(set(rhs.keys()))
+ merged_dict = {}
+
+ for k in keys:
+ # Get the logit processor object
+ processor = lhs[k][0] if k in lhs else rhs[k][0]
+ # Get and merge the mask tensors from the two dicts
+ left_mask = (
+ lhs[k][1]
+ if k in lhs
+ else torch.zeros(bs1, dtype=torch.bool, device=device)
+ )
+ right_mask = (
+ rhs[k][1]
+ if k in rhs
+ else torch.zeros(bs2, dtype=torch.bool, device=device)
+ )
+ merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
+
+ return merged_dict
+
def merge_batch(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
@@ -240,6 +340,26 @@ def merge_batch(self, other: "SamplingBatchInfo"):
)
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
+ # Merge the custom logit processors and custom params lists
+ if self.has_custom_logit_processor or other.has_custom_logit_processor:
+ # Merge the custom logit processors
+ self.custom_logit_processor = (
+ SamplingBatchInfo.merge_custom_logit_processor(
+ self.custom_logit_processor,
+ other.custom_logit_processor,
+ len(self),
+ len(other),
+ self.device,
+ )
+ )
+ # Merge the custom params lists
+ self.custom_params = self.custom_params or [None] * len(self)
+ other.custom_params = other.custom_params or [None] * len(other)
+ self.custom_params.extend(other.custom_params)
+
+ # Set the flag to True if any of the two has custom logit processor
+ self.has_custom_logit_processor = True
+
def apply_logits_bias(self, logits: torch.Tensor):
# Apply logit_bias
if self.logit_bias is not None:
diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py
index d1d932693c6..2224fb0919a 100644
--- a/python/sglang/srt/sampling/sampling_params.py
+++ b/python/sglang/srt/sampling/sampling_params.py
@@ -13,7 +13,7 @@
# ==============================================================================
"""Sampling parameters for text generation."""
-from typing import List, Optional, Union
+from typing import Any, Dict, List, Optional, Union
_SAMPLING_EPS = 1e-6
@@ -48,6 +48,7 @@ def __init__(
no_stop_trim: bool = False,
ignore_eos: bool = False,
skip_special_tokens: bool = True,
+ custom_params: Optional[Dict[str, Any]] = None,
) -> None:
self.temperature = temperature
self.top_p = top_p
@@ -71,6 +72,7 @@ def __init__(
self.json_schema = json_schema
self.ebnf = ebnf
self.no_stop_trim = no_stop_trim
+ self.custom_params = custom_params
# Process some special cases
if self.temperature < _SAMPLING_EPS:
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
index a2c1cb375dc..2cb2cd95dc8 100644
--- a/python/sglang/srt/server.py
+++ b/python/sglang/srt/server.py
@@ -773,6 +773,7 @@ def generate(
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
+ custom_logit_processor: Optional[Union[List[str], str]] = None,
stream: bool = False,
):
obj = GenerateReqInput(
@@ -784,6 +785,7 @@ def generate(
top_logprobs_num=top_logprobs_num,
lora_path=lora_path,
stream=stream,
+ custom_logit_processor=custom_logit_processor,
)
# get the current event loop
@@ -824,6 +826,7 @@ async def async_generate(
logprob_start_len: Optional[Union[List[int], int]] = None,
top_logprobs_num: Optional[Union[List[int], int]] = None,
lora_path: Optional[List[Optional[str]]] = None,
+ custom_logit_processor: Optional[Union[str, List[str]]] = None,
stream: bool = False,
):
obj = GenerateReqInput(
@@ -835,6 +838,7 @@ async def async_generate(
top_logprobs_num=top_logprobs_num,
lora_path=lora_path,
stream=stream,
+ custom_logit_processor=custom_logit_processor,
)
ret = await generate_request(obj, None)
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 052e316b7c4..6dd0b945654 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -159,6 +159,9 @@ class ServerArgs:
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
+ # Custom logit processor
+ enable_custom_logit_processor: bool = False
+
def __post_init__(self):
# Set missing default values
if self.tokenizer_path is None:
@@ -865,6 +868,11 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
)
+ parser.add_argument(
+ "--enable-custom-logit-processor",
+ action="store_true",
+ help="Enable users to pass custom logit processors to the server (disabled by default for security)",
+ )
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py
index 0fd71efcb0b..7afdc9bf41c 100644
--- a/test/srt/test_srt_endpoint.py
+++ b/test/srt/test_srt_endpoint.py
@@ -5,10 +5,12 @@
import json
import unittest
+from concurrent.futures import ThreadPoolExecutor
import numpy as np
import requests
+from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
@@ -24,7 +26,10 @@ def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
- cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
+ cls.model,
+ cls.base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ other_args=("--enable-custom-logit-processor",),
)
@classmethod
@@ -248,6 +253,62 @@ def test_logprob_grammar(self):
self.assertTrue(all(x is not None for x in logprobs))
+ def run_custom_logit_processor(self, target_token_id: int):
+ """Test custom logit processor with custom params."""
+
+ custom_params = {"token_id": target_token_id}
+
+ class DeterministicLogitProcessor(CustomLogitProcessor):
+ """A dummy logit processor that changes the logits to always
+ sample the given token id.
+ """
+
+ def __call__(self, logits, custom_param_list):
+ assert logits.shape[0] == len(custom_param_list)
+ key = "token_id"
+
+ for i, param_dict in enumerate(custom_param_list):
+ # Mask all other tokens
+ logits[i, :] = -float("inf")
+ # Assign highest probability to the specified token
+ logits[i, param_dict[key]] = 0.0
+ return logits
+
+ prompts = "Question: Is Paris the Capital of France? Answer:"
+
+ # Base case json data to be posted to the server.
+ base_json = {
+ "text": prompts,
+ "sampling_params": {"temperature": 0.0},
+ "return_logprob": True,
+ }
+
+ # Custom json data with custom logit processor and params.
+ custom_json = base_json.copy()
+ custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str()
+ custom_json["sampling_params"]["custom_params"] = custom_params
+
+ custom_response = requests.post(
+ self.base_url + "/generate",
+ json=custom_json,
+ ).json()
+
+ output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
+ sampled_tokens = [x[1] for x in output_token_logprobs]
+
+ # The logit processor should always sample the given token as the logits is deterministic.
+ self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens))
+
+ def test_custom_logit_processor(self):
+ """Test custom logit processor with a single request."""
+ self.run_custom_logit_processor(target_token_id=5)
+
+ def test_custom_logit_processor_batch(self):
+ """Test custom logit processor with a batch of requests."""
+ target_token_ids = list(range(32))
+ with ThreadPoolExecutor(len(target_token_ids)) as executor:
+ list(executor.map(self.run_custom_logit_processor, target_token_ids))
+
def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info")
response_json = response.json()
From 61f42b5732a0740ed9a416a098b96e7e6e14f277 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Sun, 19 Jan 2025 17:10:29 -0800
Subject: [PATCH 020/147] Move sgl.Runtime under sglang/lang (#2990)
---
.../frontend_language/usage/json_decode.py | 2 +-
.../models/character_generation/1/model.py | 4 +-
examples/runtime/async_io_api.py | 46 -----
python/sglang/api.py | 7 +-
python/sglang/bench_offline_throughput.py | 3 +-
.../sglang/lang/backend/runtime_endpoint.py | 169 +++++++++++++++++-
python/sglang/launch_server_llavavid.py | 25 ---
python/sglang/srt/constrained/__init__.py | 16 --
.../srt/constrained/base_grammar_backend.py | 21 +++
python/sglang/srt/managers/scheduler.py | 109 +++++------
.../sglang/srt/managers/tokenizer_manager.py | 4 +-
python/sglang/srt/server.py | 160 -----------------
python/sglang/test/runners.py | 20 +--
scripts/deprecated/test_jump_forward.py | 2 +-
test/lang/test_srt_backend.py | 2 +-
test/srt/models/test_qwen_models.py | 2 +-
test/srt/models/test_reward_models.py | 4 +-
17 files changed, 267 insertions(+), 329 deletions(-)
delete mode 100644 examples/runtime/async_io_api.py
delete mode 100644 python/sglang/launch_server_llavavid.py
delete mode 100644 python/sglang/srt/constrained/__init__.py
diff --git a/examples/frontend_language/usage/json_decode.py b/examples/frontend_language/usage/json_decode.py
index ce8f5ba7062..5dc3522d512 100644
--- a/examples/frontend_language/usage/json_decode.py
+++ b/examples/frontend_language/usage/json_decode.py
@@ -9,7 +9,7 @@
from pydantic import BaseModel
import sglang as sgl
-from sglang.srt.constrained import build_regex_from_object
+from sglang.srt.constrained.outlines_backend import build_regex_from_object
character_regex = (
r"""\{\n"""
diff --git a/examples/frontend_language/usage/triton/models/character_generation/1/model.py b/examples/frontend_language/usage/triton/models/character_generation/1/model.py
index 5550e93984b..4bf86f1b691 100644
--- a/examples/frontend_language/usage/triton/models/character_generation/1/model.py
+++ b/examples/frontend_language/usage/triton/models/character_generation/1/model.py
@@ -3,8 +3,8 @@
from pydantic import BaseModel
import sglang as sgl
-from sglang import function, set_default_backend
-from sglang.srt.constrained import build_regex_from_object
+from sglang import function
+from sglang.srt.constrained.outlines_backend import build_regex_from_object
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
diff --git a/examples/runtime/async_io_api.py b/examples/runtime/async_io_api.py
deleted file mode 100644
index 23d3d0b90bf..00000000000
--- a/examples/runtime/async_io_api.py
+++ /dev/null
@@ -1,46 +0,0 @@
-"""
-Usage:
-
-python3 async_io.py
-"""
-
-import asyncio
-
-from sglang import Runtime
-
-
-async def generate(
- engine,
- prompt,
- sampling_params,
-):
- tokenizer = engine.get_tokenizer()
-
- messages = [
- {
- "role": "system",
- "content": "You will be given question answer tasks.",
- },
- {"role": "user", "content": prompt},
- ]
-
- prompt = tokenizer.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
-
- stream = engine.add_request(prompt, sampling_params)
-
- async for output in stream:
- print(output, end="", flush=True)
- print()
-
-
-if __name__ == "__main__":
- runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
- print("--- runtime ready ---\n")
-
- prompt = "Who is Alan Turing?"
- sampling_params = {"max_new_tokens": 128}
- asyncio.run(generate(runtime, prompt, sampling_params))
-
- runtime.shutdown()
diff --git a/python/sglang/api.py b/python/sglang/api.py
index 9a30ad492da..a9c5fa9da99 100644
--- a/python/sglang/api.py
+++ b/python/sglang/api.py
@@ -1,6 +1,5 @@
"""Public APIs of the language."""
-import os
import re
from typing import Callable, List, Optional, Union
@@ -33,17 +32,13 @@ def decorator(func):
def Runtime(*args, **kwargs):
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
-
# Avoid importing unnecessary dependency
- from sglang.srt.server import Runtime
+ from sglang.lang.backend.runtime_endpoint import Runtime
return Runtime(*args, **kwargs)
def Engine(*args, **kwargs):
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
-
# Avoid importing unnecessary dependency
from sglang.srt.server import Engine
diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py
index 54b042c115d..6b31ac40e11 100644
--- a/python/sglang/bench_offline_throughput.py
+++ b/python/sglang/bench_offline_throughput.py
@@ -27,7 +27,8 @@
sample_random_requests,
set_ulimit,
)
-from sglang.srt.server import Engine, Runtime
+from sglang.lang.backend.runtime_endpoint import Runtime
+from sglang.srt.server import Engine
from sglang.srt.server_args import ServerArgs
diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py
index a0032591226..23e9f1afbc6 100644
--- a/python/sglang/lang/backend/runtime_endpoint.py
+++ b/python/sglang/lang/backend/runtime_endpoint.py
@@ -1,6 +1,11 @@
+import atexit
import json
+import multiprocessing
import warnings
-from typing import List, Optional
+from typing import Dict, List, Optional, Union
+
+import aiohttp
+import requests
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
@@ -14,6 +19,9 @@
REGEX_STR,
SglSamplingParams,
)
+from sglang.srt.hf_transformers_utils import get_tokenizer
+from sglang.srt.server_args import ServerArgs
+from sglang.srt.utils import is_port_available, kill_process_tree
from sglang.utils import http_request
@@ -325,3 +333,162 @@ def _assert_success(self, res):
def compute_normalized_prompt_logprobs(input_logprobs):
values = [x[0] for x in input_logprobs if x[0]]
return sum(values) / len(values)
+
+
+class Runtime:
+ """
+ A wrapper for the HTTP server.
+ This is used for launching the server in a python program without
+ using the commond line interface.
+
+ It is mainly used for the frontend language.
+ You should use the Engine class if you want to do normal offline processing.
+ """
+
+ def __init__(
+ self,
+ log_level: str = "error",
+ *args,
+ **kwargs,
+ ):
+ """See the arguments in server_args.py::ServerArgs"""
+ from sglang.srt.server import launch_server
+
+ self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
+
+ # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
+ atexit.register(self.shutdown)
+
+ # Pre-allocate ports
+ for port in range(self.server_args.port, 40000):
+ if is_port_available(port):
+ break
+ self.server_args.port = port
+
+ self.url = self.server_args.url()
+ self.generate_url = self.url + "/generate"
+
+ # NOTE: We store pid instead of proc to fix some issues during __delete__
+ self.pid = None
+ pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False)
+
+ proc = multiprocessing.Process(
+ target=launch_server,
+ args=(self.server_args, pipe_writer),
+ )
+ proc.start()
+ pipe_writer.close()
+ self.pid = proc.pid
+
+ try:
+ init_state = pipe_reader.recv()
+ except EOFError:
+ init_state = ""
+
+ if init_state != "ready":
+ self.shutdown()
+ raise RuntimeError(
+ "Initialization failed. Please see the error messages above."
+ )
+
+ self.endpoint = RuntimeEndpoint(self.url)
+
+ def shutdown(self):
+ if self.pid is not None:
+ kill_process_tree(self.pid)
+ self.pid = None
+
+ def cache_prefix(self, prefix: str):
+ self.endpoint.cache_prefix(prefix)
+
+ def get_tokenizer(self):
+ return get_tokenizer(
+ self.server_args.tokenizer_path,
+ tokenizer_mode=self.server_args.tokenizer_mode,
+ trust_remote_code=self.server_args.trust_remote_code,
+ revision=self.server_args.revision,
+ )
+
+ async def async_generate(
+ self,
+ prompt: str,
+ sampling_params: Optional[Dict] = None,
+ ):
+ if self.server_args.skip_tokenizer_init:
+ json_data = {
+ "input_ids": prompt,
+ "sampling_params": sampling_params,
+ "stream": True,
+ }
+ else:
+ json_data = {
+ "text": prompt,
+ "sampling_params": sampling_params,
+ "stream": True,
+ }
+ pos = 0
+
+ timeout = aiohttp.ClientTimeout(total=3 * 3600)
+ async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
+ async with session.post(self.generate_url, json=json_data) as response:
+ async for chunk, _ in response.content.iter_chunks():
+ chunk = chunk.decode("utf-8")
+ if chunk and chunk.startswith("data:"):
+ if chunk == "data: [DONE]\n\n":
+ break
+ data = json.loads(chunk[5:].strip("\n"))
+ if "text" in data:
+ cur = data["text"][pos:]
+ if cur:
+ yield cur
+ pos += len(cur)
+ else:
+ yield data
+
+ add_request = async_generate
+
+ def generate(
+ self,
+ prompt: Union[str, List[str]],
+ sampling_params: Optional[Dict] = None,
+ return_logprob: Optional[Union[List[bool], bool]] = False,
+ logprob_start_len: Optional[Union[List[int], int]] = None,
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
+ lora_path: Optional[List[Optional[str]]] = None,
+ ):
+ json_data = {
+ "text": prompt,
+ "sampling_params": sampling_params,
+ "return_logprob": return_logprob,
+ "logprob_start_len": logprob_start_len,
+ "top_logprobs_num": top_logprobs_num,
+ "lora_path": lora_path,
+ }
+ assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
+ response = requests.post(
+ self.url + "/generate",
+ json=json_data,
+ )
+ return json.dumps(response.json())
+
+ def encode(
+ self,
+ prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
+ ):
+ json_data = {"text": prompt}
+ response = requests.post(self.url + "/encode", json=json_data)
+ return json.dumps(response.json())
+
+ async def get_server_info(self):
+ async with aiohttp.ClientSession() as session:
+ async with session.get(f"{self.url}/get_server_info") as response:
+ if response.status == 200:
+ return await response.json()
+ else:
+ error_data = await response.json()
+ raise RuntimeError(
+ f"Failed to get server info. {error_data['error']['message']}"
+ )
+
+ def __del__(self):
+ self.shutdown()
diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py
deleted file mode 100644
index 138c2127e16..00000000000
--- a/python/sglang/launch_server_llavavid.py
+++ /dev/null
@@ -1,25 +0,0 @@
-"""Launch the inference server for Llava-video model."""
-
-import json
-import sys
-
-from sglang.srt.server import launch_server, prepare_server_args
-
-if __name__ == "__main__":
- server_args = prepare_server_args(sys.argv[1:])
-
- model_override_args = {}
- model_override_args["mm_spatial_pool_stride"] = 2
- model_override_args["architectures"] = ["LlavaVidForCausalLM"]
- model_override_args["num_frames"] = 16
- model_override_args["model_type"] = "llavavid"
- if model_override_args["num_frames"] == 32:
- model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"}
- model_override_args["max_sequence_length"] = 4096 * 2
- model_override_args["tokenizer_model_max_length"] = 4096 * 2
- model_override_args["model_max_length"] = 4096 * 2
- if "34b" in server_args.model_path.lower():
- model_override_args["image_token_index"] = 64002
- server_args.json_model_override_args = json.dumps(model_override_args)
-
- launch_server(server_args)
diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py
deleted file mode 100644
index 458d1925241..00000000000
--- a/python/sglang/srt/constrained/__init__.py
+++ /dev/null
@@ -1,16 +0,0 @@
-# Copyright 2023-2024 SGLang Team
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-# TODO(lmzheng): make this an optional dependency
-from sglang.srt.constrained.outlines_backend import build_regex_from_object
diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py
index 7c88229cf16..6f304ea171e 100644
--- a/python/sglang/srt/constrained/base_grammar_backend.py
+++ b/python/sglang/srt/constrained/base_grammar_backend.py
@@ -18,6 +18,8 @@
from threading import Event, Lock
from typing import Any, Optional, Tuple
+from sglang.srt.server_args import ServerArgs
+
@dataclass
class CacheEntry:
@@ -69,3 +71,22 @@ def get_future_value(self, key: Tuple[str, str]) -> Future:
def reset(self):
with self.cache_lock:
self.cache.clear()
+
+
+def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
+ if server_args.grammar_backend == "outlines":
+ from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
+
+ grammar_backend = OutlinesGrammarBackend(
+ tokenizer,
+ whitespace_pattern=server_args.constrained_json_whitespace_pattern,
+ allow_jump_forward=not server_args.disable_jump_forward,
+ )
+ elif server_args.grammar_backend == "xgrammar":
+ from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
+
+ grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size)
+ else:
+ raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
+
+ return grammar_backend
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index a89bd1bc4f4..ece5b266455 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -34,6 +34,7 @@
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
+from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -149,9 +150,7 @@ def __init__(
else 1
)
- # Init inter-process communication
- context = zmq.Context(2)
-
+ # Distributed rank info
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
compute_dp_attention_world_info(
@@ -162,6 +161,8 @@ def __init__(
)
)
+ # Init inter-process communication
+ context = zmq.Context(2)
if self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
@@ -243,7 +244,7 @@ def __init__(
nccl_port=port_args.nccl_port,
)
- # Launch worker for speculative decoding if need
+ # Launch a worker for speculative decoding if needed
if self.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -316,6 +317,8 @@ def __init__(
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
+ self.spec_num_total_accepted_tokens = 0
+ self.spec_num_total_forward_ct = 0
self.last_decode_stats_tic = time.time()
self.stream_interval = server_args.stream_interval
self.current_stream = torch.get_device_module(self.device).current_stream()
@@ -337,28 +340,9 @@ def __init__(
# Init the grammar backend for constrained generation
self.grammar_queue: List[Req] = []
if not server_args.skip_tokenizer_init:
- if server_args.grammar_backend == "outlines":
- from sglang.srt.constrained.outlines_backend import (
- OutlinesGrammarBackend,
- )
-
- self.grammar_backend = OutlinesGrammarBackend(
- self.tokenizer,
- whitespace_pattern=server_args.constrained_json_whitespace_pattern,
- allow_jump_forward=not server_args.disable_jump_forward,
- )
- elif server_args.grammar_backend == "xgrammar":
- from sglang.srt.constrained.xgrammar_backend import (
- XGrammarGrammarBackend,
- )
-
- self.grammar_backend = XGrammarGrammarBackend(
- self.tokenizer, vocab_size=self.model_config.vocab_size
- )
- else:
- raise ValueError(
- f"Invalid grammar backend: {server_args.grammar_backend}"
- )
+ self.grammar_backend = create_grammar_backend(
+ server_args, self.tokenizer, self.model_config.vocab_size
+ )
else:
self.grammar_backend = None
@@ -424,7 +408,8 @@ def __init__(
},
)
- self._dispatcher = TypeBasedDispatcher(
+ # Init request dispatcher
+ self._request_dispatcher = TypeBasedDispatcher(
[
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
@@ -480,10 +465,6 @@ def event_loop_normal(self):
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
-
- if self.server_args.enable_dp_attention: # TODO: simplify this
- batch = self.prepare_dp_attn_batch(batch)
-
self.cur_batch = batch
if batch:
@@ -506,10 +487,6 @@ def event_loop_overlap(self):
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
-
- if self.server_args.enable_dp_attention: # TODO: simplify this
- batch = self.prepare_dp_attn_batch(batch)
-
self.cur_batch = batch
if batch:
@@ -517,7 +494,7 @@ def event_loop_overlap(self):
result_queue.append((batch.copy(), result))
if self.last_batch is None:
- # Create a dummy first batch to start the pipeline for overlap scheduler.
+ # Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
@@ -593,7 +570,7 @@ def recv_requests(self) -> List[Req]:
def process_input_requests(self, recv_reqs: List):
for recv_req in recv_reqs:
- output = self._dispatcher(recv_req)
+ output = self._request_dispatcher(recv_req)
if output is not None:
self.send_to_tokenizer.send_pyobj(output)
@@ -798,15 +775,32 @@ def log_decode_stats(self):
self.num_generated_tokens = 0
self.last_decode_stats_tic = time.time()
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
- logger.info(
- f"Decode batch. "
- f"#running-req: {num_running_reqs}, "
- f"#token: {num_used}, "
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
- f"gen throughput (token/s): {gen_throughput:.2f}, "
- f"#queue-req: {len(self.waiting_queue)}"
- )
+ if self.spec_algorithm.is_none():
+ msg = (
+ f"Decode batch. "
+ f"#running-req: {num_running_reqs}, "
+ f"#token: {num_used}, "
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
+ f"gen throughput (token/s): {gen_throughput:.2f}, "
+ f"#queue-req: {len(self.waiting_queue)}"
+ )
+ else:
+ accept_length = (
+ self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
+ )
+ self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
+ msg = (
+ f"Decode batch. "
+ f"#running-req: {num_running_reqs}, "
+ f"#token: {num_used}, "
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
+ f"accept len: {accept_length:.2f}, "
+ f"gen throughput (token/s): {gen_throughput:.2f}, "
+ f"#queue-req: {len(self.waiting_queue)}"
+ )
+
+ logger.info(msg)
if self.enable_metrics:
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
@@ -855,16 +849,23 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
else:
self.running_batch.merge_batch(self.last_batch)
- # Run prefill first if possible
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
- return new_batch
+ # Run prefill first if possible
+ ret = new_batch
+ else:
+ # Run decode
+ if self.running_batch is None:
+ ret = None
+ else:
+ self.running_batch = self.update_running_batch(self.running_batch)
+ ret = self.running_batch
- # Run decode
- if self.running_batch is None:
- return None
- self.running_batch = self.update_running_batch(self.running_batch)
- return self.running_batch
+ # Handle DP attention
+ if self.server_args.enable_dp_attention:
+ ret = self.prepare_dp_attn_batch(ret)
+
+ return ret
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar is ready in the grammar queue
@@ -1053,6 +1054,10 @@ def run_batch(
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
+ self.spec_num_total_accepted_tokens += (
+ num_accepted_tokens + batch.batch_size()
+ )
+ self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
else:
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 3e349300553..d6178a959d0 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -224,7 +224,7 @@ def __init__(
},
)
- self._dispatcher = TypeBasedDispatcher(
+ self._result_dispatcher = TypeBasedDispatcher(
[
(BatchStrOut, self._handle_batch_output),
(BatchEmbeddingOut, self._handle_batch_output),
@@ -760,7 +760,7 @@ async def handle_loop(self):
while True:
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
- self._dispatcher(recv_obj)
+ self._result_dispatcher(recv_obj)
def _handle_batch_output(
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
index 2cb2cd95dc8..0b4d9c37218 100644
--- a/python/sglang/srt/server.py
+++ b/python/sglang/srt/server.py
@@ -45,8 +45,6 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
-from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
-from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
)
@@ -90,7 +88,6 @@
assert_pkg_version,
configure_logger,
delete_directory,
- is_port_available,
kill_process_tree,
maybe_set_triton_cache_manager,
prepare_model_and_tokenizer,
@@ -960,160 +957,3 @@ def resume_memory_occupation(self):
obj = ResumeMemoryOccupationReqInput()
loop = asyncio.get_event_loop()
loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
-
-
-class Runtime:
- """
- A wrapper for the HTTP server.
- This is used for launching the server in a python program without
- using the commond line interface.
-
- It is mainly used for the frontend language.
- You should use the Engine class above if you want to do normal offline processing.
- """
-
- def __init__(
- self,
- log_level: str = "error",
- *args,
- **kwargs,
- ):
- """See the arguments in server_args.py::ServerArgs"""
- self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
-
- # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
- atexit.register(self.shutdown)
-
- # Pre-allocate ports
- for port in range(self.server_args.port, 40000):
- if is_port_available(port):
- break
- self.server_args.port = port
-
- self.url = self.server_args.url()
- self.generate_url = self.url + "/generate"
-
- # NOTE: We store pid instead of proc to fix some issues during __delete__
- self.pid = None
- pipe_reader, pipe_writer = mp.Pipe(duplex=False)
-
- proc = mp.Process(
- target=launch_server,
- args=(self.server_args, pipe_writer),
- )
- proc.start()
- pipe_writer.close()
- self.pid = proc.pid
-
- try:
- init_state = pipe_reader.recv()
- except EOFError:
- init_state = ""
-
- if init_state != "ready":
- self.shutdown()
- raise RuntimeError(
- "Initialization failed. Please see the error messages above."
- )
-
- self.endpoint = RuntimeEndpoint(self.url)
-
- def shutdown(self):
- if self.pid is not None:
- kill_process_tree(self.pid)
- self.pid = None
-
- def cache_prefix(self, prefix: str):
- self.endpoint.cache_prefix(prefix)
-
- def get_tokenizer(self):
- return get_tokenizer(
- self.server_args.tokenizer_path,
- tokenizer_mode=self.server_args.tokenizer_mode,
- trust_remote_code=self.server_args.trust_remote_code,
- revision=self.server_args.revision,
- )
-
- async def async_generate(
- self,
- prompt: str,
- sampling_params: Optional[Dict] = None,
- ):
- if self.server_args.skip_tokenizer_init:
- json_data = {
- "input_ids": prompt,
- "sampling_params": sampling_params,
- "stream": True,
- }
- else:
- json_data = {
- "text": prompt,
- "sampling_params": sampling_params,
- "stream": True,
- }
- pos = 0
-
- timeout = aiohttp.ClientTimeout(total=3 * 3600)
- async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
- async with session.post(self.generate_url, json=json_data) as response:
- async for chunk, _ in response.content.iter_chunks():
- chunk = chunk.decode("utf-8")
- if chunk and chunk.startswith("data:"):
- if chunk == "data: [DONE]\n\n":
- break
- data = json.loads(chunk[5:].strip("\n"))
- if "text" in data:
- cur = data["text"][pos:]
- if cur:
- yield cur
- pos += len(cur)
- else:
- yield data
-
- add_request = async_generate
-
- def generate(
- self,
- prompt: Union[str, List[str]],
- sampling_params: Optional[Dict] = None,
- return_logprob: Optional[Union[List[bool], bool]] = False,
- logprob_start_len: Optional[Union[List[int], int]] = None,
- top_logprobs_num: Optional[Union[List[int], int]] = None,
- lora_path: Optional[List[Optional[str]]] = None,
- ):
- json_data = {
- "text": prompt,
- "sampling_params": sampling_params,
- "return_logprob": return_logprob,
- "logprob_start_len": logprob_start_len,
- "top_logprobs_num": top_logprobs_num,
- "lora_path": lora_path,
- }
- assert not isinstance(lora_path, list) or len(lora_path) == len(prompt)
- response = requests.post(
- self.url + "/generate",
- json=json_data,
- )
- return json.dumps(response.json())
-
- def encode(
- self,
- prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
- ):
- json_data = {"text": prompt}
- response = requests.post(self.url + "/encode", json=json_data)
- return json.dumps(response.json())
-
- async def get_server_info(self):
- async with aiohttp.ClientSession() as session:
- async with session.get(f"{self.url}/get_server_info") as response:
- if response.status == 200:
- return await response.json()
- else:
- error_data = await response.json()
- raise RuntimeError(
- f"Failed to get server info. {error_data['error']['message']}"
- )
-
- def __del__(self):
- self.shutdown()
diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py
index f22f9cafaf3..fc9a9793715 100644
--- a/python/sglang/test/runners.py
+++ b/python/sglang/test/runners.py
@@ -23,7 +23,7 @@
from transformers import AutoModelForCausalLM
from sglang.srt.hf_transformers_utils import get_tokenizer
-from sglang.srt.server import Runtime
+from sglang.srt.server import Engine
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
DEFAULT_PROMPTS = [
@@ -278,7 +278,7 @@ def __init__(
):
self.model_type = model_type
self.is_generation = model_type == "generation"
- self.runtime = Runtime(
+ self.engine = Engine(
model_path=model_path,
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
@@ -306,7 +306,7 @@ def forward(
top_output_logprobs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
for i, prompt in enumerate(prompts):
- response = self.runtime.generate(
+ response = self.engine.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
@@ -314,7 +314,6 @@ def forward(
logprob_start_len=0,
top_logprobs_num=NUM_TOP_LOGPROBS,
)
- response = json.loads(response)
output_strs.append(response["text"])
top_input_logprobs.append(
[
@@ -343,8 +342,7 @@ def forward(
top_output_logprobs=top_output_logprobs,
)
else:
- response = self.runtime.encode(prompts)
- response = json.loads(response)
+ response = self.engine.encode(prompts)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
@@ -366,20 +364,18 @@ def batch_forward(
# the return value contains logprobs from prefill
output_strs = []
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
- response = self.runtime.generate(
+ response = self.engine.generate(
prompts,
lora_path=lora_paths if lora_paths else None,
sampling_params=sampling_params,
)
- response = json.loads(response)
output_strs = [r["text"] for r in response]
return ModelOutput(
output_strs=output_strs,
)
else:
- response = self.runtime.encode(prompts)
- response = json.loads(response)
+ response = self.engine.encode(prompts)
if self.model_type == "embedding":
logits = [x["embedding"] for x in response]
return ModelOutput(embed_logits=logits)
@@ -391,8 +387,8 @@ def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
- self.runtime.shutdown()
- del self.runtime
+ self.engine.shutdown()
+ del self.engine
def monkey_patch_gemma2_sdpa():
diff --git a/scripts/deprecated/test_jump_forward.py b/scripts/deprecated/test_jump_forward.py
index 60074a04005..315a50b5ba7 100644
--- a/scripts/deprecated/test_jump_forward.py
+++ b/scripts/deprecated/test_jump_forward.py
@@ -4,7 +4,7 @@
from pydantic import BaseModel, constr
import sglang as sgl
-from sglang.srt.constrained import build_regex_from_object
+from sglang.srt.constrained.outlines_backend import build_regex_from_object
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py
index b99606fc1cb..0d7cc910557 100644
--- a/test/lang/test_srt_backend.py
+++ b/test/lang/test_srt_backend.py
@@ -73,7 +73,7 @@ def test_hellaswag_select(self):
# Run twice to capture more bugs
for _ in range(2):
accuracy, latency = test_hellaswag_select()
- self.assertGreater(accuracy, 0.71)
+ self.assertGreater(accuracy, 0.70)
def test_gen_min_new_tokens(self):
test_gen_min_new_tokens()
diff --git a/test/srt/models/test_qwen_models.py b/test/srt/models/test_qwen_models.py
index 903fd45d550..9e61930a76e 100644
--- a/test/srt/models/test_qwen_models.py
+++ b/test/srt/models/test_qwen_models.py
@@ -71,7 +71,7 @@ def test_gsm8k(self):
metrics = run_eval(args)
print(metrics)
- self.assertGreater(metrics["accuracy"], 0.8)
+ self.assertGreater(metrics["accuracy"], 0.79)
if __name__ == "__main__":
diff --git a/test/srt/models/test_reward_models.py b/test/srt/models/test_reward_models.py
index 0d80a4d0cde..69ad563671b 100644
--- a/test/srt/models/test_reward_models.py
+++ b/test/srt/models/test_reward_models.py
@@ -20,8 +20,8 @@
from sglang.test.runners import HFRunner, SRTRunner
MODELS = [
- ("LxzGordon/URM-LLaMa-3.1-8B", 1, 3e-2),
- ("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 3e-2),
+ ("LxzGordon/URM-LLaMa-3.1-8B", 1, 4e-2),
+ ("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 4e-2),
]
TORCH_DTYPES = [torch.float16]
From cd493b5afc27ed1b0f5700809c896af16204f0d9 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Sun, 19 Jan 2025 18:36:59 -0800
Subject: [PATCH 021/147] Improve metrics, logging, and importing orders
(#2992)
---
.github/workflows/pr-test.yml | 2 +-
.../runtime/engine/offline_batch_inference.py | 5 +++
python/sglang/__init__.py | 44 +++++++++----------
.../sglang/lang/backend/runtime_endpoint.py | 20 ++++++---
python/sglang/srt/managers/scheduler.py | 6 ++-
python/sglang/srt/metrics/collector.py | 21 ++++++---
sgl-router/py_src/sglang_router/__init__.py | 10 ++---
test/srt/run_suite.py | 3 +-
8 files changed, 63 insertions(+), 48 deletions(-)
diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml
index 51117127ada..b910683e7da 100644
--- a/.github/workflows/pr-test.yml
+++ b/.github/workflows/pr-test.yml
@@ -52,7 +52,7 @@ jobs:
runs-on: 1-gpu-runner
strategy:
matrix:
- range: [0-6, 6-16, 16-23, 23-30, 30-38, 38-100]
+ range: [0-6, 6-15, 15-22, 22-32, 32-37, 37-100]
steps:
- name: Checkout code
uses: actions/checkout@v3
diff --git a/examples/runtime/engine/offline_batch_inference.py b/examples/runtime/engine/offline_batch_inference.py
index 724051eab53..92e68dcd72c 100644
--- a/examples/runtime/engine/offline_batch_inference.py
+++ b/examples/runtime/engine/offline_batch_inference.py
@@ -1,3 +1,8 @@
+"""
+Usage:
+python3 offline_batch_inference.py --model meta-llama/Llama-3.1-8B-Instruct
+"""
+
import argparse
import dataclasses
diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py
index de9134857a6..70d58043d40 100644
--- a/python/sglang/__init__.py
+++ b/python/sglang/__init__.py
@@ -1,5 +1,6 @@
-# SGL API Components
+# SGLang public APIs
+# Frontend Language APIs
from sglang.api import (
Engine,
Runtime,
@@ -23,16 +24,26 @@
user_end,
video,
)
+from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.lang.choices import (
greedy_token_selection,
token_length_normalized,
unconditional_likelihood_normalized,
)
+from sglang.utils import LazyImport
+
+Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
+LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
+OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
+VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
+
+# Other configs
+from sglang.global_config import global_config
+from sglang.version import __version__
-# SGLang DSL APIs
__all__ = [
- "Runtime",
"Engine",
+ "Runtime",
"assistant",
"assistant_begin",
"assistant_end",
@@ -52,27 +63,14 @@
"user_begin",
"user_end",
"video",
+ "RuntimeEndpoint",
"greedy_token_selection",
"token_length_normalized",
"unconditional_likelihood_normalized",
+ "Anthropic",
+ "LiteLLM",
+ "OpenAI",
+ "VertexAI",
+ "global_config",
+ "__version__",
]
-
-# Global Configurations
-from sglang.global_config import global_config
-
-__all__ += ["global_config"]
-
-from sglang.version import __version__
-
-__all__ += ["__version__"]
-
-# SGLang Backends
-from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
-from sglang.utils import LazyImport
-
-Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
-LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
-OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
-VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
-
-__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"]
diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py
index 23e9f1afbc6..c139db6f04c 100644
--- a/python/sglang/lang/backend/runtime_endpoint.py
+++ b/python/sglang/lang/backend/runtime_endpoint.py
@@ -19,9 +19,6 @@
REGEX_STR,
SglSamplingParams,
)
-from sglang.srt.hf_transformers_utils import get_tokenizer
-from sglang.srt.server_args import ServerArgs
-from sglang.srt.utils import is_port_available, kill_process_tree
from sglang.utils import http_request
@@ -342,7 +339,7 @@ class Runtime:
using the commond line interface.
It is mainly used for the frontend language.
- You should use the Engine class if you want to do normal offline processing.
+ You should use the Engine class if you want to do normal offline processing without the frontend language.
"""
def __init__(
@@ -352,13 +349,14 @@ def __init__(
**kwargs,
):
"""See the arguments in server_args.py::ServerArgs"""
+ # We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
+ # client code without installing SRT server and its dependency if they want.
from sglang.srt.server import launch_server
+ from sglang.srt.server_args import ServerArgs
+ from sglang.srt.utils import is_port_available
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)
- # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
- atexit.register(self.shutdown)
-
# Pre-allocate ports
for port in range(self.server_args.port, 40000):
if is_port_available(port):
@@ -380,6 +378,10 @@ def __init__(
pipe_writer.close()
self.pid = proc.pid
+ # Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
+ atexit.register(self.shutdown)
+
+ # TODO: remove this pipe_writer mechanism and use `/health_generate` instead.
try:
init_state = pipe_reader.recv()
except EOFError:
@@ -394,6 +396,8 @@ def __init__(
self.endpoint = RuntimeEndpoint(self.url)
def shutdown(self):
+ from sglang.srt.utils import kill_process_tree
+
if self.pid is not None:
kill_process_tree(self.pid)
self.pid = None
@@ -402,6 +406,8 @@ def cache_prefix(self, prefix: str):
self.endpoint.cache_prefix(prefix)
def get_tokenizer(self):
+ from sglang.srt.hf_transformers_utils import get_tokenizer
+
return get_tokenizer(
self.server_args.tokenizer_path,
tokenizer_mode=self.server_args.tokenizer_mode,
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index ece5b266455..416abe21cd3 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -785,8 +785,9 @@ def log_decode_stats(self):
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
+ spec_accept_length = 0
else:
- accept_length = (
+ spec_accept_length = (
self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
)
self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
@@ -795,7 +796,7 @@ def log_decode_stats(self):
f"#running-req: {num_running_reqs}, "
f"#token: {num_used}, "
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
- f"accept len: {accept_length:.2f}, "
+ f"accept len: {spec_accept_length:.2f}, "
f"gen throughput (token/s): {gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}"
)
@@ -807,6 +808,7 @@ def log_decode_stats(self):
self.stats.token_usage = num_used / self.max_total_num_tokens
self.stats.gen_throughput = gen_throughput
self.stats.num_queue_reqs = len(self.waiting_queue)
+ self.stats.spec_accept_length = spec_accept_length
self.metrics_collector.log_stats(self.stats)
def check_memory(self):
diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py
index 070b405be42..26eb2fc27d2 100644
--- a/python/sglang/srt/metrics/collector.py
+++ b/python/sglang/srt/metrics/collector.py
@@ -25,6 +25,7 @@ class SchedulerStats:
gen_throughput: float = 0.0
num_queue_reqs: int = 0
cache_hit_rate: float = 0.0
+ spec_accept_length: float = 0.0
class SchedulerMetricsCollector:
@@ -37,42 +38,49 @@ def __init__(self, labels: Dict[str, str]) -> None:
self.num_running_reqs = Gauge(
name="sglang:num_running_reqs",
- documentation="The number of running requests",
+ documentation="The number of running requests.",
labelnames=labels.keys(),
multiprocess_mode="sum",
)
self.num_used_tokens = Gauge(
name="sglang:num_used_tokens",
- documentation="The number of used tokens",
+ documentation="The number of used tokens.",
labelnames=labels.keys(),
multiprocess_mode="sum",
)
self.token_usage = Gauge(
name="sglang:token_usage",
- documentation="The token usage",
+ documentation="The token usage.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
self.gen_throughput = Gauge(
name="sglang:gen_throughput",
- documentation="The generate throughput (token/s)",
+ documentation="The generation throughput (token/s).",
labelnames=labels.keys(),
multiprocess_mode="sum",
)
self.num_queue_reqs = Gauge(
name="sglang:num_queue_reqs",
- documentation="The number of requests in the waiting queue",
+ documentation="The number of requests in the waiting queue.",
labelnames=labels.keys(),
multiprocess_mode="sum",
)
self.cache_hit_rate = Gauge(
name="sglang:cache_hit_rate",
- documentation="The cache hit rate",
+ documentation="The prefix cache hit rate.",
+ labelnames=labels.keys(),
+ multiprocess_mode="mostrecent",
+ )
+
+ self.spec_accept_length = Gauge(
+ name="sglang:spec_accept_length",
+ documentation="The average acceptance length of speculative decoding.",
labelnames=labels.keys(),
multiprocess_mode="mostrecent",
)
@@ -88,6 +96,7 @@ def log_stats(self, stats: SchedulerStats) -> None:
self._log_gauge(self.gen_throughput, stats.gen_throughput)
self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
+ self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
class TokenizerMetricsCollector:
diff --git a/sgl-router/py_src/sglang_router/__init__.py b/sgl-router/py_src/sglang_router/__init__.py
index 285ee173ba9..081740479ca 100644
--- a/sgl-router/py_src/sglang_router/__init__.py
+++ b/sgl-router/py_src/sglang_router/__init__.py
@@ -1,11 +1,7 @@
# a lightweihgt wrapper on router with argument type and comments
-from sglang_router_rs import PolicyType
-
# no wrapper on policy type => direct export
-from .router import Router
-
-__all__ = ["Router", "PolicyType"]
-
+from sglang_router.router import Router
from sglang_router.version import __version__
+from sglang_router_rs import PolicyType
-__all__ += ["__version__"]
+__all__ = ["Router", "PolicyType", "__version__"]
diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py
index 2ed2522755a..69a5470bee4 100644
--- a/test/srt/run_suite.py
+++ b/test/srt/run_suite.py
@@ -42,8 +42,7 @@
"test_srt_endpoint.py",
"test_torch_compile.py",
"test_torch_compile_moe.py",
- # Temporarily disable this because it requires PyTorch >= 2.5
- # "test_torch_native_attention_backend.py",
+ "test_torch_native_attention_backend.py",
"test_torchao.py",
"test_triton_attention_kernels.py",
"test_triton_attention_backend.py",
From 0ffcfdf474d34858ce5641c11c0b5559861d188b Mon Sep 17 00:00:00 2001
From: Chayenne
Date: Sun, 19 Jan 2025 20:22:47 -0800
Subject: [PATCH 022/147] Docs: Only use X-Grammar in structed output (#2991)
---
docs/backend/openai_api_completions.ipynb | 131 ++--------------------
docs/backend/structured_outputs.ipynb | 93 +++------------
2 files changed, 22 insertions(+), 202 deletions(-)
diff --git a/docs/backend/openai_api_completions.ipynb b/docs/backend/openai_api_completions.ipynb
index 42cdbb11210..58b524108db 100644
--- a/docs/backend/openai_api_completions.ipynb
+++ b/docs/backend/openai_api_completions.ipynb
@@ -41,10 +41,10 @@
")\n",
"\n",
"server_process = execute_shell_command(\n",
- " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\"\n",
+ " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30020 --host 0.0.0.0\"\n",
")\n",
"\n",
- "wait_for_server(\"http://localhost:30000\")"
+ "wait_for_server(\"http://localhost:30020\")"
]
},
{
@@ -68,7 +68,7 @@
"source": [
"import openai\n",
"\n",
- "client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n",
+ "client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
@@ -214,125 +214,8 @@
"metadata": {},
"source": [
"## Structured Outputs (JSON, Regex, EBNF)\n",
- "You can specify a JSON schema, [regular expression](https://en.wikipedia.org/wiki/Regular_expression) or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (`json_schema`, `regex`, or `ebnf`) can be specified for a request.\n",
"\n",
- "SGLang supports two grammar backends:\n",
- "\n",
- "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n",
- "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n",
- " - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n",
- "\n",
- "Initialize the XGrammar backend using `--grammar-backend xgrammar` flag\n",
- "```bash\n",
- "python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
- "--port 30000 --host 0.0.0.0 --grammar-backend [xgrammar|outlines] # xgrammar or outlines (default: outlines)\n",
- "```\n",
- "\n",
- "### JSON"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import json\n",
- "\n",
- "json_schema = json.dumps(\n",
- " {\n",
- " \"type\": \"object\",\n",
- " \"properties\": {\n",
- " \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n",
- " \"population\": {\"type\": \"integer\"},\n",
- " },\n",
- " \"required\": [\"name\", \"population\"],\n",
- " }\n",
- ")\n",
- "\n",
- "response = client.chat.completions.create(\n",
- " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
- " messages=[\n",
- " {\n",
- " \"role\": \"user\",\n",
- " \"content\": \"Give me the information of the capital of France in the JSON format.\",\n",
- " },\n",
- " ],\n",
- " temperature=0,\n",
- " max_tokens=128,\n",
- " response_format={\n",
- " \"type\": \"json_schema\",\n",
- " \"json_schema\": {\"name\": \"foo\", \"schema\": json.loads(json_schema)},\n",
- " },\n",
- ")\n",
- "\n",
- "print_highlight(response.choices[0].message.content)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Regular expression (use default \"outlines\" backend)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "response = client.chat.completions.create(\n",
- " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
- " messages=[\n",
- " {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n",
- " ],\n",
- " temperature=0,\n",
- " max_tokens=128,\n",
- " extra_body={\"regex\": \"(Paris|London)\"},\n",
- ")\n",
- "\n",
- "print_highlight(response.choices[0].message.content)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### EBNF (use \"xgrammar\" backend)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# terminate the existing server(that's using default outlines backend) for this demo\n",
- "terminate_process(server_process)\n",
- "\n",
- "# start new server with xgrammar backend\n",
- "server_process = execute_shell_command(\n",
- " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0 --grammar-backend xgrammar\"\n",
- ")\n",
- "wait_for_server(\"http://localhost:30000\")\n",
- "\n",
- "# EBNF example\n",
- "ebnf_grammar = r\"\"\"\n",
- " root ::= \"Hello\" | \"Hi\" | \"Hey\"\n",
- " \"\"\"\n",
- "response = client.chat.completions.create(\n",
- " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
- " messages=[\n",
- " {\"role\": \"system\", \"content\": \"You are a helpful EBNF test bot.\"},\n",
- " {\"role\": \"user\", \"content\": \"Say a greeting.\"},\n",
- " ],\n",
- " temperature=0,\n",
- " max_tokens=32,\n",
- " extra_body={\"ebnf\": ebnf_grammar},\n",
- ")\n",
- "\n",
- "print_highlight(response.choices[0].message.content)"
+ "For OpenAI compatible structed outputs API, refer to [Structured Outputs](https://docs.sglang.ai/backend/structured_outputs.html#OpenAI-Compatible-API) for more details.\n"
]
},
{
@@ -362,7 +245,7 @@
"import time\n",
"from openai import OpenAI\n",
"\n",
- "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n",
+ "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n",
"\n",
"requests = [\n",
" {\n",
@@ -465,7 +348,7 @@
"import time\n",
"from openai import OpenAI\n",
"\n",
- "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n",
+ "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n",
"\n",
"requests = []\n",
"for i in range(100):\n",
@@ -542,7 +425,7 @@
"from openai import OpenAI\n",
"import os\n",
"\n",
- "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n",
+ "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n",
"\n",
"requests = []\n",
"for i in range(500):\n",
diff --git a/docs/backend/structured_outputs.ipynb b/docs/backend/structured_outputs.ipynb
index a5e6f2335b5..e413743ccfd 100644
--- a/docs/backend/structured_outputs.ipynb
+++ b/docs/backend/structured_outputs.ipynb
@@ -17,11 +17,12 @@
"\n",
"- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n",
"- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n",
- " - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n",
"\n",
- "We suggest using XGrammar whenever possible for its better performance. For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n",
+ "We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n",
"\n",
- "To use Xgrammar, simply add `--grammar-backend` xgrammar when launching the server. If no backend is specified, Outlines will be used as the default."
+ "To use Xgrammar, simply add `--grammar-backend` xgrammar when launching the server. If no backend is specified, Outlines will be used as the default.\n",
+ "\n",
+ "For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\n"
]
},
{
@@ -93,7 +94,7 @@
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
- " \"content\": \"Give me the information of the capital of France in the JSON format.\",\n",
+ " \"content\": \"Please generate the information of the capital of France in the JSON format.\",\n",
" },\n",
" ],\n",
" temperature=0,\n",
@@ -197,20 +198,6 @@
"print_highlight(response.choices[0].message.content)"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "terminate_process(server_process)\n",
- "server_process = execute_shell_command(\n",
- " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\"\n",
- ")\n",
- "\n",
- "wait_for_server(\"http://localhost:30000\")"
- ]
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -237,15 +224,6 @@
"print_highlight(response.choices[0].message.content)"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "terminate_process(server_process)"
- ]
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -253,21 +231,6 @@
"## Native API and SGLang Runtime (SRT)"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "server_process = execute_shell_command(\n",
- " \"\"\"\n",
- "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010 --grammar-backend xgrammar\n",
- "\"\"\"\n",
- ")\n",
- "\n",
- "wait_for_server(\"http://localhost:30010\")"
- ]
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -301,7 +264,7 @@
"\n",
"# Make API request\n",
"response = requests.post(\n",
- " \"http://localhost:30010/generate\",\n",
+ " \"http://localhost:30000/generate\",\n",
" json={\n",
" \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n",
" \"sampling_params\": {\n",
@@ -346,7 +309,7 @@
"\n",
"# JSON\n",
"response = requests.post(\n",
- " \"http://localhost:30010/generate\",\n",
+ " \"http://localhost:30000/generate\",\n",
" json={\n",
" \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n",
" \"sampling_params\": {\n",
@@ -376,7 +339,7 @@
"import requests\n",
"\n",
"response = requests.post(\n",
- " \"http://localhost:30010/generate\",\n",
+ " \"http://localhost:30000/generate\",\n",
" json={\n",
" \"text\": \"Give me the information of the capital of France.\",\n",
" \"sampling_params\": {\n",
@@ -399,22 +362,6 @@
"print_highlight(response.json())"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "terminate_process(server_process)\n",
- "server_process = execute_shell_command(\n",
- " \"\"\"\n",
- "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010\n",
- "\"\"\"\n",
- ")\n",
- "\n",
- "wait_for_server(\"http://localhost:30010\")"
- ]
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -429,7 +376,7 @@
"outputs": [],
"source": [
"response = requests.post(\n",
- " \"http://localhost:30010/generate\",\n",
+ " \"http://localhost:30000/generate\",\n",
" json={\n",
" \"text\": \"Paris is the capital of\",\n",
" \"sampling_params\": {\n",
@@ -466,7 +413,7 @@
"source": [
"import sglang as sgl\n",
"\n",
- "llm_xgrammar = sgl.Engine(\n",
+ "llm = sgl.Engine(\n",
" model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", grammar_backend=\"xgrammar\"\n",
")"
]
@@ -514,7 +461,7 @@
" \"json_schema\": json.dumps(CapitalInfo.model_json_schema()),\n",
"}\n",
"\n",
- "outputs = llm_xgrammar.generate(prompts, sampling_params)\n",
+ "outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\") # validate the output by the pydantic model\n",
@@ -554,7 +501,7 @@
"\n",
"sampling_params = {\"temperature\": 0.1, \"top_p\": 0.95, \"json_schema\": json_schema}\n",
"\n",
- "outputs = llm_xgrammar.generate(prompts, sampling_params)\n",
+ "outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
@@ -591,22 +538,12 @@
" ),\n",
"}\n",
"\n",
- "outputs = llm_xgrammar.generate(prompts, sampling_params)\n",
+ "outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "llm_xgrammar.shutdown()\n",
- "llm_outlines = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")"
- ]
- },
{
"cell_type": "markdown",
"metadata": {},
@@ -627,7 +564,7 @@
"\n",
"sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"regex\": \"(France|England)\"}\n",
"\n",
- "outputs = llm_outlines.generate(prompts, sampling_params)\n",
+ "outputs = llm.generate(prompts, sampling_params)\n",
"for prompt, output in zip(prompts, outputs):\n",
" print_highlight(\"===============================\")\n",
" print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")"
@@ -639,7 +576,7 @@
"metadata": {},
"outputs": [],
"source": [
- "llm_outlines.shutdown()"
+ "llm.shutdown()"
]
}
],
From 1a820e38a2fcc6d0e0324605bb39baec23d81f8d Mon Sep 17 00:00:00 2001
From: Chaitanya Sri Krishna Lolla
Date: Mon, 20 Jan 2025 10:30:35 +0530
Subject: [PATCH 023/147] Remove dependency of pynvml on ROCm (#2995)
---
.../device_communicators/custom_all_reduce.py | 9 ++++++++-
1 file changed, 8 insertions(+), 1 deletion(-)
diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
index 28aa9d4811e..d4506b9f04c 100644
--- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
+++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
@@ -6,7 +6,6 @@
from functools import wraps
from typing import Callable, List, Optional, TypeVar, Union
-import pynvml
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@@ -20,6 +19,14 @@
from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
+logger = logging.getLogger(__name__)
+
+if is_cuda():
+ try:
+ import pynvml
+ except ImportError as e:
+ logger.warning("Failed to import pynvml with %r", e)
+
try:
if ops.use_vllm_custom_allreduce:
ops.meta_size()
From 44a966977083f3a7d7cc2a268f46a63e76d049a8 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Mon, 20 Jan 2025 13:21:36 +0800
Subject: [PATCH 024/147] keep rotary_embedding only (#2997)
---
python/sglang/srt/layers/rotary_embedding.py | 60 ++++++--------------
1 file changed, 16 insertions(+), 44 deletions(-)
diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py
index 964152905be..43478f39d2c 100644
--- a/python/sglang/srt/layers/rotary_embedding.py
+++ b/python/sglang/srt/layers/rotary_embedding.py
@@ -144,28 +144,14 @@ def forward_cuda(
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
- # ops.rotary_embedding()/batched_rotary_embedding()
- # are in-place operations that update the query and key tensors.
- if offsets is not None:
- ops.batched_rotary_embedding(
- positions,
- query,
- key,
- self.head_size,
- self.cos_sin_cache,
- self.is_neox_style,
- self.rotary_dim,
- offsets,
- )
- else:
- ops.rotary_embedding(
- positions,
- query,
- key,
- self.head_size,
- self.cos_sin_cache,
- self.is_neox_style,
- )
+ ops.rotary_embedding(
+ positions,
+ query,
+ key,
+ self.head_size,
+ self.cos_sin_cache,
+ self.is_neox_style,
+ )
return query, key
def forward_xpu(
@@ -178,28 +164,14 @@ def forward_xpu(
from vllm._ipex_ops import ipex_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype)
- # ops.rotary_embedding()/batched_rotary_embedding()
- # are in-place operations that update the query and key tensors.
- if offsets is not None:
- ops.batched_rotary_embedding(
- positions,
- query,
- key,
- self.head_size,
- self.cos_sin_cache,
- self.is_neox_style,
- self.rotary_dim,
- offsets,
- )
- else:
- ops.rotary_embedding(
- positions,
- query,
- key,
- self.head_size,
- self.cos_sin_cache,
- self.is_neox_style,
- )
+ ops.rotary_embedding(
+ positions,
+ query,
+ key,
+ self.head_size,
+ self.cos_sin_cache,
+ self.is_neox_style,
+ )
return query, key
def forward_hpu(
From 03464890e0e0d048ebc1aa407e5235d7338f6aff Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Sun, 19 Jan 2025 22:09:24 -0800
Subject: [PATCH 025/147] Separate two entry points: Engine and HTTP server
(#2996)
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
---
docs/references/supported_models.md | 2 +-
python/sglang/api.py | 2 +-
python/sglang/bench_offline_throughput.py | 2 +-
python/sglang/bench_one_batch.py | 2 +-
python/sglang/bench_one_batch_server.py | 2 +-
.../sglang/lang/backend/runtime_endpoint.py | 2 +-
python/sglang/launch_server.py | 2 +-
python/sglang/srt/entrypoints/engine.py | 449 +++++++++
python/sglang/srt/entrypoints/http_server.py | 579 +++++++++++
python/sglang/srt/managers/io_struct.py | 1 -
.../sglang/srt/managers/tokenizer_manager.py | 7 +-
python/sglang/srt/server.py | 949 +-----------------
python/sglang/test/runners.py | 3 +-
.../py_src/sglang_router/launch_server.py | 2 +-
test/srt/test_metrics.py | 1 -
test/srt/test_nightly_gsm8k_eval.py | 6 +-
test/srt/test_nightly_human_eval.py | 4 +-
test/srt/test_srt_engine.py | 148 ++-
18 files changed, 1121 insertions(+), 1042 deletions(-)
create mode 100644 python/sglang/srt/entrypoints/engine.py
create mode 100644 python/sglang/srt/entrypoints/http_server.py
diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md
index 23c98ea9305..60551b2c1da 100644
--- a/docs/references/supported_models.md
+++ b/docs/references/supported_models.md
@@ -91,7 +91,7 @@ Here is how you can do it:
```python
from sglang.srt.models.registry import ModelRegistry
-from sglang.srt.server import launch_server
+from sglang.srt.entrypoints.http_server import launch_server
# for a single model, you can add it to the registry
ModelRegistry.models[model_name] = model_class
diff --git a/python/sglang/api.py b/python/sglang/api.py
index a9c5fa9da99..7ef306380a9 100644
--- a/python/sglang/api.py
+++ b/python/sglang/api.py
@@ -40,7 +40,7 @@ def Runtime(*args, **kwargs):
def Engine(*args, **kwargs):
# Avoid importing unnecessary dependency
- from sglang.srt.server import Engine
+ from sglang.srt.entrypoints.engine import Engine
return Engine(*args, **kwargs)
diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py
index 6b31ac40e11..b0a715e61cc 100644
--- a/python/sglang/bench_offline_throughput.py
+++ b/python/sglang/bench_offline_throughput.py
@@ -28,7 +28,7 @@
set_ulimit,
)
from sglang.lang.backend.runtime_endpoint import Runtime
-from sglang.srt.server import Engine
+from sglang.srt.entrypoints.engine import Engine
from sglang.srt.server_args import ServerArgs
diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py
index 99fba8be913..473f478ad5c 100644
--- a/python/sglang/bench_one_batch.py
+++ b/python/sglang/bench_one_batch.py
@@ -57,12 +57,12 @@
import torch.distributed as dist
from sglang.srt.configs.model_config import ModelConfig
+from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
-from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers
diff --git a/python/sglang/bench_one_batch_server.py b/python/sglang/bench_one_batch_server.py
index 01cc561e1ce..5f0759a7ce1 100644
--- a/python/sglang/bench_one_batch_server.py
+++ b/python/sglang/bench_one_batch_server.py
@@ -22,7 +22,7 @@
import numpy as np
import requests
-from sglang.srt.server import launch_server
+from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py
index c139db6f04c..01f10b9f063 100644
--- a/python/sglang/lang/backend/runtime_endpoint.py
+++ b/python/sglang/lang/backend/runtime_endpoint.py
@@ -351,7 +351,7 @@ def __init__(
"""See the arguments in server_args.py::ServerArgs"""
# We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run
# client code without installing SRT server and its dependency if they want.
- from sglang.srt.server import launch_server
+ from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_port_available
diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py
index 6b0c25711c6..caae7b0f6cc 100644
--- a/python/sglang/launch_server.py
+++ b/python/sglang/launch_server.py
@@ -3,7 +3,7 @@
import os
import sys
-from sglang.srt.server import launch_server
+from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_process_tree
diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py
new file mode 100644
index 00000000000..310e92c23d9
--- /dev/null
+++ b/python/sglang/srt/entrypoints/engine.py
@@ -0,0 +1,449 @@
+# Copyright 2023-2024 SGLang Team
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+The entry point of inference server. (SRT = SGLang Runtime)
+
+This file implements python APIs for the inference engine.
+"""
+
+import asyncio
+import atexit
+import dataclasses
+import logging
+import multiprocessing as mp
+import os
+import signal
+import threading
+from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
+
+# Fix a bug of Python threading
+setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
+
+import torch
+import uvloop
+
+from sglang.srt.managers.data_parallel_controller import (
+ run_data_parallel_controller_process,
+)
+from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
+from sglang.srt.managers.io_struct import (
+ EmbeddingReqInput,
+ GenerateReqInput,
+ GetWeightsByNameReqInput,
+ InitWeightsUpdateGroupReqInput,
+ ReleaseMemoryOccupationReqInput,
+ ResumeMemoryOccupationReqInput,
+ UpdateWeightsFromDistributedReqInput,
+ UpdateWeightsFromTensorReqInput,
+)
+from sglang.srt.managers.scheduler import run_scheduler_process
+from sglang.srt.managers.tokenizer_manager import TokenizerManager
+from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
+from sglang.srt.server_args import PortArgs, ServerArgs
+from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
+from sglang.srt.utils import (
+ MultiprocessingSerializer,
+ assert_pkg_version,
+ configure_logger,
+ kill_process_tree,
+ maybe_set_triton_cache_manager,
+ prepare_model_and_tokenizer,
+ set_prometheus_multiproc_dir,
+ set_ulimit,
+)
+from sglang.version import __version__
+
+logger = logging.getLogger(__name__)
+asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
+
+
+class Engine:
+ """
+ The entry point to the inference engine.
+
+ - The engine consists of three components:
+ 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
+ 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
+ 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
+
+ Note:
+ 1. The HTTP server, Engine, and TokenizerManager both run in the main process.
+ 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
+ """
+
+ def __init__(self, **kwargs):
+ """
+ The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`.
+ Please refer to `ServerArgs` for the documentation.
+ """
+ if "server_args" in kwargs:
+ # Directly load server_args
+ server_args = kwargs["server_args"]
+ else:
+ # Construct server_args from kwargs
+ if "log_level" not in kwargs:
+ # Do not print logs by default
+ kwargs["log_level"] = "error"
+ server_args = ServerArgs(**kwargs)
+
+ # Shutdown the subprocesses automatically when the program exists
+ atexit.register(self.shutdown)
+
+ # Launch subprocesses
+ tokenizer_manager, scheduler_info = _launch_subprocesses(
+ server_args=server_args
+ )
+ self.tokenizer_manager = tokenizer_manager
+ self.scheduler_info = scheduler_info
+
+ def generate(
+ self,
+ # The input prompt. It can be a single prompt or a batch of prompts.
+ prompt: Optional[Union[List[str], str]] = None,
+ sampling_params: Optional[Union[List[Dict], Dict]] = None,
+ # The token ids for text; one can either specify text or input_ids.
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
+ return_logprob: Optional[Union[List[bool], bool]] = False,
+ logprob_start_len: Optional[Union[List[int], int]] = None,
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
+ lora_path: Optional[List[Optional[str]]] = None,
+ custom_logit_processor: Optional[Union[List[str], str]] = None,
+ stream: bool = False,
+ ) -> Union[Dict, Iterator[Dict]]:
+ """
+ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
+ Please refer to `GenerateReqInput` for the documentation.
+ """
+ obj = GenerateReqInput(
+ text=prompt,
+ input_ids=input_ids,
+ sampling_params=sampling_params,
+ return_logprob=return_logprob,
+ logprob_start_len=logprob_start_len,
+ top_logprobs_num=top_logprobs_num,
+ lora_path=lora_path,
+ custom_logit_processor=custom_logit_processor,
+ stream=stream,
+ )
+ loop = asyncio.get_event_loop()
+ generator = self.tokenizer_manager.generate_request(obj, None)
+
+ if stream:
+
+ def generator_wrapper():
+ while True:
+ try:
+ chunk = loop.run_until_complete(generator.__anext__())
+ yield chunk
+ except StopAsyncIteration:
+ break
+
+ return generator_wrapper()
+ else:
+ ret = loop.run_until_complete(generator.__anext__())
+ return ret
+
+ async def async_generate(
+ self,
+ # The input prompt. It can be a single prompt or a batch of prompts.
+ prompt: Optional[Union[List[str], str]] = None,
+ sampling_params: Optional[Union[List[Dict], Dict]] = None,
+ # The token ids for text; one can either specify text or input_ids.
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
+ return_logprob: Optional[Union[List[bool], bool]] = False,
+ logprob_start_len: Optional[Union[List[int], int]] = None,
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
+ lora_path: Optional[List[Optional[str]]] = None,
+ custom_logit_processor: Optional[Union[List[str], str]] = None,
+ stream: bool = False,
+ ) -> Union[Dict, AsyncIterator[Dict]]:
+ """
+ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
+ Please refer to `GenerateReqInput` for the documentation.
+ """
+ obj = GenerateReqInput(
+ text=prompt,
+ input_ids=input_ids,
+ sampling_params=sampling_params,
+ return_logprob=return_logprob,
+ logprob_start_len=logprob_start_len,
+ top_logprobs_num=top_logprobs_num,
+ lora_path=lora_path,
+ stream=stream,
+ custom_logit_processor=custom_logit_processor,
+ )
+ generator = self.tokenizer_manager.generate_request(obj, None)
+
+ if stream is True:
+ return generator
+ else:
+ return await generator.__anext__()
+
+ def encode(
+ self,
+ prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
+ ) -> Dict:
+ """
+ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
+ Please refer to `EmbeddingReqInput` for the documentation.
+ """
+
+ obj = EmbeddingReqInput(text=prompt)
+ loop = asyncio.get_event_loop()
+ generator = self.tokenizer_manager.generate_request(obj, None)
+ ret = loop.run_until_complete(generator.__anext__())
+ return ret
+
+ def shutdown(self):
+ """Shutdown the engine"""
+ kill_process_tree(os.getpid(), include_parent=False)
+
+ def start_profile(self):
+ self.tokenizer_manager.start_profile()
+
+ def stop_profile(self):
+ self.tokenizer_manager.stop_profile()
+
+ def get_server_info(self):
+ return {
+ **dataclasses.asdict(self.tokenizer_manager.server_args), # server args
+ **self.scheduler_info,
+ "version": __version__,
+ }
+
+ def init_weights_update_group(
+ self,
+ master_address: str,
+ master_port: int,
+ rank_offset: int,
+ world_size: int,
+ group_name: str,
+ backend: str = "nccl",
+ ):
+ """Initialize parameter update group."""
+ obj = InitWeightsUpdateGroupReqInput(
+ master_address=master_address,
+ master_port=master_port,
+ rank_offset=rank_offset,
+ world_size=world_size,
+ group_name=group_name,
+ backend=backend,
+ )
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(
+ self.tokenizer_manager.init_weights_update_group(obj, None)
+ )
+
+ def update_weights_from_distributed(self, name: str, dtype, shape):
+ """Update weights from distributed source."""
+ obj = UpdateWeightsFromDistributedReqInput(
+ name=name,
+ dtype=dtype,
+ shape=shape,
+ )
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(
+ self.tokenizer_manager.update_weights_from_distributed(obj, None)
+ )
+
+ def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
+ """Update weights from distributed source."""
+ obj = UpdateWeightsFromTensorReqInput(
+ serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
+ )
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(
+ self.tokenizer_manager.update_weights_from_tensor(obj, None)
+ )
+
+ def get_weights_by_name(self, name: str, truncate_size: int = 100):
+ """Get weights by parameter name."""
+ obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(
+ self.tokenizer_manager.get_weights_by_name(obj, None)
+ )
+
+ def release_memory_occupation(self):
+ """Release GPU occupation temporarily."""
+ obj = ReleaseMemoryOccupationReqInput()
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(
+ self.tokenizer_manager.release_memory_occupation(obj, None)
+ )
+
+ def resume_memory_occupation(self):
+ """Resume GPU occupation."""
+ obj = ResumeMemoryOccupationReqInput()
+ loop = asyncio.get_event_loop()
+ return loop.run_until_complete(
+ self.tokenizer_manager.resume_memory_occupation(obj, None)
+ )
+
+
+def _set_envs_and_config(server_args: ServerArgs):
+ # Set global environments
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
+ os.environ["NCCL_CUMEM_ENABLE"] = "0"
+ os.environ["NCCL_NVLS_ENABLE"] = "0"
+ os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
+ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
+
+ # Set prometheus env vars
+ if server_args.enable_metrics:
+ set_prometheus_multiproc_dir()
+
+ # Set ulimit
+ set_ulimit()
+
+ # Fix triton bugs
+ if server_args.tp_size * server_args.dp_size > 1:
+ # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
+ maybe_set_triton_cache_manager()
+
+ # Check flashinfer version
+ if server_args.attention_backend == "flashinfer":
+ assert_pkg_version(
+ "flashinfer",
+ "0.1.6",
+ "Please uninstall the old version and "
+ "reinstall the latest version by following the instructions "
+ "at https://docs.flashinfer.ai/installation.html.",
+ )
+
+ # Register the signal handler.
+ # The child processes will send SIGQUIT to this process when any error happens
+ # This process then clean up the whole process tree
+ def sigquit_handler(signum, frame):
+ logger.error(
+ "Received sigquit from a child proces. It usually means the child failed."
+ )
+ kill_process_tree(os.getpid())
+
+ signal.signal(signal.SIGQUIT, sigquit_handler)
+
+ # Set mp start method
+ mp.set_start_method("spawn", force=True)
+
+
+def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
+ """
+ Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
+ """
+ # Configure global environment
+ configure_logger(server_args)
+ server_args.check_server_args()
+ _set_envs_and_config(server_args)
+
+ # Allocate ports for inter-process communications
+ port_args = PortArgs.init_new(server_args)
+ logger.info(f"{server_args=}")
+
+ # If using model from www.modelscope.cn, first download the model.
+ server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
+ server_args.model_path, server_args.tokenizer_path
+ )
+
+ scheduler_procs = []
+ if server_args.dp_size == 1:
+ # Launch tensor parallel scheduler processes
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
+ enable=server_args.enable_memory_saver
+ )
+
+ scheduler_pipe_readers = []
+ tp_size_per_node = server_args.tp_size // server_args.nnodes
+ tp_rank_range = range(
+ tp_size_per_node * server_args.node_rank,
+ tp_size_per_node * (server_args.node_rank + 1),
+ )
+ for tp_rank in tp_rank_range:
+ reader, writer = mp.Pipe(duplex=False)
+ gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
+ proc = mp.Process(
+ target=run_scheduler_process,
+ args=(server_args, port_args, gpu_id, tp_rank, None, writer),
+ )
+ with memory_saver_adapter.configure_subprocess():
+ proc.start()
+ scheduler_procs.append(proc)
+ scheduler_pipe_readers.append(reader)
+ else:
+ # Launch the data parallel controller
+ reader, writer = mp.Pipe(duplex=False)
+ scheduler_pipe_readers = [reader]
+ proc = mp.Process(
+ target=run_data_parallel_controller_process,
+ args=(server_args, port_args, writer),
+ )
+ proc.start()
+ scheduler_procs.append(proc)
+
+ if server_args.node_rank >= 1:
+ # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
+ # so they can just wait here.
+
+ for reader in scheduler_pipe_readers:
+ data = reader.recv()
+ assert data["status"] == "ready"
+
+ if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
+ # When using `Engine` as a Python API, we don't want to block here.
+ return
+
+ for proc in scheduler_procs:
+ proc.join()
+ logger.error(
+ f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
+ )
+ return
+
+ # Launch detokenizer process
+ detoken_proc = mp.Process(
+ target=run_detokenizer_process,
+ args=(
+ server_args,
+ port_args,
+ ),
+ )
+ detoken_proc.start()
+
+ # Launch tokenizer process
+ tokenizer_manager = TokenizerManager(server_args, port_args)
+ if server_args.chat_template:
+ load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
+
+ # Wait for the model to finish loading
+ scheduler_infos = []
+ for i in range(len(scheduler_pipe_readers)):
+ try:
+ data = scheduler_pipe_readers[i].recv()
+ except EOFError:
+ logger.error(
+ f"Rank {i} scheduler is dead. Please check if there are relevant logs."
+ )
+ scheduler_procs[i].join()
+ logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
+ raise
+
+ if data["status"] != "ready":
+ raise RuntimeError(
+ "Initialization failed. Please see the error messages above."
+ )
+ scheduler_infos.append(data)
+
+ # Assume all schedulers have the same scheduler_info
+ scheduler_info = scheduler_infos[0]
+ tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
+ return tokenizer_manager, scheduler_info
diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py
new file mode 100644
index 00000000000..0ebce1a85d5
--- /dev/null
+++ b/python/sglang/srt/entrypoints/http_server.py
@@ -0,0 +1,579 @@
+# Copyright 2023-2024 SGLang Team
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+The entry point of inference server. (SRT = SGLang Runtime)
+
+This file implements HTTP APIs for the inferenc engine via fastapi.
+"""
+
+import asyncio
+import dataclasses
+import logging
+import multiprocessing as multiprocessing
+import os
+import threading
+import time
+from http import HTTPStatus
+from typing import AsyncIterator, Dict, Optional
+
+# Fix a bug of Python threading
+setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
+
+import orjson
+import requests
+import uvicorn
+import uvloop
+from fastapi import FastAPI, File, Form, Request, UploadFile
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import ORJSONResponse, Response, StreamingResponse
+
+from sglang.srt.entrypoints.engine import _launch_subprocesses
+from sglang.srt.managers.io_struct import (
+ CloseSessionReqInput,
+ ConfigureLoggingReq,
+ EmbeddingReqInput,
+ GenerateReqInput,
+ GetWeightsByNameReqInput,
+ InitWeightsUpdateGroupReqInput,
+ OpenSessionReqInput,
+ ReleaseMemoryOccupationReqInput,
+ ResumeMemoryOccupationReqInput,
+ UpdateWeightFromDiskReqInput,
+ UpdateWeightsFromDistributedReqInput,
+)
+from sglang.srt.managers.tokenizer_manager import TokenizerManager
+from sglang.srt.metrics.func_timer import enable_func_timer
+from sglang.srt.openai_api.adapter import (
+ v1_batches,
+ v1_cancel_batch,
+ v1_chat_completions,
+ v1_completions,
+ v1_delete_file,
+ v1_embeddings,
+ v1_files_create,
+ v1_retrieve_batch,
+ v1_retrieve_file,
+ v1_retrieve_file_content,
+)
+from sglang.srt.openai_api.protocol import ModelCard, ModelList
+from sglang.srt.server_args import ServerArgs
+from sglang.srt.utils import (
+ add_api_key_middleware,
+ add_prometheus_middleware,
+ delete_directory,
+ kill_process_tree,
+ set_uvicorn_logging_configs,
+)
+from sglang.utils import get_exception_traceback
+from sglang.version import __version__
+
+logger = logging.getLogger(__name__)
+asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
+
+# Fast API
+app = FastAPI()
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+
+# Store global states
+@dataclasses.dataclass
+class _GlobalState:
+ tokenizer_manager: TokenizerManager
+ scheduler_info: Dict
+
+
+_global_state: Optional[_GlobalState] = None
+
+
+def set_global_state(global_state: _GlobalState):
+ global _global_state
+ _global_state = global_state
+
+
+##### Native API endpoints #####
+
+
+@app.get("/health")
+async def health() -> Response:
+ """Check the health of the http server."""
+ return Response(status_code=200)
+
+
+@app.get("/health_generate")
+async def health_generate(request: Request) -> Response:
+ """Check the health of the inference server by generating one token."""
+
+ sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
+
+ if _global_state.tokenizer_manager.is_generation:
+ gri = GenerateReqInput(
+ input_ids=[0], sampling_params=sampling_params, log_metrics=False
+ )
+ else:
+ gri = EmbeddingReqInput(
+ input_ids=[0], sampling_params=sampling_params, log_metrics=False
+ )
+
+ try:
+ async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
+ break
+ return Response(status_code=200)
+ except Exception as e:
+ logger.exception(e)
+ return Response(status_code=503)
+
+
+@app.get("/get_model_info")
+async def get_model_info():
+ """Get the model information."""
+ result = {
+ "model_path": _global_state.tokenizer_manager.model_path,
+ "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
+ "is_generation": _global_state.tokenizer_manager.is_generation,
+ }
+ return result
+
+
+@app.get("/get_server_info")
+async def get_server_info():
+ return {
+ **dataclasses.asdict(_global_state.tokenizer_manager.server_args),
+ **_global_state.scheduler_info,
+ "version": __version__,
+ }
+
+
+# fastapi implicitly converts json in the request to obj (dataclass)
+@app.api_route("/generate", methods=["POST", "PUT"])
+async def generate_request(obj: GenerateReqInput, request: Request):
+ """Handle a generate request."""
+ if obj.stream:
+
+ async def stream_results() -> AsyncIterator[bytes]:
+ try:
+ async for out in _global_state.tokenizer_manager.generate_request(
+ obj, request
+ ):
+ yield b"data: " + orjson.dumps(
+ out, option=orjson.OPT_NON_STR_KEYS
+ ) + b"\n\n"
+ except ValueError as e:
+ out = {"error": {"message": str(e)}}
+ yield b"data: " + orjson.dumps(
+ out, option=orjson.OPT_NON_STR_KEYS
+ ) + b"\n\n"
+ yield b"data: [DONE]\n\n"
+
+ return StreamingResponse(
+ stream_results(),
+ media_type="text/event-stream",
+ background=_global_state.tokenizer_manager.create_abort_task(obj),
+ )
+ else:
+ try:
+ ret = await _global_state.tokenizer_manager.generate_request(
+ obj, request
+ ).__anext__()
+ return ret
+ except ValueError as e:
+ logger.error(f"Error: {e}")
+ return _create_error_response(e)
+
+
+@app.api_route("/encode", methods=["POST", "PUT"])
+async def encode_request(obj: EmbeddingReqInput, request: Request):
+ """Handle an embedding request."""
+ try:
+ ret = await _global_state.tokenizer_manager.generate_request(
+ obj, request
+ ).__anext__()
+ return ret
+ except ValueError as e:
+ return _create_error_response(e)
+
+
+@app.api_route("/classify", methods=["POST", "PUT"])
+async def classify_request(obj: EmbeddingReqInput, request: Request):
+ """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
+ try:
+ ret = await _global_state.tokenizer_manager.generate_request(
+ obj, request
+ ).__anext__()
+ return ret
+ except ValueError as e:
+ return _create_error_response(e)
+
+
+@app.post("/flush_cache")
+async def flush_cache():
+ """Flush the radix cache."""
+ _global_state.tokenizer_manager.flush_cache()
+ return Response(
+ content="Cache flushed.\nPlease check backend logs for more details. "
+ "(When there are running or waiting requests, the operation will not be performed.)\n",
+ status_code=200,
+ )
+
+
+@app.api_route("/start_profile", methods=["GET", "POST"])
+async def start_profile_async():
+ """Start profiling."""
+ _global_state.tokenizer_manager.start_profile()
+ return Response(
+ content="Start profiling.\n",
+ status_code=200,
+ )
+
+
+@app.api_route("/stop_profile", methods=["GET", "POST"])
+async def stop_profile_async():
+ """Stop profiling."""
+ _global_state.tokenizer_manager.stop_profile()
+ return Response(
+ content="Stop profiling. This will take some time.\n",
+ status_code=200,
+ )
+
+
+@app.post("/update_weights_from_disk")
+async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
+ """Update the weights from disk in-place without re-launching the server."""
+ success, message = await _global_state.tokenizer_manager.update_weights_from_disk(
+ obj, request
+ )
+ content = {"success": success, "message": message}
+ if success:
+ return ORJSONResponse(
+ content,
+ status_code=HTTPStatus.OK,
+ )
+ else:
+ return ORJSONResponse(
+ content,
+ status_code=HTTPStatus.BAD_REQUEST,
+ )
+
+
+@app.post("/init_weights_update_group")
+async def init_weights_update_group(
+ obj: InitWeightsUpdateGroupReqInput, request: Request
+):
+ """Initialize the parameter update group."""
+ success, message = await _global_state.tokenizer_manager.init_weights_update_group(
+ obj, request
+ )
+ content = {"success": success, "message": message}
+ if success:
+ return ORJSONResponse(content, status_code=200)
+ else:
+ return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
+
+
+@app.post("/update_weights_from_distributed")
+async def update_weights_from_distributed(
+ obj: UpdateWeightsFromDistributedReqInput, request: Request
+):
+ """Update model parameter from distributed online."""
+ success, message = (
+ await _global_state.tokenizer_manager.update_weights_from_distributed(
+ obj, request
+ )
+ )
+ content = {"success": success, "message": message}
+ if success:
+ return ORJSONResponse(content, status_code=200)
+ else:
+ return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
+
+
+@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
+async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
+ """Get model parameter by name."""
+ try:
+ ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request)
+ if ret is None:
+ return _create_error_response("Get parameter by name failed")
+ else:
+ return ORJSONResponse(ret, status_code=200)
+ except Exception as e:
+ return _create_error_response(e)
+
+
+@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
+async def release_memory_occupation(
+ obj: ReleaseMemoryOccupationReqInput, request: Request
+):
+ """Release GPU occupation temporarily"""
+ try:
+ await _global_state.tokenizer_manager.release_memory_occupation(obj, request)
+ except Exception as e:
+ return _create_error_response(e)
+
+
+@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
+async def resume_memory_occupation(
+ obj: ResumeMemoryOccupationReqInput, request: Request
+):
+ """Resume GPU occupation"""
+ try:
+ await _global_state.tokenizer_manager.resume_memory_occupation(obj, request)
+ except Exception as e:
+ return _create_error_response(e)
+
+
+@app.api_route("/open_session", methods=["GET", "POST"])
+async def open_session(obj: OpenSessionReqInput, request: Request):
+ """Open a session, and return its unique session id."""
+ try:
+ session_id = await _global_state.tokenizer_manager.open_session(obj, request)
+ if session_id is None:
+ raise Exception(
+ "Failed to open the session. Check if a session with the same id is still open."
+ )
+ return session_id
+ except Exception as e:
+ return _create_error_response(e)
+
+
+@app.api_route("/close_session", methods=["GET", "POST"])
+async def close_session(obj: CloseSessionReqInput, request: Request):
+ """Close the session"""
+ try:
+ await _global_state.tokenizer_manager.close_session(obj, request)
+ return Response(status_code=200)
+ except Exception as e:
+ return _create_error_response(e)
+
+
+@app.api_route("/configure_logging", methods=["GET", "POST"])
+async def configure_logging(obj: ConfigureLoggingReq, request: Request):
+ """Close the session"""
+ _global_state.tokenizer_manager.configure_logging(obj)
+ return Response(status_code=200)
+
+
+##### OpenAI-compatible API endpoints #####
+
+
+@app.post("/v1/completions")
+async def openai_v1_completions(raw_request: Request):
+ return await v1_completions(_global_state.tokenizer_manager, raw_request)
+
+
+@app.post("/v1/chat/completions")
+async def openai_v1_chat_completions(raw_request: Request):
+ return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
+
+
+@app.post("/v1/embeddings", response_class=ORJSONResponse)
+async def openai_v1_embeddings(raw_request: Request):
+ response = await v1_embeddings(_global_state.tokenizer_manager, raw_request)
+ return response
+
+
+@app.get("/v1/models", response_class=ORJSONResponse)
+def available_models():
+ """Show available models."""
+ served_model_names = [_global_state.tokenizer_manager.served_model_name]
+ model_cards = []
+ for served_model_name in served_model_names:
+ model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
+ return ModelList(data=model_cards)
+
+
+@app.post("/v1/files")
+async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
+ return await v1_files_create(
+ file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth
+ )
+
+
+@app.delete("/v1/files/{file_id}")
+async def delete_file(file_id: str):
+ # https://platform.openai.com/docs/api-reference/files/delete
+ return await v1_delete_file(file_id)
+
+
+@app.post("/v1/batches")
+async def openai_v1_batches(raw_request: Request):
+ return await v1_batches(_global_state.tokenizer_manager, raw_request)
+
+
+@app.post("/v1/batches/{batch_id}/cancel")
+async def cancel_batches(batch_id: str):
+ # https://platform.openai.com/docs/api-reference/batch/cancel
+ return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id)
+
+
+@app.get("/v1/batches/{batch_id}")
+async def retrieve_batch(batch_id: str):
+ return await v1_retrieve_batch(batch_id)
+
+
+@app.get("/v1/files/{file_id}")
+async def retrieve_file(file_id: str):
+ # https://platform.openai.com/docs/api-reference/files/retrieve
+ return await v1_retrieve_file(file_id)
+
+
+@app.get("/v1/files/{file_id}/content")
+async def retrieve_file_content(file_id: str):
+ # https://platform.openai.com/docs/api-reference/files/retrieve-contents
+ return await v1_retrieve_file_content(file_id)
+
+
+def _create_error_response(e):
+ return ORJSONResponse(
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
+ )
+
+
+def launch_server(
+ server_args: ServerArgs,
+ pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None,
+):
+ """
+ Launch SRT (SGLang Runtime) Server.
+
+ The SRT server consists of an HTTP server and an SRT engine.
+
+ - HTTP server: A FastAPI server that routes requests to the engine.
+ - The engine consists of three components:
+ 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
+ 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
+ 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
+
+ Note:
+ 1. The HTTP server, Engine, and TokenizerManager both run in the main process.
+ 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
+ """
+ tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args)
+ set_global_state(
+ _GlobalState(
+ tokenizer_manager=tokenizer_manager,
+ scheduler_info=scheduler_info,
+ )
+ )
+
+ # Add api key authorization
+ if server_args.api_key:
+ add_api_key_middleware(app, server_args.api_key)
+
+ # Add prometheus middleware
+ if server_args.enable_metrics:
+ add_prometheus_middleware(app)
+ enable_func_timer()
+
+ # Send a warmup request
+ t = threading.Thread(
+ target=_wait_and_warmup,
+ args=(
+ server_args,
+ pipe_finish_writer,
+ _global_state.tokenizer_manager.image_token_id,
+ ),
+ )
+ t.start()
+
+ try:
+ # Update logging configs
+ set_uvicorn_logging_configs()
+
+ # Listen for HTTP requests
+ uvicorn.run(
+ app,
+ host=server_args.host,
+ port=server_args.port,
+ log_level=server_args.log_level_http or server_args.log_level,
+ timeout_keep_alive=5,
+ loop="uvloop",
+ )
+ finally:
+ t.join()
+
+
+def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
+ headers = {}
+ url = server_args.url()
+ if server_args.api_key:
+ headers["Authorization"] = f"Bearer {server_args.api_key}"
+
+ # Wait until the server is launched
+ success = False
+ for _ in range(120):
+ time.sleep(1)
+ try:
+ res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
+ assert res.status_code == 200, f"{res=}, {res.text=}"
+ success = True
+ break
+ except (AssertionError, requests.exceptions.RequestException):
+ last_traceback = get_exception_traceback()
+ pass
+
+ if not success:
+ if pipe_finish_writer is not None:
+ pipe_finish_writer.send(last_traceback)
+ logger.error(f"Initialization failed. warmup error: {last_traceback}")
+ kill_process_tree(os.getpid())
+ return
+
+ model_info = res.json()
+
+ # Send a warmup request
+ request_name = "/generate" if model_info["is_generation"] else "/encode"
+ max_new_tokens = 8 if model_info["is_generation"] else 1
+ json_data = {
+ "sampling_params": {
+ "temperature": 0,
+ "max_new_tokens": max_new_tokens,
+ },
+ }
+ if server_args.skip_tokenizer_init:
+ json_data["input_ids"] = [10, 11, 12]
+ else:
+ json_data["text"] = "The capital city of France is"
+
+ try:
+ for _ in range(server_args.dp_size):
+ res = requests.post(
+ url + request_name,
+ json=json_data,
+ headers=headers,
+ timeout=600,
+ )
+ assert res.status_code == 200, f"{res}"
+ except Exception:
+ last_traceback = get_exception_traceback()
+ if pipe_finish_writer is not None:
+ pipe_finish_writer.send(last_traceback)
+ logger.error(f"Initialization failed. warmup error: {last_traceback}")
+ kill_process_tree(os.getpid())
+ return
+
+ # Debug print
+ # logger.info(f"{res.json()=}")
+
+ logger.info("The server is fired up and ready to roll!")
+ if pipe_finish_writer is not None:
+ pipe_finish_writer.send("ready")
+
+ if server_args.delete_ckpt_after_loading:
+ delete_directory(server_args.model_path)
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 5a803dd997a..9183239838d 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -22,7 +22,6 @@
from typing import Dict, List, Optional, Union
from sglang.srt.managers.schedule_batch import BaseFinishReason
-from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import SamplingParams
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index d6178a959d0..162f10624f9 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -176,7 +176,7 @@ def __init__(
)
# Store states
- self.to_create_loop = True
+ self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {}
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000
@@ -684,7 +684,6 @@ async def open_session(
async def close_session(
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
):
- assert not self.to_create_loop, "close session should not be the first request"
await self.send_to_scheduler.send_pyobj(obj)
def configure_logging(self, obj: ConfigureLoggingReq):
@@ -713,10 +712,10 @@ async def abort_request():
return background_tasks
def auto_create_handle_loop(self):
- if not self.to_create_loop:
+ if self.no_create_loop:
return
- self.to_create_loop = False
+ self.no_create_loop = True
loop = asyncio.get_event_loop()
self.asyncio_tasks.add(
loop.create_task(print_exception_wrapper(self.handle_loop))
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
index 0b4d9c37218..8b0c5618622 100644
--- a/python/sglang/srt/server.py
+++ b/python/sglang/srt/server.py
@@ -11,949 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""
-The entry point of inference server.
-SRT = SGLang Runtime.
-"""
-import asyncio
-import atexit
-import dataclasses
-import json
-import logging
-import multiprocessing as mp
-import os
-import signal
-import threading
-import time
-from http import HTTPStatus
-from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
-
-import torch
-
-from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
-
-# Fix a bug of Python threading
-setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
-
-import aiohttp
-import orjson
-import requests
-import uvicorn
-import uvloop
-from fastapi import FastAPI, File, Form, Request, UploadFile
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import ORJSONResponse, Response, StreamingResponse
-
-from sglang.srt.managers.data_parallel_controller import (
- run_data_parallel_controller_process,
-)
-from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
-from sglang.srt.managers.io_struct import (
- CloseSessionReqInput,
- ConfigureLoggingReq,
- EmbeddingReqInput,
- GenerateReqInput,
- GetWeightsByNameReqInput,
- InitWeightsUpdateGroupReqInput,
- OpenSessionReqInput,
- ReleaseMemoryOccupationReqInput,
- ResumeMemoryOccupationReqInput,
- UpdateWeightFromDiskReqInput,
- UpdateWeightsFromDistributedReqInput,
- UpdateWeightsFromTensorReqInput,
-)
-from sglang.srt.managers.scheduler import run_scheduler_process
-from sglang.srt.managers.tokenizer_manager import TokenizerManager
-from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency
-from sglang.srt.openai_api.adapter import (
- load_chat_template_for_openai_api,
- v1_batches,
- v1_cancel_batch,
- v1_chat_completions,
- v1_completions,
- v1_delete_file,
- v1_embeddings,
- v1_files_create,
- v1_retrieve_batch,
- v1_retrieve_file,
- v1_retrieve_file_content,
-)
-from sglang.srt.openai_api.protocol import ModelCard, ModelList
-from sglang.srt.server_args import PortArgs, ServerArgs
-from sglang.srt.utils import (
- MultiprocessingSerializer,
- add_api_key_middleware,
- add_prometheus_middleware,
- assert_pkg_version,
- configure_logger,
- delete_directory,
- kill_process_tree,
- maybe_set_triton_cache_manager,
- prepare_model_and_tokenizer,
- set_prometheus_multiproc_dir,
- set_ulimit,
- set_uvicorn_logging_configs,
-)
-from sglang.utils import get_exception_traceback
-from sglang.version import __version__
-
-logger = logging.getLogger(__name__)
-
-asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
-
-# Fast API
-app = FastAPI()
-app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
-)
-
-tokenizer_manager: TokenizerManager = None
-scheduler_info: Dict = None
-
-
-##### Native API endpoints #####
-
-
-@app.get("/health")
-async def health() -> Response:
- """Check the health of the http server."""
- return Response(status_code=200)
-
-
-@app.get("/health_generate")
-async def health_generate(request: Request) -> Response:
- """Check the health of the inference server by generating one token."""
-
- sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
-
- if tokenizer_manager.is_generation:
- gri = GenerateReqInput(
- input_ids=[0], sampling_params=sampling_params, log_metrics=False
- )
- else:
- gri = EmbeddingReqInput(
- input_ids=[0], sampling_params=sampling_params, log_metrics=False
- )
-
- try:
- async for _ in tokenizer_manager.generate_request(gri, request):
- break
- return Response(status_code=200)
- except Exception as e:
- logger.exception(e)
- return Response(status_code=503)
-
-
-@app.get("/get_model_info")
-async def get_model_info():
- """Get the model information."""
- result = {
- "model_path": tokenizer_manager.model_path,
- "tokenizer_path": tokenizer_manager.server_args.tokenizer_path,
- "is_generation": tokenizer_manager.is_generation,
- }
- return result
-
-
-@app.get("/get_server_info")
-async def get_server_info():
- return {
- **dataclasses.asdict(tokenizer_manager.server_args),
- **scheduler_info,
- "version": __version__,
- }
-
-
-# fastapi implicitly converts json in the request to obj (dataclass)
-@app.api_route("/generate", methods=["POST", "PUT"])
-@time_func_latency
-async def generate_request(obj: GenerateReqInput, request: Request):
- """Handle a generate request."""
- if obj.stream:
-
- async def stream_results() -> AsyncIterator[bytes]:
- try:
- async for out in tokenizer_manager.generate_request(obj, request):
- yield b"data: " + orjson.dumps(
- out, option=orjson.OPT_NON_STR_KEYS
- ) + b"\n\n"
- except ValueError as e:
- out = {"error": {"message": str(e)}}
- yield b"data: " + orjson.dumps(
- out, option=orjson.OPT_NON_STR_KEYS
- ) + b"\n\n"
- yield b"data: [DONE]\n\n"
-
- return StreamingResponse(
- stream_results(),
- media_type="text/event-stream",
- background=tokenizer_manager.create_abort_task(obj),
- )
- else:
- try:
- ret = await tokenizer_manager.generate_request(obj, request).__anext__()
- return ret
- except ValueError as e:
- logger.error(f"Error: {e}")
- return _create_error_response(e)
-
-
-@app.api_route("/encode", methods=["POST", "PUT"])
-@time_func_latency
-async def encode_request(obj: EmbeddingReqInput, request: Request):
- """Handle an embedding request."""
- try:
- ret = await tokenizer_manager.generate_request(obj, request).__anext__()
- return ret
- except ValueError as e:
- return _create_error_response(e)
-
-
-@app.api_route("/classify", methods=["POST", "PUT"])
-@time_func_latency
-async def classify_request(obj: EmbeddingReqInput, request: Request):
- """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
- try:
- ret = await tokenizer_manager.generate_request(obj, request).__anext__()
- return ret
- except ValueError as e:
- return _create_error_response(e)
-
-
-@app.post("/flush_cache")
-async def flush_cache():
- """Flush the radix cache."""
- tokenizer_manager.flush_cache()
- return Response(
- content="Cache flushed.\nPlease check backend logs for more details. "
- "(When there are running or waiting requests, the operation will not be performed.)\n",
- status_code=200,
- )
-
-
-@app.api_route("/start_profile", methods=["GET", "POST"])
-async def start_profile_async():
- """Start profiling."""
- tokenizer_manager.start_profile()
- return Response(
- content="Start profiling.\n",
- status_code=200,
- )
-
-
-@app.api_route("/stop_profile", methods=["GET", "POST"])
-async def stop_profile_async():
- """Stop profiling."""
- tokenizer_manager.stop_profile()
- return Response(
- content="Stop profiling. This will take some time.\n",
- status_code=200,
- )
-
-
-@app.post("/update_weights_from_disk")
-@time_func_latency
-async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
- """Update the weights from disk in-place without re-launching the server."""
- success, message = await tokenizer_manager.update_weights_from_disk(obj, request)
- content = {"success": success, "message": message}
- if success:
- return ORJSONResponse(
- content,
- status_code=HTTPStatus.OK,
- )
- else:
- return ORJSONResponse(
- content,
- status_code=HTTPStatus.BAD_REQUEST,
- )
-
-
-@app.post("/init_weights_update_group")
-async def init_weights_update_group(
- obj: InitWeightsUpdateGroupReqInput, request: Request
-):
- """Initialize the parameter update group."""
- success, message = await tokenizer_manager.init_weights_update_group(obj, request)
- content = {"success": success, "message": message}
- if success:
- return ORJSONResponse(content, status_code=200)
- else:
- return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
-
-
-@app.post("/update_weights_from_distributed")
-async def update_weights_from_distributed(
- obj: UpdateWeightsFromDistributedReqInput, request: Request
-):
- """Update model parameter from distributed online."""
- success, message = await tokenizer_manager.update_weights_from_distributed(
- obj, request
- )
- content = {"success": success, "message": message}
- if success:
- return ORJSONResponse(content, status_code=200)
- else:
- return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
-
-
-@app.api_route("/get_weights_by_name", methods=["GET", "POST"])
-async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
- """Get model parameter by name."""
- try:
- ret = await tokenizer_manager.get_weights_by_name(obj, request)
- if ret is None:
- return _create_error_response("Get parameter by name failed")
- else:
- return ORJSONResponse(ret, status_code=200)
- except Exception as e:
- return _create_error_response(e)
-
-
-@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
-async def release_memory_occupation(
- obj: ReleaseMemoryOccupationReqInput, request: Request
-):
- """Release GPU occupation temporarily"""
- try:
- await tokenizer_manager.release_memory_occupation(obj, request)
- except Exception as e:
- return _create_error_response(e)
-
-
-@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
-async def resume_memory_occupation(
- obj: ResumeMemoryOccupationReqInput, request: Request
-):
- """Resume GPU occupation"""
- try:
- await tokenizer_manager.resume_memory_occupation(obj, request)
- except Exception as e:
- return _create_error_response(e)
-
-
-@app.api_route("/open_session", methods=["GET", "POST"])
-async def open_session(obj: OpenSessionReqInput, request: Request):
- """Open a session, and return its unique session id."""
- try:
- session_id = await tokenizer_manager.open_session(obj, request)
- if session_id is None:
- raise Exception(
- "Failed to open the session. Check if a session with the same id is still open."
- )
- return session_id
- except Exception as e:
- return _create_error_response(e)
-
-
-@app.api_route("/close_session", methods=["GET", "POST"])
-async def close_session(obj: CloseSessionReqInput, request: Request):
- """Close the session"""
- try:
- await tokenizer_manager.close_session(obj, request)
- return Response(status_code=200)
- except Exception as e:
- return _create_error_response(e)
-
-
-@app.api_route("/configure_logging", methods=["GET", "POST"])
-async def configure_logging(obj: ConfigureLoggingReq, request: Request):
- """Close the session"""
- tokenizer_manager.configure_logging(obj)
- return Response(status_code=200)
-
-
-##### OpenAI-compatible API endpoints #####
-
-
-@app.post("/v1/completions")
-@time_func_latency
-async def openai_v1_completions(raw_request: Request):
- return await v1_completions(tokenizer_manager, raw_request)
-
-
-@app.post("/v1/chat/completions")
-@time_func_latency
-async def openai_v1_chat_completions(raw_request: Request):
- return await v1_chat_completions(tokenizer_manager, raw_request)
-
-
-@app.post("/v1/embeddings", response_class=ORJSONResponse)
-@time_func_latency
-async def openai_v1_embeddings(raw_request: Request):
- response = await v1_embeddings(tokenizer_manager, raw_request)
- return response
-
-
-@app.get("/v1/models", response_class=ORJSONResponse)
-def available_models():
- """Show available models."""
- served_model_names = [tokenizer_manager.served_model_name]
- model_cards = []
- for served_model_name in served_model_names:
- model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
- return ModelList(data=model_cards)
-
-
-@app.post("/v1/files")
-async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
- return await v1_files_create(
- file, purpose, tokenizer_manager.server_args.file_storage_pth
- )
-
-
-@app.delete("/v1/files/{file_id}")
-async def delete_file(file_id: str):
- # https://platform.openai.com/docs/api-reference/files/delete
- return await v1_delete_file(file_id)
-
-
-@app.post("/v1/batches")
-async def openai_v1_batches(raw_request: Request):
- return await v1_batches(tokenizer_manager, raw_request)
-
-
-@app.post("/v1/batches/{batch_id}/cancel")
-async def cancel_batches(batch_id: str):
- # https://platform.openai.com/docs/api-reference/batch/cancel
- return await v1_cancel_batch(tokenizer_manager, batch_id)
-
-
-@app.get("/v1/batches/{batch_id}")
-async def retrieve_batch(batch_id: str):
- return await v1_retrieve_batch(batch_id)
-
-
-@app.get("/v1/files/{file_id}")
-async def retrieve_file(file_id: str):
- # https://platform.openai.com/docs/api-reference/files/retrieve
- return await v1_retrieve_file(file_id)
-
-
-@app.get("/v1/files/{file_id}/content")
-async def retrieve_file_content(file_id: str):
- # https://platform.openai.com/docs/api-reference/files/retrieve-contents
- return await v1_retrieve_file_content(file_id)
-
-
-def _create_error_response(e):
- return ORJSONResponse(
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
- )
-
-
-def launch_engine(
- server_args: ServerArgs,
-):
- """
- Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
- """
-
- global tokenizer_manager
- global scheduler_info
-
- # Configure global environment
- configure_logger(server_args)
- server_args.check_server_args()
- _set_envs_and_config(server_args)
-
- # Allocate ports for inter-process communications
- port_args = PortArgs.init_new(server_args)
- logger.info(f"{server_args=}")
-
- # If using model from www.modelscope.cn, first download the model.
- server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
- server_args.model_path, server_args.tokenizer_path
- )
-
- scheduler_procs = []
- if server_args.dp_size == 1:
- # Launch tensor parallel scheduler processes
- memory_saver_adapter = TorchMemorySaverAdapter.create(
- enable=server_args.enable_memory_saver
- )
-
- scheduler_pipe_readers = []
- tp_size_per_node = server_args.tp_size // server_args.nnodes
- tp_rank_range = range(
- tp_size_per_node * server_args.node_rank,
- tp_size_per_node * (server_args.node_rank + 1),
- )
- for tp_rank in tp_rank_range:
- reader, writer = mp.Pipe(duplex=False)
- gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
- proc = mp.Process(
- target=run_scheduler_process,
- args=(server_args, port_args, gpu_id, tp_rank, None, writer),
- )
- with memory_saver_adapter.configure_subprocess():
- proc.start()
- scheduler_procs.append(proc)
- scheduler_pipe_readers.append(reader)
- else:
- # Launch the data parallel controller
- reader, writer = mp.Pipe(duplex=False)
- scheduler_pipe_readers = [reader]
- proc = mp.Process(
- target=run_data_parallel_controller_process,
- args=(server_args, port_args, writer),
- )
- proc.start()
- scheduler_procs.append(proc)
-
- if server_args.node_rank >= 1:
- # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer,
- # so they can just wait here.
-
- for reader in scheduler_pipe_readers:
- data = reader.recv()
- assert data["status"] == "ready"
-
- if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
- # When using `Engine` as a Python API, we don't want to block here.
- return
-
- for proc in scheduler_procs:
- proc.join()
- logger.error(
- f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
- )
- return
-
- # Launch detokenizer process
- detoken_proc = mp.Process(
- target=run_detokenizer_process,
- args=(
- server_args,
- port_args,
- ),
- )
- detoken_proc.start()
-
- # Launch tokenizer process
- tokenizer_manager = TokenizerManager(server_args, port_args)
- if server_args.chat_template:
- load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
-
- # Wait for model to finish loading
- scheduler_infos = []
- for i in range(len(scheduler_pipe_readers)):
- try:
- data = scheduler_pipe_readers[i].recv()
- except EOFError as e:
- logger.exception(e)
- logger.error(
- f"Rank {i} scheduler is dead. Please check if there are relevant logs."
- )
- scheduler_procs[i].join()
- logger.error(f"Exit code: {scheduler_procs[i].exitcode}")
- raise
-
- if data["status"] != "ready":
- raise RuntimeError(
- "Initialization failed. Please see the error messages above."
- )
- scheduler_infos.append(data)
-
- # Assume all schedulers have same scheduler_info
- scheduler_info = scheduler_infos[0]
- tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
-
-
-def launch_server(
- server_args: ServerArgs,
- pipe_finish_writer: Optional[mp.connection.Connection] = None,
-):
- """
- Launch SRT (SGLang Runtime) Server
-
- The SRT server consists of an HTTP server and the SRT engine.
-
- 1. HTTP server: A FastAPI server that routes requests to the engine.
- 2. SRT engine:
- 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
- 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
- 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
-
- Note:
- 1. The HTTP server and TokenizerManager both run in the main process.
- 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
- """
- launch_engine(server_args=server_args)
-
- # Add api key authorization
- if server_args.api_key:
- add_api_key_middleware(app, server_args.api_key)
-
- # Add prometheus middleware
- if server_args.enable_metrics:
- add_prometheus_middleware(app)
- enable_func_timer()
-
- # Send a warmup request
- t = threading.Thread(
- target=_wait_and_warmup,
- args=(
- server_args,
- pipe_finish_writer,
- tokenizer_manager.image_token_id,
- ),
- )
- t.start()
-
- try:
- # Update logging configs
- set_uvicorn_logging_configs()
-
- # Listen for HTTP requests
- uvicorn.run(
- app,
- host=server_args.host,
- port=server_args.port,
- log_level=server_args.log_level_http or server_args.log_level,
- timeout_keep_alive=5,
- loop="uvloop",
- )
- finally:
- t.join()
-
-
-def _set_envs_and_config(server_args: ServerArgs):
- # Set global environments
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
- os.environ["NCCL_CUMEM_ENABLE"] = "0"
- os.environ["NCCL_NVLS_ENABLE"] = "0"
- os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
- os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
-
- # Set prometheus env vars
- if server_args.enable_metrics:
- set_prometheus_multiproc_dir()
-
- # Set ulimit
- set_ulimit()
-
- # Fix triton bugs
- if server_args.tp_size * server_args.dp_size > 1:
- # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
- maybe_set_triton_cache_manager()
-
- # Check flashinfer version
- if server_args.attention_backend == "flashinfer":
- assert_pkg_version(
- "flashinfer",
- "0.1.6",
- "Please uninstall the old version and "
- "reinstall the latest version by following the instructions "
- "at https://docs.flashinfer.ai/installation.html.",
- )
-
- # Register the signal handler.
- # The child processes will send SIGQUIT to this process when any error happens
- # This process then clean up the whole process tree
- def sigquit_handler(signum, frame):
- logger.error(
- "Received sigquit from a child proces. It usually means the child failed."
- )
- kill_process_tree(os.getpid())
-
- signal.signal(signal.SIGQUIT, sigquit_handler)
-
- # Set mp start method
- mp.set_start_method("spawn", force=True)
-
-
-def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
- headers = {}
- url = server_args.url()
- if server_args.api_key:
- headers["Authorization"] = f"Bearer {server_args.api_key}"
-
- # Wait until the server is launched
- success = False
- for _ in range(120):
- time.sleep(1)
- try:
- res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
- assert res.status_code == 200, f"{res=}, {res.text=}"
- success = True
- break
- except (AssertionError, requests.exceptions.RequestException):
- last_traceback = get_exception_traceback()
- pass
-
- if not success:
- if pipe_finish_writer is not None:
- pipe_finish_writer.send(last_traceback)
- logger.error(f"Initialization failed. warmup error: {last_traceback}")
- kill_process_tree(os.getpid())
- return
-
- model_info = res.json()
-
- # Send a warmup request
- request_name = "/generate" if model_info["is_generation"] else "/encode"
- max_new_tokens = 8 if model_info["is_generation"] else 1
- json_data = {
- "sampling_params": {
- "temperature": 0,
- "max_new_tokens": max_new_tokens,
- },
- }
- if server_args.skip_tokenizer_init:
- json_data["input_ids"] = [10, 11, 12]
- else:
- json_data["text"] = "The capital city of France is"
-
- try:
- for _ in range(server_args.dp_size):
- res = requests.post(
- url + request_name,
- json=json_data,
- headers=headers,
- timeout=600,
- )
- assert res.status_code == 200, f"{res}"
- except Exception:
- last_traceback = get_exception_traceback()
- if pipe_finish_writer is not None:
- pipe_finish_writer.send(last_traceback)
- logger.error(f"Initialization failed. warmup error: {last_traceback}")
- kill_process_tree(os.getpid())
- return
-
- # Debug print
- # logger.info(f"{res.json()=}")
-
- logger.info("The server is fired up and ready to roll!")
- if pipe_finish_writer is not None:
- pipe_finish_writer.send("ready")
-
- if server_args.delete_ckpt_after_loading:
- delete_directory(server_args.model_path)
-
-
-STREAM_END_SYMBOL = b"data: [DONE]"
-STREAM_CHUNK_START_SYMBOL = b"data:"
-
-
-class Engine:
- """
- SRT Engine without an HTTP server layer.
-
- This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where
- launching the HTTP server adds unnecessary complexity or overhead,
- """
-
- def __init__(self, log_level: str = "error", *args, **kwargs):
- """See the arguments in server_args.py::ServerArgs"""
-
- # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
- atexit.register(self.shutdown)
-
- server_args = ServerArgs(*args, log_level=log_level, **kwargs)
- launch_engine(server_args=server_args)
-
- def generate(
- self,
- # The input prompt. It can be a single prompt or a batch of prompts.
- prompt: Optional[Union[List[str], str]] = None,
- sampling_params: Optional[Union[List[Dict], Dict]] = None,
- # The token ids for text; one can either specify text or input_ids.
- input_ids: Optional[Union[List[List[int]], List[int]]] = None,
- return_logprob: Optional[Union[List[bool], bool]] = False,
- logprob_start_len: Optional[Union[List[int], int]] = None,
- top_logprobs_num: Optional[Union[List[int], int]] = None,
- lora_path: Optional[List[Optional[str]]] = None,
- custom_logit_processor: Optional[Union[List[str], str]] = None,
- stream: bool = False,
- ):
- obj = GenerateReqInput(
- text=prompt,
- input_ids=input_ids,
- sampling_params=sampling_params,
- return_logprob=return_logprob,
- logprob_start_len=logprob_start_len,
- top_logprobs_num=top_logprobs_num,
- lora_path=lora_path,
- stream=stream,
- custom_logit_processor=custom_logit_processor,
- )
-
- # get the current event loop
- loop = asyncio.get_event_loop()
- ret = loop.run_until_complete(generate_request(obj, None))
-
- if stream is True:
-
- def generator_wrapper():
- offset = 0
- loop = asyncio.get_event_loop()
- generator = ret.body_iterator
- while True:
- chunk = loop.run_until_complete(generator.__anext__())
-
- if chunk.startswith(STREAM_END_SYMBOL):
- break
- else:
- data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
- data["text"] = data["text"][offset:]
- offset += len(data["text"])
- yield data
-
- # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
- # however, it allows to wrap the generator as a subfunction and return
- return generator_wrapper()
- else:
- return ret
-
- async def async_generate(
- self,
- # The input prompt. It can be a single prompt or a batch of prompts.
- prompt: Optional[Union[List[str], str]] = None,
- sampling_params: Optional[Dict] = None,
- # The token ids for text; one can either specify text or input_ids.
- input_ids: Optional[Union[List[List[int]], List[int]]] = None,
- return_logprob: Optional[Union[List[bool], bool]] = False,
- logprob_start_len: Optional[Union[List[int], int]] = None,
- top_logprobs_num: Optional[Union[List[int], int]] = None,
- lora_path: Optional[List[Optional[str]]] = None,
- custom_logit_processor: Optional[Union[str, List[str]]] = None,
- stream: bool = False,
- ):
- obj = GenerateReqInput(
- text=prompt,
- input_ids=input_ids,
- sampling_params=sampling_params,
- return_logprob=return_logprob,
- logprob_start_len=logprob_start_len,
- top_logprobs_num=top_logprobs_num,
- lora_path=lora_path,
- stream=stream,
- custom_logit_processor=custom_logit_processor,
- )
-
- ret = await generate_request(obj, None)
-
- if stream is True:
- generator = ret.body_iterator
-
- async def generator_wrapper():
- offset = 0
-
- while True:
- chunk = await generator.__anext__()
-
- if chunk.startswith(STREAM_END_SYMBOL):
- break
- else:
- data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
- data["text"] = data["text"][offset:]
- offset += len(data["text"])
- yield data
-
- return generator_wrapper()
- else:
- return ret
-
- def shutdown(self):
- kill_process_tree(os.getpid(), include_parent=False)
-
- def get_tokenizer(self):
- global tokenizer_manager
-
- if tokenizer_manager is None:
- raise ReferenceError("Tokenizer Manager is not initialized.")
- else:
- return tokenizer_manager.tokenizer
-
- def encode(
- self,
- prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
- ):
- obj = EmbeddingReqInput(text=prompt)
-
- # get the current event loop
- loop = asyncio.get_event_loop()
- return loop.run_until_complete(encode_request(obj, None))
-
- def start_profile(self):
- tokenizer_manager.start_profile()
-
- def stop_profile(self):
- tokenizer_manager.stop_profile()
-
- def get_server_info(self):
- return {
- **dataclasses.asdict(tokenizer_manager.server_args), # server args
- **scheduler_info,
- "version": __version__,
- }
-
- def init_weights_update_group(
- self,
- master_address: str,
- master_port: int,
- rank_offset: int,
- world_size: int,
- group_name: str,
- backend: str = "nccl",
- ):
- """Initialize parameter update group."""
- obj = InitWeightsUpdateGroupReqInput(
- master_address=master_address,
- master_port=master_port,
- rank_offset=rank_offset,
- world_size=world_size,
- group_name=group_name,
- backend=backend,
- )
- loop = asyncio.get_event_loop()
- return loop.run_until_complete(
- tokenizer_manager.init_weights_update_group(obj, None)
- )
-
- def update_weights_from_distributed(self, name, dtype, shape):
- """Update weights from distributed source."""
- obj = UpdateWeightsFromDistributedReqInput(
- name=name,
- dtype=dtype,
- shape=shape,
- )
- loop = asyncio.get_event_loop()
- return loop.run_until_complete(
- tokenizer_manager.update_weights_from_distributed(obj, None)
- )
-
- def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
- """Update weights from distributed source."""
- obj = UpdateWeightsFromTensorReqInput(
- serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
- )
- loop = asyncio.get_event_loop()
- return loop.run_until_complete(
- tokenizer_manager.update_weights_from_tensor(obj, None)
- )
-
- def get_weights_by_name(self, name, truncate_size=100):
- """Get weights by parameter name."""
- obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
- loop = asyncio.get_event_loop()
- return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
-
- def release_memory_occupation(self):
- """Release GPU occupation temporarily"""
- obj = ReleaseMemoryOccupationReqInput()
- loop = asyncio.get_event_loop()
- loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None))
-
- def resume_memory_occupation(self):
- """Resume GPU occupation"""
- obj = ResumeMemoryOccupationReqInput()
- loop = asyncio.get_event_loop()
- loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
+# Some shortcuts for backward compatbility.
+# They will be removed in new versions.
+from sglang.srt.entrypoints.engine import Engine
+from sglang.srt.entrypoints.http_server import launch_server
diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py
index fc9a9793715..bae0fcf2a49 100644
--- a/python/sglang/test/runners.py
+++ b/python/sglang/test/runners.py
@@ -12,7 +12,6 @@
# limitations under the License.
# ==============================================================================
-import json
import multiprocessing as mp
import os
from dataclasses import dataclass
@@ -22,8 +21,8 @@
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
+from sglang.srt.entrypoints.engine import Engine
from sglang.srt.hf_transformers_utils import get_tokenizer
-from sglang.srt.server import Engine
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
DEFAULT_PROMPTS = [
diff --git a/sgl-router/py_src/sglang_router/launch_server.py b/sgl-router/py_src/sglang_router/launch_server.py
index 2f433269efa..93bc2345d18 100644
--- a/sgl-router/py_src/sglang_router/launch_server.py
+++ b/sgl-router/py_src/sglang_router/launch_server.py
@@ -13,7 +13,7 @@
from setproctitle import setproctitle
from sglang_router.launch_router import RouterArgs, launch_router
-from sglang.srt.server import launch_server
+from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_port_available
diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py
index 69babf795f0..2837107a1e6 100644
--- a/test/srt/test_metrics.py
+++ b/test/srt/test_metrics.py
@@ -56,7 +56,6 @@ def test_metrics_enabled(self):
"sglang:gen_throughput",
"sglang:num_queue_reqs",
"sglang:cache_hit_rate",
- "sglang:func_latency_seconds",
"sglang:prompt_tokens_total",
"sglang:generation_tokens_total",
"sglang:num_requests_total",
diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py
index 2e379c11179..06c83048f39 100644
--- a/test/srt/test_nightly_gsm8k_eval.py
+++ b/test/srt/test_nightly_gsm8k_eval.py
@@ -45,7 +45,7 @@ def parse_models(model_string):
return [model.strip() for model in model_string.split(",") if model.strip()]
-def launch_server(base_url, model, is_fp8, is_tp2):
+def popen_launch_server_wrapper(base_url, model, is_fp8, is_tp2):
other_args = ["--log-level-http", "warning", "--trust-remote-code"]
if is_fp8:
if "Llama-3" in model or "gemma-2" in model:
@@ -148,7 +148,9 @@ def test_mgsm_en_all_models(self):
for model_group, is_fp8, is_tp2 in self.model_groups:
for model in model_group:
with self.subTest(model=model):
- process = launch_server(self.base_url, model, is_fp8, is_tp2)
+ process = popen_launch_server_wrapper(
+ self.base_url, model, is_fp8, is_tp2
+ )
args = SimpleNamespace(
base_url=self.base_url,
diff --git a/test/srt/test_nightly_human_eval.py b/test/srt/test_nightly_human_eval.py
index 0b682937a82..6558b9effb9 100644
--- a/test/srt/test_nightly_human_eval.py
+++ b/test/srt/test_nightly_human_eval.py
@@ -4,7 +4,7 @@
import subprocess
import unittest
-from test_nightly_gsm8k_eval import launch_server, parse_models
+from test_nightly_gsm8k_eval import parse_models, popen_launch_server_wrapper
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
@@ -93,7 +93,7 @@ def test_human_eval_all_models(self):
# NOTE: only Llama for now
if "Llama" in model:
with self.subTest(model=model):
- self.process = launch_server(
+ self.process = popen_launch_server_wrapper(
self.base_url, model, is_fp8, is_tp2
)
self.run_evalplus(model)
diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py
index 7479b646837..c535d5c0686 100644
--- a/test/srt/test_srt_engine.py
+++ b/test/srt/test_srt_engine.py
@@ -1,6 +1,6 @@
"""
Usage:
-python3 -m unittest test_srt_engine.TestSRTEngine.test_3_sync_streaming_combination
+python3 -m unittest test_srt_engine.TestSRTEngine.test_4_sync_async_stream_combination
"""
import asyncio
@@ -44,64 +44,97 @@ def test_1_engine_runtime_consistency(self):
print(out2)
self.assertEqual(out1, out2)
- def test_2_engine_multiple_generate(self):
+ def test_2_engine_runtime_encode_consistency(self):
+ prompt = "Today is a sunny day and I like"
+ model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
+
+ engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42)
+ out1 = torch.tensor(engine.encode(prompt)["embedding"])
+ engine.shutdown()
+
+ runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42)
+ out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"])
+ runtime.shutdown()
+
+ self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3))
+
+ def test_3_engine_token_ids_consistency(self):
# just to ensure there is no issue running multiple generate calls
prompt = "Today is a sunny day and I like"
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
-
sampling_params = {"temperature": 0, "max_new_tokens": 8}
- engine = sgl.Engine(model_path=model_path, random_seed=42)
- engine.generate(prompt, sampling_params)
- engine.generate(prompt, sampling_params)
- engine.shutdown()
+ engine = sgl.Engine(
+ model_path=model_path, random_seed=42, disable_radix_cache=True
+ )
+ out1 = engine.generate(prompt, sampling_params)["text"]
- def test_3_sync_streaming_combination(self):
+ tokenizer = get_tokenizer(model_path)
+ token_ids = tokenizer.encode(prompt)
+ out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[
+ "text"
+ ]
- prompt = "AI safety is..."
- sampling_params = {"temperature": 0.8, "top_p": 0.95}
+ engine.shutdown()
- async def async_streaming(engine):
+ print("==== Answer 1 ====")
+ print(out1)
- generator = await engine.async_generate(
- prompt, sampling_params, stream=True
- )
+ print("==== Answer 2 ====")
+ print(out2)
+ self.assertEqual(out1, out2)
- async for output in generator:
- print(output["text"], end="", flush=True)
- print()
+ def test_4_sync_async_stream_combination(self):
+ prompt = "AI safety is"
+ sampling_params = {"temperature": 0.8, "top_p": 0.95}
# Create an LLM.
llm = sgl.Engine(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
)
- # 1. sync + non streaming
- print("\n\n==== 1. sync + non streaming ====")
- output = llm.generate(prompt, sampling_params)
+ if True:
+ # 1. sync + non streaming
+ print("\n\n==== 1. sync + non streaming ====")
+ output = llm.generate(prompt, sampling_params)
+ print(output["text"])
+
+ # 2. sync + streaming
+ print("\n\n==== 2. sync + streaming ====")
+ output_generator = llm.generate(prompt, sampling_params, stream=True)
+ offset = 0
+ for output in output_generator:
+ print(output["text"][offset:], end="", flush=True)
+ offset = len(output["text"])
+ print()
- print(output["text"])
+ if True:
+ loop = asyncio.get_event_loop()
+ # 3. async + non_streaming
+ print("\n\n==== 3. async + non streaming ====")
+ output = loop.run_until_complete(
+ llm.async_generate(prompt, sampling_params)
+ )
+ print(output["text"])
- # 2. sync + streaming
- print("\n\n==== 2. sync + streaming ====")
- output_generator = llm.generate(prompt, sampling_params, stream=True)
- for output in output_generator:
- print(output["text"], end="", flush=True)
- print()
+ # 4. async + streaming
+ async def async_streaming(engine):
+ generator = await engine.async_generate(
+ prompt, sampling_params, stream=True
+ )
- loop = asyncio.get_event_loop()
- # 3. async + non_streaming
- print("\n\n==== 3. async + non streaming ====")
- output = loop.run_until_complete(llm.async_generate(prompt, sampling_params))
- print(output["text"])
+ offset = 0
+ async for output in generator:
+ print(output["text"][offset:], end="", flush=True)
+ offset = len(output["text"])
+ print()
- # 4. async + streaming
- print("\n\n==== 4. async + streaming ====")
- loop.run_until_complete(async_streaming(llm))
+ print("\n\n==== 4. async + streaming ====")
+ loop.run_until_complete(async_streaming(llm))
llm.shutdown()
- def test_4_gsm8k(self):
+ def test_5_gsm8k(self):
args = SimpleNamespace(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
@@ -113,46 +146,7 @@ def test_4_gsm8k(self):
metrics = run_eval(args)
self.assertGreater(metrics["accuracy"], 0.3)
- def test_5_prompt_input_ids_consistency(self):
- prompt = "The capital of UK is"
-
- model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
- engine = sgl.Engine(
- model_path=model_path, random_seed=42, disable_radix_cache=True
- )
- sampling_params = {"temperature": 0, "max_new_tokens": 8}
- out1 = engine.generate(prompt, sampling_params)["text"]
-
- tokenizer = get_tokenizer(model_path)
- token_ids = tokenizer.encode(prompt)
- out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[
- "text"
- ]
-
- engine.shutdown()
-
- print("==== Answer 1 ====")
- print(out1)
-
- print("==== Answer 2 ====")
- print(out2)
- self.assertEqual(out1, out2)
-
- def test_6_engine_runtime_encode_consistency(self):
- prompt = "Today is a sunny day and I like"
- model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
-
- engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42)
- out1 = torch.tensor(engine.encode(prompt)["embedding"])
- engine.shutdown()
-
- runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42)
- out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"])
- runtime.shutdown()
-
- self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3))
-
- def test_7_engine_cpu_offload(self):
+ def test_6_engine_cpu_offload(self):
prompt = "Today is a sunny day and I like"
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
@@ -182,7 +176,7 @@ def test_7_engine_cpu_offload(self):
print(out2)
self.assertEqual(out1, out2)
- def test_8_engine_offline_throughput(self):
+ def test_7_engine_offline_throughput(self):
server_args = ServerArgs(
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
)
From 09bcbe0123ba33e5487b1e86505de04c3749ada4 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Sun, 19 Jan 2025 23:37:27 -0800
Subject: [PATCH 026/147] Update TypeBasedDispatcher and balance CI tests
(#3001)
---
.github/workflows/pr-test.yml | 2 +-
python/sglang/srt/managers/tokenizer_manager.py | 7 ++++---
2 files changed, 5 insertions(+), 4 deletions(-)
diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml
index b910683e7da..8b8d7c56e7f 100644
--- a/.github/workflows/pr-test.yml
+++ b/.github/workflows/pr-test.yml
@@ -52,7 +52,7 @@ jobs:
runs-on: 1-gpu-runner
strategy:
matrix:
- range: [0-6, 6-15, 15-22, 22-32, 32-37, 37-100]
+ range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100]
steps:
- name: Checkout code
uses: actions/checkout@v3
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 162f10624f9..2be2e532d07 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -226,9 +226,10 @@ def __init__(
self._result_dispatcher = TypeBasedDispatcher(
[
- (BatchStrOut, self._handle_batch_output),
- (BatchEmbeddingOut, self._handle_batch_output),
- (BatchTokenIDOut, self._handle_batch_output),
+ (
+ (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
+ self._handle_batch_output,
+ ),
(OpenSessionReqOutput, self._handle_open_session_req_output),
(
UpdateWeightFromDiskReqOutput,
From 51e87f6f216d7a5f0f16f1050b3974da8238d96c Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Mon, 20 Jan 2025 00:28:47 -0800
Subject: [PATCH 027/147] Skip flaky custom_logit_processor tests (#3004)
---
test/srt/test_srt_endpoint.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py
index 7afdc9bf41c..cddd75fa6d6 100644
--- a/test/srt/test_srt_endpoint.py
+++ b/test/srt/test_srt_endpoint.py
@@ -301,10 +301,14 @@ def __call__(self, logits, custom_param_list):
def test_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
+ # Temporarily skipped due to buggy implementation
+ return
self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch(self):
"""Test custom logit processor with a batch of requests."""
+ # Temporarily skipped due to buggy implementation
+ return
target_token_ids = list(range(32))
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
From 2584f6d94487645c48696762194457f6296c5ea7 Mon Sep 17 00:00:00 2001
From: Chayenne
Date: Mon, 20 Jan 2025 01:00:52 -0800
Subject: [PATCH 028/147] Docs: Add Performance Demonstaration for DPA (#3005)
---
docs/references/deepseek.md | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md
index 913395357e1..2bdceb90478 100644
--- a/docs/references/deepseek.md
+++ b/docs/references/deepseek.md
@@ -34,6 +34,10 @@ Overall, with these optimizations, we have achieved up to a 7x acceleration in o
**Usage**: This optimization is aimed at improving throughput and should be used for scenarios with high QPS (Queries Per Second). Data Parallelism Attention optimization can be enabeld by `--enable-dp-attention` for DeepSeek Series Models.
+
+
+
+
**Reference**: Check [Blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models).
## Multi Node Tensor Parallelism
From 583697cd71faa65a2e132a014743f5ff5c63890a Mon Sep 17 00:00:00 2001
From: Hongpeng Guo
Date: Mon, 20 Jan 2025 02:00:35 -0800
Subject: [PATCH 029/147] [Enhancement] Custom Logit Processor Improvement
(#2998)
Signed-off-by: Hongpeng Guo
---
python/sglang/bench_one_batch.py | 1 +
python/sglang/srt/layers/sampler.py | 10 ++++
python/sglang/srt/managers/schedule_batch.py | 6 +++
python/sglang/srt/managers/scheduler.py | 2 +
.../srt/sampling/sampling_batch_info.py | 53 ++++++++++++-------
test/srt/test_srt_endpoint.py | 35 ++++++++----
6 files changed, 79 insertions(+), 28 deletions(-)
diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py
index 473f478ad5c..e01919399b5 100644
--- a/python/sglang/bench_one_batch.py
+++ b/python/sglang/bench_one_batch.py
@@ -232,6 +232,7 @@ def extend(reqs, model_runner):
model_config=model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
+ enable_custom_logit_processor=False,
)
batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch()
diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py
index e8b25da0704..ebaa1aa0e7e 100644
--- a/python/sglang/srt/layers/sampler.py
+++ b/python/sglang/srt/layers/sampler.py
@@ -132,6 +132,11 @@ def _apply_custom_logit_processor(
"""Apply custom logit processors to the logits.
This function will modify the logits in-place."""
+ assert logits.shape[0] == len(sampling_batch_info), (
+ f"The batch size of logits ({logits.shape[0]}) does not match the batch size of "
+ f"sampling_batch_info ({len(sampling_batch_info)})"
+ )
+
for _, (
processor,
batch_mask,
@@ -139,6 +144,11 @@ def _apply_custom_logit_processor(
# Get the batch indices that need to be processed
batch_indices = batch_mask.nonzero(as_tuple=True)[0]
+ assert batch_mask.shape[0] == len(sampling_batch_info), (
+ f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of "
+ f"sampling_batch_info ({len(sampling_batch_info)})"
+ )
+
# Apply the processor to the logits
logits[batch_mask] = processor(
logits[batch_mask],
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index a09810a3871..040afe3d324 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -595,6 +595,9 @@ class ScheduleBatch:
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[SpecInfo] = None
+ # Enable custom logit processor
+ enable_custom_logit_processor: bool = False
+
@classmethod
def init_new(
cls,
@@ -605,6 +608,7 @@ def init_new(
model_config: ModelConfig,
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
+ enable_custom_logit_processor: bool,
):
return cls(
reqs=reqs,
@@ -618,6 +622,7 @@ def init_new(
has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
+ enable_custom_logit_processor=enable_custom_logit_processor,
)
def batch_size(self):
@@ -1201,6 +1206,7 @@ def copy(self):
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
+ enable_custom_logit_processor=self.enable_custom_logit_processor,
)
def __str__(self):
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 416abe21cd3..fba8a67ecf4 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -966,6 +966,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
self.model_config,
self.enable_overlap,
self.spec_algorithm,
+ self.server_args.enable_custom_logit_processor,
)
new_batch.prepare_for_extend()
@@ -1520,6 +1521,7 @@ def get_idle_batch(self):
self.model_config,
self.enable_overlap,
self.spec_algorithm,
+ self.server_args.enable_custom_logit_processor,
)
idle_batch.prepare_for_idle()
return idle_batch
diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py
index d4c5c32386a..a27ff1ad2a3 100644
--- a/python/sglang/srt/sampling/sampling_batch_info.py
+++ b/python/sglang/srt/sampling/sampling_batch_info.py
@@ -89,7 +89,10 @@ def from_schedule_batch(
).to(device, non_blocking=True)
# Check if any request has custom logit processor
- has_custom_logit_processor = any(r.custom_logit_processor for r in reqs)
+ has_custom_logit_processor = (
+ batch.enable_custom_logit_processor # check the flag first.
+ and any(r.custom_logit_processor for r in reqs) # then check the requests.
+ )
if has_custom_logit_processor:
# Merge the same type of custom logit processors together
@@ -247,8 +250,7 @@ def _filter_batch_custom_logit_processor(
self, unfinished_indices: List[int], new_indices: torch.Tensor
):
"""Filter the custom logit processor and custom params"""
- if not self.custom_logit_processor:
- return
+
self.custom_logit_processor = {
k: (p, mask[new_indices])
for k, (p, mask) in self.custom_logit_processor.items()
@@ -258,7 +260,9 @@ def _filter_batch_custom_logit_processor(
}
self.custom_params = [self.custom_params[i] for i in unfinished_indices]
- if len(self) == 0:
+ # If the custom logit processor is an empty dict, set the flag to False,
+ # and set the custom logit processor and custom params to None.
+ if len(self.custom_logit_processor) == 0:
self.custom_logit_processor = None
self.custom_params = None
self.has_custom_logit_processor = False
@@ -290,8 +294,8 @@ def merge_bias_tensor(
@staticmethod
def merge_custom_logit_processor(
- lhs: Optional[Dict[str, torch.Tensor]],
- rhs: Optional[Dict[str, torch.Tensor]],
+ lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
+ rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
bs1: int,
bs2: int,
device: str,
@@ -319,27 +323,22 @@ def merge_custom_logit_processor(
)
merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
+ assert merged_dict[k][1].shape[0] == bs1 + bs2, (
+ f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
+ f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
+ f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
+ f"\n{lhs=}\n{rhs=}"
+ )
+
return merged_dict
def merge_batch(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
- for item in [
- "temperatures",
- "top_ps",
- "top_ks",
- "min_ps",
- ]:
- self_val = getattr(self, item, None)
- other_val = getattr(other, item, None)
- setattr(self, item, torch.concat([self_val, other_val]))
-
- self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
+ # Merge the logit bias tensor
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
- self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
-
# Merge the custom logit processors and custom params lists
if self.has_custom_logit_processor or other.has_custom_logit_processor:
# Merge the custom logit processors
@@ -360,6 +359,22 @@ def merge_batch(self, other: "SamplingBatchInfo"):
# Set the flag to True if any of the two has custom logit processor
self.has_custom_logit_processor = True
+ # Note: becasue the __len()__ operator is defined on the temperatures tensor,
+ # please make sure any merge operation with len(self) or len(other) is done before
+ # the merge operation of the temperatures tensor below.
+ for item in [
+ "temperatures",
+ "top_ps",
+ "top_ks",
+ "min_ps",
+ ]:
+ self_val = getattr(self, item, None)
+ other_val = getattr(other, item, None)
+ setattr(self, item, torch.concat([self_val, other_val]))
+
+ self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
+ self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
+
def apply_logits_bias(self, logits: torch.Tensor):
# Apply logit_bias
if self.logit_bias is not None:
diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py
index cddd75fa6d6..7c57c13e251 100644
--- a/test/srt/test_srt_endpoint.py
+++ b/test/srt/test_srt_endpoint.py
@@ -4,8 +4,10 @@
"""
import json
+import random
import unittest
from concurrent.futures import ThreadPoolExecutor
+from typing import Optional
import numpy as np
import requests
@@ -253,8 +255,11 @@ def test_logprob_grammar(self):
self.assertTrue(all(x is not None for x in logprobs))
- def run_custom_logit_processor(self, target_token_id: int):
- """Test custom logit processor with custom params."""
+ def run_custom_logit_processor(self, target_token_id: Optional[int] = None):
+ """Test custom logit processor with custom params.
+
+ If target_token_id is None, the custom logit processor won't be passed in.
+ """
custom_params = {"token_id": target_token_id}
@@ -285,8 +290,12 @@ def __call__(self, logits, custom_param_list):
# Custom json data with custom logit processor and params.
custom_json = base_json.copy()
- custom_json["custom_logit_processor"] = DeterministicLogitProcessor().to_str()
- custom_json["sampling_params"]["custom_params"] = custom_params
+ # Only set the custom logit processor if target_token_id is not None.
+ if target_token_id is not None:
+ custom_json["custom_logit_processor"] = (
+ DeterministicLogitProcessor().to_str()
+ )
+ custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post(
self.base_url + "/generate",
@@ -297,22 +306,30 @@ def __call__(self, logits, custom_param_list):
sampled_tokens = [x[1] for x in output_token_logprobs]
# The logit processor should always sample the given token as the logits is deterministic.
- self.assertTrue(all(x == custom_params["token_id"] for x in sampled_tokens))
+ if target_token_id is not None:
+ self.assertTrue(
+ all(x == custom_params["token_id"] for x in sampled_tokens),
+ # Print the detailed test case info if the test fails.
+ f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
+ )
def test_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
- # Temporarily skipped due to buggy implementation
- return
self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch(self):
"""Test custom logit processor with a batch of requests."""
- # Temporarily skipped due to buggy implementation
- return
target_token_ids = list(range(32))
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
+ def test_custom_logit_processor_batch_mixed(self):
+ """Test a batch of requests mixed of requests with and without custom logit processor."""
+ target_token_ids = list(range(32)) + [None] * 16
+ random.shuffle(target_token_ids)
+ with ThreadPoolExecutor(len(target_token_ids)) as executor:
+ list(executor.map(self.run_custom_logit_processor, target_token_ids))
+
def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info")
response_json = response.json()
From 10bfce71b35300b61cb9016a544eb79d61352f77 Mon Sep 17 00:00:00 2001
From: yiakwy-xpu-ml-framework-team
<89890040+yiakwy-xpu-ml-framework-team@users.noreply.github.com>
Date: Mon, 20 Jan 2025 19:33:29 +0800
Subject: [PATCH 030/147] fix moe align blocks benchmark (#3003)
---
.../benchmark_deepseekv3_moe_align_blocks.py | 36 +++++++++++++++----
1 file changed, 30 insertions(+), 6 deletions(-)
diff --git a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
index 0a6049a1200..d00f4985ad2 100644
--- a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
+++ b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
@@ -7,6 +7,8 @@
import triton.language as tl
from sgl_kernel import moe_align_block_size
+USE_RANDOM_PERM = False
+
def ceil_div(a, b):
return (a + b - 1) // b
@@ -141,8 +143,13 @@ def moe_align_block_size_triton(
def calculate_diff(batch_size, seq_len):
num_experts = 256
block_size = 128
- topk_ids = torch.randint(
- 0, num_experts, (batch_size, seq_len), dtype=torch.int32, device="cuda"
+ topk = 8
+
+ topk_ids = torch.stack(
+ [
+ torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
+ for _ in range(batch_size * seq_len)
+ ]
)
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
@@ -169,7 +176,7 @@ def calculate_diff(batch_size, seq_len):
expert_ids_triton = torch.empty_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
- # 运行两个实现
+ # compare the performance of cuda and triton implementation
moe_align_block_size(
topk_ids,
num_experts,
@@ -206,6 +213,15 @@ def calculate_diff(batch_size, seq_len):
configs = list(itertools.product(batch_size_range, seq_length_range))
+def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
+ topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda")
+ for i in range(num_tokens):
+ topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[
+ :topk
+ ]
+ return topk_ids
+
+
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
@@ -223,9 +239,17 @@ def benchmark(batch_size, seq_len, provider):
num_experts = 256
block_size = 128
topk = 8
- topk_ids = torch.randint(
- 0, num_experts, (batch_size * seq_len, topk), dtype=torch.int32, device="cuda"
- )
+
+ if USE_RANDOM_PERM:
+ topk_ids = get_topk_ids(batch_size * seq_len, num_experts, topk)
+ else:
+ topk_ids = torch.randint(
+ 0,
+ num_experts,
+ (batch_size * seq_len, topk),
+ dtype=torch.int32,
+ device="cuda",
+ )
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty(
From dc1881326f61734a4160620b6e12a5542b756066 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Mon, 20 Jan 2025 03:39:49 -0800
Subject: [PATCH 031/147] Fix perf regression on small batch sizes (#3008)
---
python/sglang/srt/layers/radix_attention.py | 4 ++--
python/sglang/srt/mem_cache/memory_pool.py | 14 +++++++++-----
2 files changed, 11 insertions(+), 7 deletions(-)
diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py
index a449d7188a4..0d46e7bba9a 100644
--- a/python/sglang/srt/layers/radix_attention.py
+++ b/python/sglang/srt/layers/radix_attention.py
@@ -47,8 +47,8 @@ def __init__(
self.logit_cap = logit_cap
self.sliding_window_size = sliding_window_size or -1
self.is_cross_attention = is_cross_attention
- self.k_scale = 1.0
- self.v_scale = 1.0
+ self.k_scale = None
+ self.v_scale = None
def forward(
self,
diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py
index e307367223a..7b9b35611d8 100644
--- a/python/sglang/srt/mem_cache/memory_pool.py
+++ b/python/sglang/srt/mem_cache/memory_pool.py
@@ -27,7 +27,7 @@
import threading
from enum import IntEnum
from functools import wraps
-from typing import List, Tuple, Union
+from typing import List, Optional, Tuple, Union
import numpy as np
import psutil
@@ -270,13 +270,17 @@ def set_kv_buffer(
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
- k_scale: float = 1.0,
- v_scale: float = 1.0,
+ k_scale: Optional[float] = None,
+ v_scale: Optional[float] = None,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
- cache_k = (cache_k / k_scale).to(self.dtype)
- cache_v = (cache_v / v_scale).to(self.dtype)
+ if k_scale is not None:
+ cache_k.div_(k_scale)
+ if v_scale is not None:
+ cache_v.div_(v_scale)
+ cache_k = cache_k.to(self.dtype)
+ cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
From 89cd923581fec16d70ed536eceac7212dc6e0898 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Mon, 20 Jan 2025 04:03:15 -0800
Subject: [PATCH 032/147] Roll back to use vllm custom allreduce (#3006)
---
python/sglang/srt/_custom_ops.py | 2 +-
python/sglang/srt/distributed/__init__.py | 6 +-
.../srt/distributed/communication_op.py | 2 +-
.../custom_all_reduce_utils.py | 1 -
.../device_communicators/pynccl_wrapper.py | 2 +-
.../device_communicators/shm_broadcast.py | 2 +-
python/sglang/srt/layers/attention/vision.py | 4 +-
.../srt/model_executor/cuda_graph_runner.py | 3 -
.../sglang/srt/model_executor/model_runner.py | 5 +-
python/sglang/srt/utils.py | 56 ++-----------------
10 files changed, 18 insertions(+), 65 deletions(-)
diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py
index 3c00a8552ff..3cb313b9133 100644
--- a/python/sglang/srt/_custom_ops.py
+++ b/python/sglang/srt/_custom_ops.py
@@ -12,7 +12,7 @@
from sglang.srt.utils import is_hpu
logger = logging.getLogger(__name__)
-use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=False)
+use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
if not is_hpu():
if use_vllm_custom_allreduce:
diff --git a/python/sglang/srt/distributed/__init__.py b/python/sglang/srt/distributed/__init__.py
index db325cfabf5..12f802055c5 100644
--- a/python/sglang/srt/distributed/__init__.py
+++ b/python/sglang/srt/distributed/__init__.py
@@ -1,3 +1,3 @@
-from .communication_op import *
-from .parallel_state import *
-from .utils import *
+from sglang.srt.distributed.communication_op import *
+from sglang.srt.distributed.parallel_state import *
+from sglang.srt.distributed.utils import *
diff --git a/python/sglang/srt/distributed/communication_op.py b/python/sglang/srt/distributed/communication_op.py
index ddf3b8ef568..7895508cd09 100644
--- a/python/sglang/srt/distributed/communication_op.py
+++ b/python/sglang/srt/distributed/communication_op.py
@@ -4,7 +4,7 @@
import torch
import torch.distributed
-from .parallel_state import get_tp_group
+from sglang.srt.distributed.parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
index d807dfd5ce5..64cf9a78d83 100644
--- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
+++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
@@ -7,7 +7,6 @@
import subprocess
import sys
import tempfile
-from functools import lru_cache
from itertools import product
from typing import Dict, List, Optional, Sequence
diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
index e72284f5117..a2eacd741f8 100644
--- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
+++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
@@ -57,7 +57,7 @@ def find_nccl_library() -> str:
so_file = "librccl.so.1"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
- logger.info("Found nccl from library %s", so_file)
+ logger.debug("Found nccl from library %s", so_file)
return so_file
diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py
index 1afe6fca526..c9f329fb274 100644
--- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py
+++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py
@@ -313,7 +313,7 @@ def __init__(
remote_subscribe_port=remote_subscribe_port,
)
- logger.info("vLLM message queue communication handle: %s", self.handle)
+ logger.debug("Message queue communication handle: %s", self.handle)
def export_handle(self) -> Handle:
return self.handle
diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py
index f66456b0437..4fcfaad5625 100644
--- a/python/sglang/srt/layers/attention/vision.py
+++ b/python/sglang/srt/layers/attention/vision.py
@@ -5,9 +5,9 @@
import torch
import torch.nn as nn
from einops import rearrange, repeat
-from vllm.distributed import parallel_state
-from vllm.distributed import utils as dist_utils
+from sglang.srt.distributed import parallel_state
+from sglang.srt.distributed import utils as dist_utils
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py
index 9fdf7a8ac78..762dac140fb 100644
--- a/python/sglang/srt/model_executor/cuda_graph_runner.py
+++ b/python/sglang/srt/model_executor/cuda_graph_runner.py
@@ -33,7 +33,6 @@
ForwardBatch,
ForwardMode,
)
-from sglang.srt.utils import monkey_patch_vllm_all_gather
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
@@ -72,7 +71,6 @@ def patch_model(
try:
if enable_compile:
_to_torch(model, reverse=False, batch_size=batch_size)
- monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm
# Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch,
@@ -88,7 +86,6 @@ def patch_model(
finally:
if enable_compile:
_to_torch(model, reverse=True, batch_size=batch_size)
- monkey_patch_vllm_all_gather(reverse=True)
tp_group.ca_comm = backup_ca_comm
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 46920d92249..d5cdcf2beb0 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -63,8 +63,8 @@
init_custom_process_group,
is_cuda,
is_hip,
+ monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config,
- monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
)
@@ -229,7 +229,8 @@ def init_torch_distributed(self):
backend = "gloo"
if not self.server_args.enable_p2p_check:
- monkey_patch_vllm_p2p_access_check(self.gpu_id)
+ monkey_patch_p2p_access_check()
+
if self.server_args.dist_init_addr:
dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
else:
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
index c67b6635b30..cf74f1d0f08 100644
--- a/python/sglang/srt/utils.py
+++ b/python/sglang/srt/utils.py
@@ -518,68 +518,24 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
pass
-def monkey_patch_vllm_p2p_access_check(gpu_id: int):
+def monkey_patch_p2p_access_check():
"""
- Monkey patch the slow p2p access check in vllm.
+ Monkey patch the slow p2p access check.
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
"""
- import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt
+ import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
# Suppress the warnings from this delete function when using sglang.bench_one_batch
- from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
+ from sglang.srt.distributed.device_communicators.custom_all_reduce import (
+ CustomAllreduce,
+ )
setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)
-vllm_all_gather_backup = None
-
-
-def monkey_patch_vllm_all_gather(reverse: bool = False):
- """Monkey patch all-gather to remove in-place operations."""
- from torch.distributed import _functional_collectives as funcol
- from vllm.distributed.parallel_state import GroupCoordinator
-
- global vllm_all_gather_backup
- if vllm_all_gather_backup is None:
- vllm_all_gather_backup = GroupCoordinator.all_gather
-
- def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
- world_size = self.world_size
- # Bypass the function if we are using only 1 GPU.
- if world_size == 1:
- return input_
- assert (
- -input_.dim() <= dim < input_.dim()
- ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
- if dim < 0:
- # Convert negative dim to positive.
- dim += input_.dim()
- input_size = input_.size()
- # Allocate output tensor.
- output_tensor = torch.empty(
- (world_size,) + input_size, dtype=input_.dtype, device=input_.device
- )
-
- output_tensor = funcol.all_gather_tensor(
- input_, gather_dim=0, group=self.device_group
- ).view((world_size,) + input_size)
-
- # Reshape
- output_tensor = output_tensor.movedim(0, dim)
- output_tensor = output_tensor.reshape(
- input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
- )
- return output_tensor
-
- if reverse:
- setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup)
- else:
- setattr(GroupCoordinator, "all_gather", all_gather)
-
-
def monkey_patch_vllm_gguf_config():
from vllm.model_executor.layers.quantization.gguf import (
GGUFConfig,
From 73401fd0161caef9681e34f36dfead3134edd549 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Mon, 20 Jan 2025 04:57:14 -0800
Subject: [PATCH 033/147] Sync distributed package from vllm 0.6.4.post1
(#3010)
---
.../srt/distributed/communication_op.py | 5 +-
.../device_communicators/__init__.py | 0
.../device_communicators/cuda_wrapper.py | 3 +-
.../device_communicators/custom_all_reduce.py | 3 +-
.../custom_all_reduce_utils.py | 3 +-
.../device_communicators/hpu_communicator.py | 3 +-
.../device_communicators/pynccl.py | 81 ++++++++++++-
.../device_communicators/pynccl_wrapper.py | 112 +++++++++++++++++-
.../device_communicators/shm_broadcast.py | 75 +-----------
.../device_communicators/xpu_communicator.py | 3 +-
.../sglang/srt/distributed/parallel_state.py | 2 +-
python/sglang/srt/distributed/utils.py | 3 +-
python/sglang/srt/server_args.py | 4 +-
python/sglang/srt/utils.py | 75 ++++++++++--
python/sglang/utils.py | 1 -
15 files changed, 280 insertions(+), 93 deletions(-)
delete mode 100644 python/sglang/srt/distributed/device_communicators/__init__.py
diff --git a/python/sglang/srt/distributed/communication_op.py b/python/sglang/srt/distributed/communication_op.py
index 7895508cd09..95600edfb41 100644
--- a/python/sglang/srt/distributed/communication_op.py
+++ b/python/sglang/srt/distributed/communication_op.py
@@ -1,10 +1,11 @@
-# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py
+
from typing import Any, Dict, Optional, Union
import torch
import torch.distributed
-from sglang.srt.distributed.parallel_state import get_tp_group
+from .parallel_state import get_tp_group
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
diff --git a/python/sglang/srt/distributed/device_communicators/__init__.py b/python/sglang/srt/distributed/device_communicators/__init__.py
deleted file mode 100644
index e69de29bb2d..00000000000
diff --git a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py
index ab4ee33fcfc..c902f314112 100644
--- a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py
+++ b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py
@@ -1,4 +1,5 @@
-# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py
+
"""This file is a pure Python wrapper for the cudart library.
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
index d4506b9f04c..c3cbc41fe63 100644
--- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
+++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py
@@ -1,4 +1,5 @@
-# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py
+
import ctypes
import logging
import os
diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
index 64cf9a78d83..4073491aa62 100644
--- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
+++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py
@@ -1,4 +1,5 @@
-# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py
+
import ctypes
import json
import logging
diff --git a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py
index 72ef3889e01..722e494cf77 100644
--- a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py
+++ b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py
@@ -1,4 +1,5 @@
-# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py
+
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py
index baee270da90..9f65939f6d9 100644
--- a/python/sglang/srt/distributed/device_communicators/pynccl.py
+++ b/python/sglang/srt/distributed/device_communicators/pynccl.py
@@ -1,8 +1,10 @@
-# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
+
import logging
from contextlib import contextmanager
from typing import Optional, Union
+# ===================== import region =====================
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup, ReduceOp
@@ -143,6 +145,57 @@ def all_reduce(
cudaStream_t(stream.cuda_stream),
)
+ def all_gather(
+ self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
+ ):
+ if self.disabled:
+ return
+ # nccl communicator created on a specific device
+ # will only work on tensors on the same device
+ # otherwise it will cause "illegal memory access"
+ assert input_tensor.device == self.device, (
+ f"this nccl communicator is created to work on {self.device}, "
+ f"but the input tensor is on {input_tensor.device}"
+ )
+ if stream is None:
+ stream = self.stream
+ self.nccl.ncclAllGather(
+ buffer_type(input_tensor.data_ptr()),
+ buffer_type(output_tensor.data_ptr()),
+ input_tensor.numel(),
+ ncclDataTypeEnum.from_torch(input_tensor.dtype),
+ self.comm,
+ cudaStream_t(stream.cuda_stream),
+ )
+
+ def reduce_scatter(
+ self,
+ output_tensor: torch.Tensor,
+ input_tensor: torch.Tensor,
+ op: ReduceOp = ReduceOp.SUM,
+ stream=None,
+ ):
+ if self.disabled:
+ return
+ # nccl communicator created on a specific device
+ # will only work on tensors on the same device
+ # otherwise it will cause "illegal memory access"
+ assert input_tensor.device == self.device, (
+ f"this nccl communicator is created to work on {self.device}, "
+ f"but the input tensor is on {input_tensor.device}"
+ )
+ if stream is None:
+ stream = self.stream
+ self.nccl.ncclReduceScatter(
+ buffer_type(input_tensor.data_ptr()),
+ buffer_type(output_tensor.data_ptr()),
+ output_tensor.numel(),
+ ncclDataTypeEnum.from_torch(input_tensor.dtype),
+ ncclRedOpTypeEnum.from_torch(op),
+ self.comm,
+ cudaStream_t(stream.cuda_stream),
+ )
+
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
@@ -179,6 +232,32 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None):
cudaStream_t(stream.cuda_stream),
)
+ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
+ if self.disabled:
+ return
+ assert tensor.device == self.device, (
+ f"this nccl communicator is created to work on {self.device}, "
+ f"but the input tensor is on {tensor.device}"
+ )
+ if stream is None:
+ stream = self.stream
+ if src == self.rank:
+ sendbuff = buffer_type(tensor.data_ptr())
+ # NCCL requires the sender also to have a receive buffer
+ recvbuff = buffer_type(tensor.data_ptr())
+ else:
+ sendbuff = buffer_type()
+ recvbuff = buffer_type(tensor.data_ptr())
+ self.nccl.ncclBroadcast(
+ sendbuff,
+ recvbuff,
+ tensor.numel(),
+ ncclDataTypeEnum.from_torch(tensor.dtype),
+ src,
+ self.comm,
+ cudaStream_t(stream.cuda_stream),
+ )
+
@contextmanager
def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
index a2eacd741f8..afb47733476 100644
--- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
+++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
@@ -1,4 +1,4 @@
-# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
@@ -187,6 +187,43 @@ class NCCLLibrary:
cudaStream_t,
],
),
+ # ncclResult_t ncclAllGather(
+ # const void* sendbuff, void* recvbuff, size_t count,
+ # ncclDataType_t datatype, ncclComm_t comm,
+ # cudaStream_t stream);
+ # note that cudaStream_t is a pointer type, so the last argument
+ # is a pointer
+ Function(
+ "ncclAllGather",
+ ncclResult_t,
+ [
+ buffer_type,
+ buffer_type,
+ ctypes.c_size_t,
+ ncclDataType_t,
+ ncclComm_t,
+ cudaStream_t,
+ ],
+ ),
+ # ncclResult_t ncclReduceScatter(
+ # const void* sendbuff, void* recvbuff, size_t count,
+ # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
+ # cudaStream_t stream);
+ # note that cudaStream_t is a pointer type, so the last argument
+ # is a pointer
+ Function(
+ "ncclReduceScatter",
+ ncclResult_t,
+ [
+ buffer_type,
+ buffer_type,
+ ctypes.c_size_t,
+ ncclDataType_t,
+ ncclRedOp_t,
+ ncclComm_t,
+ cudaStream_t,
+ ],
+ ),
# ncclResult_t ncclSend(
# const void* sendbuff, size_t count, ncclDataType_t datatype,
# int dest, ncclComm_t comm, cudaStream_t stream);
@@ -217,6 +254,23 @@ class NCCLLibrary:
cudaStream_t,
],
),
+ # ncclResult_t ncclBroadcast(
+ # const void* sendbuff, void* recvbuff, size_t count,
+ # ncclDataType_t datatype, int root, ncclComm_t comm,
+ # cudaStream_t stream);
+ Function(
+ "ncclBroadcast",
+ ncclResult_t,
+ [
+ buffer_type,
+ buffer_type,
+ ctypes.c_size_t,
+ ncclDataType_t,
+ ctypes.c_int,
+ ncclComm_t,
+ cudaStream_t,
+ ],
+ ),
# be cautious! this is a collective call, it will block until all
# processes in the communicator have called this function.
# because Python object destruction can happen in random order,
@@ -321,6 +375,46 @@ def ncclAllReduce(
)
)
+ def ncclReduceScatter(
+ self,
+ sendbuff: buffer_type,
+ recvbuff: buffer_type,
+ count: int,
+ datatype: int,
+ op: int,
+ comm: ncclComm_t,
+ stream: cudaStream_t,
+ ) -> None:
+ # `datatype` actually should be `ncclDataType_t`
+ # and `op` should be `ncclRedOp_t`
+ # both are aliases of `ctypes.c_int`
+ # when we pass int to a function, it will be converted to `ctypes.c_int`
+ # by ctypes automatically
+ self.NCCL_CHECK(
+ self._funcs["ncclReduceScatter"](
+ sendbuff, recvbuff, count, datatype, op, comm, stream
+ )
+ )
+
+ def ncclAllGather(
+ self,
+ sendbuff: buffer_type,
+ recvbuff: buffer_type,
+ count: int,
+ datatype: int,
+ comm: ncclComm_t,
+ stream: cudaStream_t,
+ ) -> None:
+ # `datatype` actually should be `ncclDataType_t`
+ # which is an aliases of `ctypes.c_int`
+ # when we pass int to a function, it will be converted to `ctypes.c_int`
+ # by ctypes automatically
+ self.NCCL_CHECK(
+ self._funcs["ncclAllGather"](
+ sendbuff, recvbuff, count, datatype, comm, stream
+ )
+ )
+
def ncclSend(
self,
sendbuff: buffer_type,
@@ -347,6 +441,22 @@ def ncclRecv(
self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)
)
+ def ncclBroadcast(
+ self,
+ sendbuff: buffer_type,
+ recvbuff: buffer_type,
+ count: int,
+ datatype: int,
+ root: int,
+ comm: ncclComm_t,
+ stream: cudaStream_t,
+ ) -> None:
+ self.NCCL_CHECK(
+ self._funcs["ncclBroadcast"](
+ sendbuff, recvbuff, count, datatype, root, comm, stream
+ )
+ )
+
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py
index c9f329fb274..7a3b22e27a8 100644
--- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py
+++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py
@@ -1,11 +1,9 @@
-# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py
-import ipaddress
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/shm_broadcast.py
+
import logging
import os
import pickle
-import socket
import time
-import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from multiprocessing import shared_memory
@@ -18,6 +16,8 @@
from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
+from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address
+
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60")
@@ -26,73 +26,6 @@
logger = logging.getLogger(__name__)
-def get_ip() -> str:
- # SGLANG_HOST_IP env can be ignore
- host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
- if host_ip:
- return host_ip
-
- # IP is not set, try to get it from the network interface
-
- # try ipv4
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- try:
- s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
- return s.getsockname()[0]
- except Exception:
- pass
-
- # try ipv6
- try:
- s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
- # Google's public DNS server, see
- # https://developers.google.com/speed/public-dns/docs/using#addresses
- s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
- return s.getsockname()[0]
- except Exception:
- pass
-
- warnings.warn(
- "Failed to get the IP address, using 0.0.0.0 by default."
- "The value can be set by the environment variable"
- " SGLANG_HOST_IP or HOST_IP.",
- stacklevel=2,
- )
- return "0.0.0.0"
-
-
-def get_open_port() -> int:
-
- port = os.getenv("SGLANG_PORT")
- if port is not None:
- while True:
- try:
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(("", port))
- return port
- except OSError:
- port += 1 # Increment port number if already in use
- logger.info("Port %d is already in use, trying port %d", port - 1, port)
- # try ipv4
- try:
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(("", 0))
- return s.getsockname()[1]
- except OSError:
- # try ipv6
- with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
- s.bind(("", 0))
- return s.getsockname()[1]
-
-
-def is_valid_ipv6_address(address: str) -> bool:
- try:
- ipaddress.IPv6Address(address)
- return True
- except ValueError:
- return False
-
-
class ShmRingBuffer:
def __init__(
diff --git a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py
index ff0981b80bc..532279f70c3 100644
--- a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py
+++ b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py
@@ -1,4 +1,5 @@
-# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/xpu_communicator.py
+
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py
index 26d04b04ce9..c6d1a830781 100644
--- a/python/sglang/srt/distributed/parallel_state.py
+++ b/python/sglang/srt/distributed/parallel_state.py
@@ -1,4 +1,4 @@
-# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/parallel_state.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/parallel_state.py
# Copyright 2023 The vLLM team.
# Adapted from
diff --git a/python/sglang/srt/distributed/utils.py b/python/sglang/srt/distributed/utils.py
index a225fbb9182..e117aa30d07 100644
--- a/python/sglang/srt/distributed/utils.py
+++ b/python/sglang/srt/distributed/utils.py
@@ -1,4 +1,5 @@
-# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py
+# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/utils.py
+
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 6dd0b945654..4a7a28751db 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -29,8 +29,8 @@
get_nvgpu_memory_capacity,
is_flashinfer_available,
is_hip,
- is_ipv6,
is_port_available,
+ is_valid_ipv6_address,
nullable_str,
)
@@ -883,7 +883,7 @@ def from_cli_args(cls, args: argparse.Namespace):
return cls(**{attr: getattr(args, attr) for attr in attrs})
def url(self):
- if is_ipv6(self.host):
+ if is_valid_ipv6_address(self.host):
return f"http://[{self.host}]:{self.port}"
else:
return f"http://{self.host}:{self.port}"
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
index cf74f1d0f08..4614114b41d 100644
--- a/python/sglang/srt/utils.py
+++ b/python/sglang/srt/utils.py
@@ -102,14 +102,6 @@ def is_cuda_available():
return torch.cuda.is_available() and torch.version.cuda
-def is_ipv6(address):
- try:
- ipaddress.IPv6Address(address)
- return True
- except ipaddress.AddressValueError:
- return False
-
-
def enable_show_time_cost():
global show_time_cost
show_time_cost = True
@@ -1383,3 +1375,70 @@ def set_uvicorn_logging_configs():
"fmt"
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
+
+
+def get_ip() -> str:
+ # SGLANG_HOST_IP env can be ignore
+ host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
+ if host_ip:
+ return host_ip
+
+ # IP is not set, try to get it from the network interface
+
+ # try ipv4
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ try:
+ s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
+ return s.getsockname()[0]
+ except Exception:
+ pass
+
+ # try ipv6
+ try:
+ s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+ # Google's public DNS server, see
+ # https://developers.google.com/speed/public-dns/docs/using#addresses
+ s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
+ return s.getsockname()[0]
+ except Exception:
+ pass
+
+ warnings.warn(
+ "Failed to get the IP address, using 0.0.0.0 by default."
+ "The value can be set by the environment variable"
+ " SGLANG_HOST_IP or HOST_IP.",
+ stacklevel=2,
+ )
+ return "0.0.0.0"
+
+
+def get_open_port() -> int:
+
+ port = os.getenv("SGLANG_PORT")
+ if port is not None:
+ while True:
+ try:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", port))
+ return port
+ except OSError:
+ port += 1 # Increment port number if already in use
+ logger.info("Port %d is already in use, trying port %d", port - 1, port)
+ # try ipv4
+ try:
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+ except OSError:
+ # try ipv6
+ with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
+ s.bind(("", 0))
+ return s.getsockname()[1]
+
+
+def is_valid_ipv6_address(address: str) -> bool:
+ try:
+ ipaddress.IPv6Address(address)
+ return True
+ except ValueError:
+ return False
diff --git a/python/sglang/utils.py b/python/sglang/utils.py
index 98942fbb39c..742eebc3bc9 100644
--- a/python/sglang/utils.py
+++ b/python/sglang/utils.py
@@ -1,7 +1,6 @@
"""Common utilities"""
import base64
-import gc
import importlib
import json
import logging
From b5caa22dfbdada1753011ef26d44b3da6028d2ad Mon Sep 17 00:00:00 2001
From: Byron Hsu
Date: Mon, 20 Jan 2025 04:58:51 -0800
Subject: [PATCH 034/147] [kernel] port rope cuda kernel to sgl-kernel (#2993)
Co-authored-by: Yineng Zhang
---
.gitignore | 3 +
sgl-kernel/pyproject.toml | 2 +-
sgl-kernel/setup.py | 1 +
sgl-kernel/src/sgl-kernel/__init__.py | 2 +
.../src/sgl-kernel/csrc/rotary_embedding.cu | 119 ++++++++++++++++++
.../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 6 +
sgl-kernel/src/sgl-kernel/ops/__init__.py | 5 +
sgl-kernel/tests/test_rotary_embedding.py | 118 +++++++++++++++++
8 files changed, 255 insertions(+), 1 deletion(-)
create mode 100644 sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
create mode 100644 sgl-kernel/tests/test_rotary_embedding.py
diff --git a/.gitignore b/.gitignore
index 73fd52992c2..91966c664b5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -222,3 +222,6 @@ work_dirs/
compile_commands.json
*.iml
+
+# VSCode
+.vscode
diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml
index b0554bd8fed..ab9d68b44c8 100644
--- a/sgl-kernel/pyproject.toml
+++ b/sgl-kernel/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sgl-kernel"
-version = "0.0.2.post14"
+version = "0.0.2.post15"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py
index 33e4abe1b23..25319af7a65 100644
--- a/sgl-kernel/setup.py
+++ b/sgl-kernel/setup.py
@@ -53,6 +53,7 @@ def update_wheel_platform_tag():
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
+ "src/sgl-kernel/csrc/rotary_embedding.cu",
],
include_dirs=include_dirs,
extra_compile_args={
diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py
index 0c744982dd8..480bec71f36 100644
--- a/sgl-kernel/src/sgl-kernel/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/__init__.py
@@ -6,6 +6,7 @@
int8_scaled_mm,
moe_align_block_size,
register_graph_buffers,
+ rotary_embedding,
sampling_scaling_penalties,
)
@@ -18,4 +19,5 @@
"sampling_scaling_penalties",
"get_graph_buffer_ipc_meta",
"register_graph_buffers",
+ "rotary_embedding",
]
diff --git a/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu b/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
new file mode 100644
index 00000000000..1dd4c4c5244
--- /dev/null
+++ b/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
@@ -0,0 +1,119 @@
+// Reference: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
+
+#include
+#include
+#include
+
+template
+inline __device__ void apply_token_rotary_embedding(scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
+ const scalar_t* __restrict__ sin_ptr, int rot_offset,
+ int embed_dim) {
+ int x_index, y_index;
+ scalar_t cos, sin;
+ if (IS_NEOX) {
+ // GPT-NeoX style rotary embedding.
+ x_index = rot_offset;
+ y_index = embed_dim + rot_offset;
+ cos = __ldg(cos_ptr + x_index);
+ sin = __ldg(sin_ptr + x_index);
+ } else {
+ // GPT-J style rotary embedding.
+ x_index = 2 * rot_offset;
+ y_index = 2 * rot_offset + 1;
+ cos = __ldg(cos_ptr + x_index / 2);
+ sin = __ldg(sin_ptr + x_index / 2);
+ }
+
+ const scalar_t x = arr[x_index];
+ const scalar_t y = arr[y_index];
+ arr[x_index] = x * cos - y * sin;
+ arr[y_index] = y * cos + x * sin;
+}
+
+template
+inline __device__ void apply_rotary_embedding(scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
+ // head_size] or [num_tokens, num_heads,
+ // head_size]
+ scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
+ // head_size] or [num_tokens, num_kv_heads,
+ // head_size]
+ const scalar_t* cache_ptr, const int head_size, const int num_heads,
+ const int num_kv_heads, const int rot_dim, const int token_idx,
+ const int64_t query_stride, const int64_t key_stride) {
+ const int embed_dim = rot_dim / 2;
+ const scalar_t* cos_ptr = cache_ptr;
+ const scalar_t* sin_ptr = cache_ptr + embed_dim;
+
+ const int nq = num_heads * embed_dim;
+ for (int i = threadIdx.x; i < nq; i += blockDim.x) {
+ const int head_idx = i / embed_dim;
+ const int64_t token_head = token_idx * query_stride + head_idx * head_size;
+ const int rot_offset = i % embed_dim;
+ apply_token_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
+ }
+
+ const int nk = num_kv_heads * embed_dim;
+ for (int i = threadIdx.x; i < nk; i += blockDim.x) {
+ const int head_idx = i / embed_dim;
+ const int64_t token_head = token_idx * key_stride + head_idx * head_size;
+ const int rot_offset = i % embed_dim;
+ apply_token_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
+ }
+}
+
+template
+__global__ void rotary_embedding_kernel(const int64_t* __restrict__ positions, // [batch_size, seq_len] or
+ // [num_tokens]
+ scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
+ // head_size] or [num_tokens, num_heads,
+ // head_size]
+ scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
+ // head_size] or [num_tokens, num_kv_heads,
+ // head_size]
+ const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
+ // 2]
+ const int rot_dim, const int64_t query_stride, const int64_t key_stride,
+ const int num_heads, const int num_kv_heads, const int head_size) {
+ // Each thread block is responsible for one token.
+ const int token_idx = blockIdx.x;
+ int64_t pos = positions[token_idx];
+ const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
+
+ apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
+ token_idx, query_stride, key_stride);
+}
+
+void rotary_embedding(torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
+ torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
+ // [num_tokens, num_heads * head_size]
+ torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
+ // [num_tokens, num_kv_heads * head_size]
+ int64_t head_size,
+ torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
+ bool is_neox) {
+ int64_t num_tokens = query.numel() / query.size(-1);
+ int rot_dim = cos_sin_cache.size(1);
+ int num_heads = query.size(-1) / head_size;
+ int num_kv_heads = key.size(-1) / head_size;
+ int64_t query_stride = query.stride(-2);
+ int64_t key_stride = key.stride(-2);
+
+ dim3 grid(num_tokens);
+ dim3 block(std::min(num_heads * rot_dim / 2, 512));
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::BFloat16, at::ScalarType::Half, query.scalar_type(), "rotary_embedding", [&] {
+ if (is_neox) {
+ rotary_embedding_kernel
+ <<>>(positions.data_ptr(), query.data_ptr(),
+ key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim,
+ query_stride, key_stride, num_heads, num_kv_heads, head_size);
+ } else {
+ rotary_embedding_kernel
+ <<>>(positions.data_ptr(), query.data_ptr(),
+ key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim,
+ query_stride, key_stride, num_heads, num_kv_heads, head_size);
+ }
+ });
+}
diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
index 99d0326cf07..f2ae95d7f79 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
@@ -26,6 +26,10 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional& bias);
+// rotary embedding
+void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
+ torch::Tensor& cos_sin_cache, bool is_neox);
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
@@ -39,4 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
// int8_scaled_mm
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
+ // rotary embedding
+ m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
}
diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py
index 6b35f78a490..b8abd57d39d 100644
--- a/sgl-kernel/src/sgl-kernel/ops/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py
@@ -7,6 +7,7 @@
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
+from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
from sgl_kernel.ops._kernels import (
sampling_scaling_penalties as _sampling_scaling_penalties,
)
@@ -71,3 +72,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
out_dtype,
bias,
)
+
+
+def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
+ return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
diff --git a/sgl-kernel/tests/test_rotary_embedding.py b/sgl-kernel/tests/test_rotary_embedding.py
new file mode 100644
index 00000000000..1bbe8f1bfeb
--- /dev/null
+++ b/sgl-kernel/tests/test_rotary_embedding.py
@@ -0,0 +1,118 @@
+from typing import Optional, Tuple
+
+import torch
+from vllm.model_executor.layers.rotary_embedding import (
+ RotaryEmbedding as VLLMRotaryEmbedding,
+)
+
+
+class SGLRotaryEmbedding(VLLMRotaryEmbedding):
+
+ def forward_cuda(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ offsets: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ from sgl_kernel import rotary_embedding
+
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
+
+ rotary_embedding(
+ positions,
+ query,
+ key,
+ self.head_size,
+ self.cos_sin_cache,
+ self.is_neox_style,
+ )
+ return query, key
+
+
+# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native
+
+
+def test_rotary_embedding():
+ # Test case 1: FP32
+ def run_test(
+ head_size,
+ rotary_dim,
+ max_position,
+ base,
+ is_neox_style,
+ dtype,
+ batch_size,
+ seq_len,
+ num_heads,
+ test_name,
+ ):
+ print(f"\nRunning {test_name}...")
+ # Initialize both implementations
+ sgl_rope = SGLRotaryEmbedding(
+ head_size, rotary_dim, max_position, base, is_neox_style, dtype
+ ).to("cuda")
+ vllm_rope = VLLMRotaryEmbedding(
+ head_size, rotary_dim, max_position, base, is_neox_style, dtype
+ ).to("cuda")
+
+ # Regular forward pass
+ positions = torch.arange(seq_len, device="cuda").repeat(batch_size)
+ query = torch.randn(
+ batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
+ )
+ key = torch.randn(
+ batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
+ )
+
+ # Make copies for both implementations
+ query_sgl = query.clone()
+ key_sgl = key.clone()
+ query_vllm = query.clone()
+ key_vllm = key.clone()
+
+ # Run both implementations
+ query_sgl_out, key_sgl_out = sgl_rope.forward_cuda(
+ positions, query_sgl, key_sgl
+ )
+ query_vllm_out, key_vllm_out = vllm_rope.forward_native(
+ positions, query_vllm, key_vllm
+ )
+
+ # Compare outputs
+ torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3)
+ torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3)
+
+ print(f"{test_name} passed!")
+
+ # Test Case 1: FP32 with larger dimensions
+ run_test(
+ head_size=128,
+ rotary_dim=64,
+ max_position=4096,
+ base=10000,
+ is_neox_style=True,
+ dtype=torch.float32,
+ batch_size=4,
+ seq_len=32,
+ num_heads=8,
+ test_name="FP32 Test",
+ )
+
+ # Test Case 2: BF16 with smaller dimensions
+ run_test(
+ head_size=64,
+ rotary_dim=32,
+ max_position=2048,
+ base=8000,
+ is_neox_style=True,
+ dtype=torch.bfloat16,
+ batch_size=2,
+ seq_len=16,
+ num_heads=4,
+ test_name="BF16 Test",
+ )
+
+
+if __name__ == "__main__":
+ test_rotary_embedding()
From e94fb7cb1094f1210c0ab92a31bcc848e2c2cf7a Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Mon, 20 Jan 2025 21:50:55 +0800
Subject: [PATCH 035/147] chore: bump v0.4.1.post7 (#3009)
---
docker/Dockerfile.rocm | 2 +-
docs/developer/setup_github_runner.md | 4 ++--
docs/start/install.md | 10 +++++-----
python/pyproject.toml | 2 +-
python/sglang/version.py | 2 +-
5 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm
index 5a6e9770b72..2a55504e612 100644
--- a/docker/Dockerfile.rocm
+++ b/docker/Dockerfile.rocm
@@ -1,5 +1,5 @@
# Usage (to build SGLang ROCm docker image):
-# docker build --build-arg SGL_BRANCH=v0.4.1.post6 -t v0.4.1.post6-rocm620 -f Dockerfile.rocm .
+# docker build --build-arg SGL_BRANCH=v0.4.1.post7 -t v0.4.1.post7-rocm620 -f Dockerfile.rocm .
# default base image
ARG BASE_IMAGE="rocmshared/vllm-rocm:20250114-tuned-elementwise-layernorm"
diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md
index edc03d66183..e805cfce7da 100644
--- a/docs/developer/setup_github_runner.md
+++ b/docs/developer/setup_github_runner.md
@@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04
# Nvidia
docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash
# AMD
-docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.1.post6-rocm620 /bin/bash
+docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.1.post7-rocm620 /bin/bash
# AMD just the last 2 GPUs
-docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.1.post6-rocm620 /bin/bash
+docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.1.post7-rocm620 /bin/bash
```
### Step 2: Configure the runner by `config.sh`
diff --git a/docs/start/install.md b/docs/start/install.md
index 8b84527c4ff..81e2345a673 100644
--- a/docs/start/install.md
+++ b/docs/start/install.md
@@ -13,7 +13,7 @@ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/
## Method 2: From source
```
# Use the last release branch
-git clone -b v0.4.1.post6 https://github.com/sgl-project/sglang.git
+git clone -b v0.4.1.post7 https://github.com/sgl-project/sglang.git
cd sglang
pip install --upgrade pip
@@ -26,7 +26,7 @@ Note: To AMD ROCm system with Instinct/MI GPUs, do following instead:
```
# Use the last release branch
-git clone -b v0.4.1.post6 https://github.com/sgl-project/sglang.git
+git clone -b v0.4.1.post7 https://github.com/sgl-project/sglang.git
cd sglang
pip install --upgrade pip
@@ -51,7 +51,7 @@ docker run --gpus all \
Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below:
```bash
-docker build --build-arg SGL_BRANCH=v0.4.1.post6 -t v0.4.1.post6-rocm620 -f Dockerfile.rocm .
+docker build --build-arg SGL_BRANCH=v0.4.1.post7 -t v0.4.1.post7-rocm620 -f Dockerfile.rocm .
alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \
--shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
@@ -60,11 +60,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d
drun -p 30000:30000 \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HF_TOKEN=" \
- v0.4.1.post6-rocm620 \
+ v0.4.1.post7-rocm620 \
python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000
# Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default
-drun v0.4.1.post6-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8
+drun v0.4.1.post7-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8
```
## Method 4: Using docker compose
diff --git a/python/pyproject.toml b/python/pyproject.toml
index f97c9c26679..80cc0e9dc60 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sglang"
-version = "0.4.1.post6"
+version = "0.4.1.post7"
description = "SGLang is yet another fast serving framework for large language models and vision language models."
readme = "README.md"
requires-python = ">=3.8"
diff --git a/python/sglang/version.py b/python/sglang/version.py
index 3a906dbcfff..18ca924974b 100644
--- a/python/sglang/version.py
+++ b/python/sglang/version.py
@@ -1 +1 @@
-__version__ = "0.4.1.post6"
+__version__ = "0.4.1.post7"
From 41a0ccd4f1714ea57b532d7a5f3abe655db2f04e Mon Sep 17 00:00:00 2001
From: Ke Bao
Date: Mon, 20 Jan 2025 23:22:19 +0800
Subject: [PATCH 036/147] Add clang-format check to sgl-kernel ci (#3012)
---
.github/workflows/pr-test-sgl-kernel.yml | 14 ++++++++++++++
1 file changed, 14 insertions(+)
diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml
index 4115677dcb0..cacf938a330 100644
--- a/.github/workflows/pr-test-sgl-kernel.yml
+++ b/.github/workflows/pr-test-sgl-kernel.yml
@@ -16,6 +16,20 @@ concurrency:
cancel-in-progress: true
jobs:
+ lint:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Check clang-format
+ uses: DoozyX/clang-format-lint-action@v0.18.1
+ with:
+ source: sgl-kernel
+ extensions: h,c,cpp,hpp,cu,cuh,cc
+ clangFormatVersion: 16
+ style: file
+
unit-test:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner
From 5dfcacfcb186ca7d35ce535d0ab3c34df6eff7ea Mon Sep 17 00:00:00 2001
From: Ke Bao
Date: Tue, 21 Jan 2025 00:04:12 +0800
Subject: [PATCH 037/147] Add compile flags for cutlass 3.x (#3013)
Co-authored-by: HandH1998 <1335248067@qq.com>
---
sgl-kernel/setup.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py
index 25319af7a65..9f986711338 100644
--- a/sgl-kernel/setup.py
+++ b/sgl-kernel/setup.py
@@ -30,6 +30,7 @@ def update_wheel_platform_tag():
root / "src" / "sgl-kernel" / "csrc",
]
nvcc_flags = [
+ "-DNDEBUG",
"-O3",
"-Xcompiler",
"-fPIC",
@@ -37,6 +38,7 @@ def update_wheel_platform_tag():
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
+ "-gencode=arch=compute_90a,code=sm_90a",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
]
From 0311ce8e1ccda984f1afe5a90e1208902ed923fc Mon Sep 17 00:00:00 2001
From: Byron Hsu
Date: Mon, 20 Jan 2025 12:45:13 -0800
Subject: [PATCH 038/147] [router] Expose worker startup secs & Return error
instead of panic for router init (#3016)
---
.../py_src/sglang_router/launch_router.py | 20 ++++---
.../py_src/sglang_router/launch_server.py | 32 +++++++++--
sgl-router/py_src/sglang_router/router.py | 3 ++
sgl-router/py_test/test_launch_router.py | 1 +
sgl-router/src/lib.rs | 20 ++++---
sgl-router/src/router.rs | 54 +++++++++++++++----
sgl-router/src/server.rs | 41 +++++++-------
7 files changed, 124 insertions(+), 47 deletions(-)
diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py
index 28cd5d11fbb..384e3666db0 100644
--- a/sgl-router/py_src/sglang_router/launch_router.py
+++ b/sgl-router/py_src/sglang_router/launch_router.py
@@ -33,6 +33,7 @@ class RouterArgs:
# Routing policy
policy: str = "cache_aware"
+ worker_startup_timeout_secs: int = 300
cache_threshold: float = 0.5
balance_abs_threshold: int = 32
balance_rel_threshold: float = 1.0001
@@ -87,6 +88,12 @@ def add_cli_args(
choices=["random", "round_robin", "cache_aware"],
help="Load balancing policy to use",
)
+ parser.add_argument(
+ f"--{prefix}worker-startup-timeout-secs",
+ type=int,
+ default=RouterArgs.worker_startup_timeout_secs,
+ help="Timeout in seconds for worker startup",
+ )
parser.add_argument(
f"--{prefix}cache-threshold",
type=float,
@@ -147,6 +154,9 @@ def from_cli_args(
host=args.host,
port=args.port,
policy=getattr(args, f"{prefix}policy"),
+ worker_startup_timeout_secs=getattr(
+ args, f"{prefix}worker_startup_timeout_secs"
+ ),
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
@@ -188,9 +198,10 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
router = Router(
worker_urls=router_args.worker_urls,
- policy=policy_from_str(router_args.policy),
host=router_args.host,
port=router_args.port,
+ policy=policy_from_str(router_args.policy),
+ worker_startup_timeout_secs=router_args.worker_startup_timeout_secs,
cache_threshold=router_args.cache_threshold,
balance_abs_threshold=router_args.balance_abs_threshold,
balance_rel_threshold=router_args.balance_rel_threshold,
@@ -205,7 +216,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
except Exception as e:
logger.error(f"Error starting router: {e}")
- return None
+ raise e
class CustomHelpFormatter(
@@ -239,10 +250,7 @@ def parse_router_args(args: List[str]) -> RouterArgs:
def main() -> None:
router_args = parse_router_args(sys.argv[1:])
- router = launch_router(router_args)
-
- if router is None:
- sys.exit(1)
+ launch_router(router_args)
if __name__ == "__main__":
diff --git a/sgl-router/py_src/sglang_router/launch_server.py b/sgl-router/py_src/sglang_router/launch_server.py
index 93bc2345d18..74353c21edb 100644
--- a/sgl-router/py_src/sglang_router/launch_server.py
+++ b/sgl-router/py_src/sglang_router/launch_server.py
@@ -68,7 +68,7 @@ def run_server(server_args, dp_rank):
# create new process group
os.setpgrp()
- setproctitle(f"sglang::server")
+ setproctitle("sglang::server")
# Set SGLANG_DP_RANK environment variable
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
@@ -120,9 +120,26 @@ def find_available_ports(base_port: int, count: int) -> List[int]:
def cleanup_processes(processes: List[mp.Process]):
for process in processes:
- logger.info(f"Terminating process {process.pid}")
- process.terminate()
- logger.info("All processes terminated")
+ logger.info(f"Terminating process group {process.pid}")
+ try:
+ os.killpg(process.pid, signal.SIGTERM)
+ except ProcessLookupError:
+ # Process group may already be terminated
+ pass
+
+ # Wait for processes to terminate
+ for process in processes:
+ process.join(timeout=5)
+ if process.is_alive():
+ logger.warning(
+ f"Process {process.pid} did not terminate gracefully, forcing kill"
+ )
+ try:
+ os.killpg(process.pid, signal.SIGKILL)
+ except ProcessLookupError:
+ pass
+
+ logger.info("All process groups terminated")
def main():
@@ -173,7 +190,12 @@ def main():
]
# Start the router
- router = launch_router(router_args)
+ try:
+ launch_router(router_args)
+ except Exception as e:
+ logger.error(f"Failed to start router: {e}")
+ cleanup_processes(server_processes)
+ sys.exit(1)
if __name__ == "__main__":
diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py
index 5ce21c3d78e..1665f8a67be 100644
--- a/sgl-router/py_src/sglang_router/router.py
+++ b/sgl-router/py_src/sglang_router/router.py
@@ -17,6 +17,7 @@ class Router:
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001
+ worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5
@@ -37,6 +38,7 @@ def __init__(
policy: PolicyType = PolicyType.RoundRobin,
host: str = "127.0.0.1",
port: int = 3001,
+ worker_startup_timeout_secs: int = 300,
cache_threshold: float = 0.50,
balance_abs_threshold: int = 32,
balance_rel_threshold: float = 1.0001,
@@ -50,6 +52,7 @@ def __init__(
policy=policy,
host=host,
port=port,
+ worker_startup_timeout_secs=worker_startup_timeout_secs,
cache_threshold=cache_threshold,
balance_abs_threshold=balance_abs_threshold,
balance_rel_threshold=balance_rel_threshold,
diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py
index 94912f69491..15549cae72f 100644
--- a/sgl-router/py_test/test_launch_router.py
+++ b/sgl-router/py_test/test_launch_router.py
@@ -28,6 +28,7 @@ def setUp(self):
host="127.0.0.1",
port=30000,
policy="cache_aware",
+ worker_startup_timeout_secs=600,
cache_threshold=0.5,
balance_abs_threshold=32,
balance_rel_threshold=1.0001,
diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs
index 2d8cf4c0c8d..8355f135216 100644
--- a/sgl-router/src/lib.rs
+++ b/sgl-router/src/lib.rs
@@ -17,6 +17,7 @@ struct Router {
port: u16,
worker_urls: Vec,
policy: PolicyType,
+ worker_startup_timeout_secs: u64,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
@@ -34,6 +35,7 @@ impl Router {
policy = PolicyType::RoundRobin,
host = String::from("127.0.0.1"),
port = 3001,
+ worker_startup_timeout_secs = 300,
cache_threshold = 0.50,
balance_abs_threshold = 32,
balance_rel_threshold = 1.0001,
@@ -47,6 +49,7 @@ impl Router {
policy: PolicyType,
host: String,
port: u16,
+ worker_startup_timeout_secs: u64,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
@@ -60,6 +63,7 @@ impl Router {
port,
worker_urls,
policy,
+ worker_startup_timeout_secs,
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
@@ -72,9 +76,14 @@ impl Router {
fn start(&self) -> PyResult<()> {
let policy_config = match &self.policy {
- PolicyType::Random => router::PolicyConfig::RandomConfig,
- PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig,
+ PolicyType::Random => router::PolicyConfig::RandomConfig {
+ timeout_secs: self.worker_startup_timeout_secs,
+ },
+ PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
+ timeout_secs: self.worker_startup_timeout_secs,
+ },
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
+ timeout_secs: self.worker_startup_timeout_secs,
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
@@ -93,10 +102,9 @@ impl Router {
max_payload_size: self.max_payload_size,
})
.await
- .unwrap();
- });
-
- Ok(())
+ .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
+ Ok(())
+ })
}
}
diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs
index 08f6cdefa75..6ea791685d4 100644
--- a/sgl-router/src/router.rs
+++ b/sgl-router/src/router.rs
@@ -3,7 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
use futures_util::{StreamExt, TryStreamExt};
-use log::{debug, info, warn};
+use log::{debug, error, info, warn};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
@@ -17,9 +17,11 @@ pub enum Router {
RoundRobin {
worker_urls: Arc>>,
current_index: AtomicUsize,
+ timeout_secs: u64,
},
Random {
worker_urls: Arc>>,
+ timeout_secs: u64,
},
CacheAware {
/*
@@ -89,36 +91,51 @@ pub enum Router {
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
+ timeout_secs: u64,
_eviction_thread: Option>,
},
}
#[derive(Debug, Clone)]
pub enum PolicyConfig {
- RandomConfig,
- RoundRobinConfig,
+ RandomConfig {
+ timeout_secs: u64,
+ },
+ RoundRobinConfig {
+ timeout_secs: u64,
+ },
CacheAwareConfig {
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
eviction_interval_secs: u64,
max_tree_size: usize,
+ timeout_secs: u64,
},
}
impl Router {
pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result {
+ // Get timeout from policy config
+ let timeout_secs = match &policy_config {
+ PolicyConfig::RandomConfig { timeout_secs } => *timeout_secs,
+ PolicyConfig::RoundRobinConfig { timeout_secs } => *timeout_secs,
+ PolicyConfig::CacheAwareConfig { timeout_secs, .. } => *timeout_secs,
+ };
+
// Wait until all workers are healthy
- Self::wait_for_healthy_workers(&worker_urls, 300, 10)?;
+ Self::wait_for_healthy_workers(&worker_urls, timeout_secs, 10)?;
// Create router based on policy...
Ok(match policy_config {
- PolicyConfig::RandomConfig => Router::Random {
+ PolicyConfig::RandomConfig { timeout_secs } => Router::Random {
worker_urls: Arc::new(RwLock::new(worker_urls)),
+ timeout_secs,
},
- PolicyConfig::RoundRobinConfig => Router::RoundRobin {
+ PolicyConfig::RoundRobinConfig { timeout_secs } => Router::RoundRobin {
worker_urls: Arc::new(RwLock::new(worker_urls)),
current_index: std::sync::atomic::AtomicUsize::new(0),
+ timeout_secs,
},
PolicyConfig::CacheAwareConfig {
cache_threshold,
@@ -126,6 +143,7 @@ impl Router {
balance_rel_threshold,
eviction_interval_secs,
max_tree_size,
+ timeout_secs,
} => {
let mut running_queue = HashMap::new();
for url in &worker_urls {
@@ -176,6 +194,7 @@ impl Router {
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
+ timeout_secs,
_eviction_thread: Some(eviction_thread),
}
}
@@ -192,6 +211,10 @@ impl Router {
loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
+ error!(
+ "Timeout {}s waiting for workers to become healthy",
+ timeout_secs
+ );
return Err(format!(
"Timeout {}s waiting for workers to become healthy",
timeout_secs
@@ -238,7 +261,7 @@ impl Router {
fn select_first_worker(&self) -> Result {
match self {
Router::RoundRobin { worker_urls, .. }
- | Router::Random { worker_urls }
+ | Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => {
if worker_urls.read().unwrap().is_empty() {
Err("No workers are available".to_string())
@@ -349,6 +372,7 @@ impl Router {
Router::RoundRobin {
worker_urls,
current_index,
+ ..
} => {
let idx = current_index
.fetch_update(
@@ -360,7 +384,7 @@ impl Router {
worker_urls.read().unwrap()[idx].clone()
}
- Router::Random { worker_urls } => worker_urls.read().unwrap()
+ Router::Random { worker_urls, .. } => worker_urls.read().unwrap()
[rand::random::() % worker_urls.read().unwrap().len()]
.clone(),
@@ -571,13 +595,21 @@ impl Router {
pub async fn add_worker(&self, worker_url: &str) -> Result {
let interval_secs = 10; // check every 10 seconds
- let timeout_secs = 300; // 5 minutes
+ let timeout_secs = match self {
+ Router::Random { timeout_secs, .. } => *timeout_secs,
+ Router::RoundRobin { timeout_secs, .. } => *timeout_secs,
+ Router::CacheAware { timeout_secs, .. } => *timeout_secs,
+ };
let start_time = std::time::Instant::now();
let client = reqwest::Client::new();
loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
+ error!(
+ "Timeout {}s waiting for worker {} to become healthy",
+ timeout_secs, worker_url
+ );
return Err(format!(
"Timeout {}s waiting for worker {} to become healthy",
timeout_secs, worker_url
@@ -589,7 +621,7 @@ impl Router {
if res.status().is_success() {
match self {
Router::RoundRobin { worker_urls, .. }
- | Router::Random { worker_urls }
+ | Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => {
info!("Worker {} health check passed", worker_url);
let mut urls = worker_urls.write().unwrap();
@@ -663,7 +695,7 @@ impl Router {
pub fn remove_worker(&self, worker_url: &str) {
match self {
Router::RoundRobin { worker_urls, .. }
- | Router::Random { worker_urls }
+ | Router::Random { worker_urls, .. }
| Router::CacheAware { worker_urls, .. } => {
let mut urls = worker_urls.write().unwrap();
if let Some(index) = urls.iter().position(|url| url == &worker_url) {
diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs
index 09878f07f8e..e3587389e9f 100644
--- a/sgl-router/src/server.rs
+++ b/sgl-router/src/server.rs
@@ -18,14 +18,10 @@ impl AppState {
worker_urls: Vec,
client: reqwest::Client,
policy_config: PolicyConfig,
- ) -> Self {
+ ) -> Result {
// Create router based on policy
- let router = match Router::new(worker_urls, policy_config) {
- Ok(router) => router,
- Err(error) => panic!("Failed to create router: {}", error),
- };
-
- Self { router, client }
+ let router = Router::new(worker_urls, policy_config)?;
+ Ok(Self { router, client })
}
}
@@ -131,6 +127,7 @@ pub struct ServerConfig {
}
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
+ // Initialize logger
Builder::new()
.format(|buf, record| {
use chrono::Local;
@@ -152,24 +149,30 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
)
.init();
+ info!("🚧 Initializing router on {}:{}", config.host, config.port);
+ info!("🚧 Initializing workers on {:?}", config.worker_urls);
+ info!("🚧 Policy Config: {:?}", config.policy_config);
+ info!(
+ "🚧 Max payload size: {} MB",
+ config.max_payload_size / (1024 * 1024)
+ );
+
let client = reqwest::Client::builder()
.build()
.expect("Failed to create HTTP client");
- let app_state = web::Data::new(AppState::new(
- config.worker_urls.clone(),
- client,
- config.policy_config.clone(),
- ));
-
- info!("✅ Starting router on {}:{}", config.host, config.port);
- info!("✅ Serving Worker URLs: {:?}", config.worker_urls);
- info!("✅ Policy Config: {:?}", config.policy_config);
- info!(
- "✅ Max payload size: {} MB",
- config.max_payload_size / (1024 * 1024)
+ let app_state = web::Data::new(
+ AppState::new(
+ config.worker_urls.clone(),
+ client,
+ config.policy_config.clone(),
+ )
+ .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
);
+ info!("✅ Serving router on {}:{}", config.host, config.port);
+ info!("✅ Serving workers on {:?}", config.worker_urls);
+
HttpServer::new(move || {
App::new()
.app_data(app_state.clone())
From 3a8428ecaa6375996de0142afd73df2f98c4cc23 Mon Sep 17 00:00:00 2001
From: Byron Hsu
Date: Mon, 20 Jan 2025 14:36:54 -0800
Subject: [PATCH 039/147] [router] Expose worker startup interval (#3019)
---
.../py_src/sglang_router/launch_router.py | 11 ++++
sgl-router/py_src/sglang_router/router.py | 3 +
sgl-router/py_test/test_launch_router.py | 1 +
sgl-router/src/lib.rs | 7 +++
sgl-router/src/router.rs | 63 +++++++++++++++----
5 files changed, 72 insertions(+), 13 deletions(-)
diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py
index 384e3666db0..38f1fbba2dc 100644
--- a/sgl-router/py_src/sglang_router/launch_router.py
+++ b/sgl-router/py_src/sglang_router/launch_router.py
@@ -34,6 +34,7 @@ class RouterArgs:
# Routing policy
policy: str = "cache_aware"
worker_startup_timeout_secs: int = 300
+ worker_startup_check_interval: int = 10
cache_threshold: float = 0.5
balance_abs_threshold: int = 32
balance_rel_threshold: float = 1.0001
@@ -94,6 +95,12 @@ def add_cli_args(
default=RouterArgs.worker_startup_timeout_secs,
help="Timeout in seconds for worker startup",
)
+ parser.add_argument(
+ f"--{prefix}worker-startup-check-interval",
+ type=int,
+ default=RouterArgs.worker_startup_check_interval,
+ help="Interval in seconds between checks for worker startup",
+ )
parser.add_argument(
f"--{prefix}cache-threshold",
type=float,
@@ -157,6 +164,9 @@ def from_cli_args(
worker_startup_timeout_secs=getattr(
args, f"{prefix}worker_startup_timeout_secs"
),
+ worker_startup_check_interval=getattr(
+ args, f"{prefix}worker_startup_check_interval"
+ ),
cache_threshold=getattr(args, f"{prefix}cache_threshold"),
balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"),
balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"),
@@ -202,6 +212,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
port=router_args.port,
policy=policy_from_str(router_args.policy),
worker_startup_timeout_secs=router_args.worker_startup_timeout_secs,
+ worker_startup_check_interval=router_args.worker_startup_check_interval,
cache_threshold=router_args.cache_threshold,
balance_abs_threshold=router_args.balance_abs_threshold,
balance_rel_threshold=router_args.balance_rel_threshold,
diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py
index 1665f8a67be..b8757168b24 100644
--- a/sgl-router/py_src/sglang_router/router.py
+++ b/sgl-router/py_src/sglang_router/router.py
@@ -18,6 +18,7 @@ class Router:
host: Host address to bind the router server. Default: '127.0.0.1'
port: Port number to bind the router server. Default: 3001
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
+ worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10
cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker
if the match rate exceeds threshold, otherwise routes to the worker with the smallest
tree. Default: 0.5
@@ -39,6 +40,7 @@ def __init__(
host: str = "127.0.0.1",
port: int = 3001,
worker_startup_timeout_secs: int = 300,
+ worker_startup_check_interval: int = 10,
cache_threshold: float = 0.50,
balance_abs_threshold: int = 32,
balance_rel_threshold: float = 1.0001,
@@ -53,6 +55,7 @@ def __init__(
host=host,
port=port,
worker_startup_timeout_secs=worker_startup_timeout_secs,
+ worker_startup_check_interval=worker_startup_check_interval,
cache_threshold=cache_threshold,
balance_abs_threshold=balance_abs_threshold,
balance_rel_threshold=balance_rel_threshold,
diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py
index 15549cae72f..27ed64d6e66 100644
--- a/sgl-router/py_test/test_launch_router.py
+++ b/sgl-router/py_test/test_launch_router.py
@@ -29,6 +29,7 @@ def setUp(self):
port=30000,
policy="cache_aware",
worker_startup_timeout_secs=600,
+ worker_startup_check_interval=10,
cache_threshold=0.5,
balance_abs_threshold=32,
balance_rel_threshold=1.0001,
diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs
index 8355f135216..ba9aeac1fef 100644
--- a/sgl-router/src/lib.rs
+++ b/sgl-router/src/lib.rs
@@ -18,6 +18,7 @@ struct Router {
worker_urls: Vec,
policy: PolicyType,
worker_startup_timeout_secs: u64,
+ worker_startup_check_interval: u64,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
@@ -36,6 +37,7 @@ impl Router {
host = String::from("127.0.0.1"),
port = 3001,
worker_startup_timeout_secs = 300,
+ worker_startup_check_interval = 10,
cache_threshold = 0.50,
balance_abs_threshold = 32,
balance_rel_threshold = 1.0001,
@@ -50,6 +52,7 @@ impl Router {
host: String,
port: u16,
worker_startup_timeout_secs: u64,
+ worker_startup_check_interval: u64,
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
@@ -64,6 +67,7 @@ impl Router {
worker_urls,
policy,
worker_startup_timeout_secs,
+ worker_startup_check_interval,
cache_threshold,
balance_abs_threshold,
balance_rel_threshold,
@@ -78,12 +82,15 @@ impl Router {
let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig {
timeout_secs: self.worker_startup_timeout_secs,
+ interval_secs: self.worker_startup_check_interval,
},
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
timeout_secs: self.worker_startup_timeout_secs,
+ interval_secs: self.worker_startup_check_interval,
},
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
timeout_secs: self.worker_startup_timeout_secs,
+ interval_secs: self.worker_startup_check_interval,
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs
index 6ea791685d4..5bbffc74ccf 100644
--- a/sgl-router/src/router.rs
+++ b/sgl-router/src/router.rs
@@ -18,10 +18,12 @@ pub enum Router {
worker_urls: Arc>>,
current_index: AtomicUsize,
timeout_secs: u64,
+ interval_secs: u64,
},
Random {
worker_urls: Arc>>,
timeout_secs: u64,
+ interval_secs: u64,
},
CacheAware {
/*
@@ -92,6 +94,7 @@ pub enum Router {
balance_abs_threshold: usize,
balance_rel_threshold: f32,
timeout_secs: u64,
+ interval_secs: u64,
_eviction_thread: Option>,
},
}
@@ -100,9 +103,11 @@ pub enum Router {
pub enum PolicyConfig {
RandomConfig {
timeout_secs: u64,
+ interval_secs: u64,
},
RoundRobinConfig {
timeout_secs: u64,
+ interval_secs: u64,
},
CacheAwareConfig {
cache_threshold: f32,
@@ -111,31 +116,50 @@ pub enum PolicyConfig {
eviction_interval_secs: u64,
max_tree_size: usize,
timeout_secs: u64,
+ interval_secs: u64,
},
}
impl Router {
pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result {
- // Get timeout from policy config
- let timeout_secs = match &policy_config {
- PolicyConfig::RandomConfig { timeout_secs } => *timeout_secs,
- PolicyConfig::RoundRobinConfig { timeout_secs } => *timeout_secs,
- PolicyConfig::CacheAwareConfig { timeout_secs, .. } => *timeout_secs,
+ // Get timeout and interval from policy config
+ let (timeout_secs, interval_secs) = match &policy_config {
+ PolicyConfig::RandomConfig {
+ timeout_secs,
+ interval_secs,
+ } => (*timeout_secs, *interval_secs),
+ PolicyConfig::RoundRobinConfig {
+ timeout_secs,
+ interval_secs,
+ } => (*timeout_secs, *interval_secs),
+ PolicyConfig::CacheAwareConfig {
+ timeout_secs,
+ interval_secs,
+ ..
+ } => (*timeout_secs, *interval_secs),
};
// Wait until all workers are healthy
- Self::wait_for_healthy_workers(&worker_urls, timeout_secs, 10)?;
+ Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
// Create router based on policy...
Ok(match policy_config {
- PolicyConfig::RandomConfig { timeout_secs } => Router::Random {
+ PolicyConfig::RandomConfig {
+ timeout_secs,
+ interval_secs,
+ } => Router::Random {
worker_urls: Arc::new(RwLock::new(worker_urls)),
timeout_secs,
+ interval_secs,
},
- PolicyConfig::RoundRobinConfig { timeout_secs } => Router::RoundRobin {
+ PolicyConfig::RoundRobinConfig {
+ timeout_secs,
+ interval_secs,
+ } => Router::RoundRobin {
worker_urls: Arc::new(RwLock::new(worker_urls)),
current_index: std::sync::atomic::AtomicUsize::new(0),
timeout_secs,
+ interval_secs,
},
PolicyConfig::CacheAwareConfig {
cache_threshold,
@@ -144,6 +168,7 @@ impl Router {
eviction_interval_secs,
max_tree_size,
timeout_secs,
+ interval_secs,
} => {
let mut running_queue = HashMap::new();
for url in &worker_urls {
@@ -195,6 +220,7 @@ impl Router {
balance_abs_threshold,
balance_rel_threshold,
timeout_secs,
+ interval_secs,
_eviction_thread: Some(eviction_thread),
}
}
@@ -594,11 +620,22 @@ impl Router {
}
pub async fn add_worker(&self, worker_url: &str) -> Result {
- let interval_secs = 10; // check every 10 seconds
- let timeout_secs = match self {
- Router::Random { timeout_secs, .. } => *timeout_secs,
- Router::RoundRobin { timeout_secs, .. } => *timeout_secs,
- Router::CacheAware { timeout_secs, .. } => *timeout_secs,
+ let (timeout_secs, interval_secs) = match self {
+ Router::Random {
+ timeout_secs,
+ interval_secs,
+ ..
+ } => (*timeout_secs, *interval_secs),
+ Router::RoundRobin {
+ timeout_secs,
+ interval_secs,
+ ..
+ } => (*timeout_secs, *interval_secs),
+ Router::CacheAware {
+ timeout_secs,
+ interval_secs,
+ ..
+ } => (*timeout_secs, *interval_secs),
};
let start_time = std::time::Instant::now();
From 3ad4cd491575d9ac6f9faf7582b418c6d6bb34e6 Mon Sep 17 00:00:00 2001
From: Byron Hsu
Date: Mon, 20 Jan 2025 14:38:06 -0800
Subject: [PATCH 040/147] bump router to 0.1.3 (#3020)
---
sgl-router/py_src/sglang_router/version.py | 2 +-
sgl-router/pyproject.toml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/sgl-router/py_src/sglang_router/version.py b/sgl-router/py_src/sglang_router/version.py
index b3f4756216d..ae7362549b3 100644
--- a/sgl-router/py_src/sglang_router/version.py
+++ b/sgl-router/py_src/sglang_router/version.py
@@ -1 +1 @@
-__version__ = "0.1.2"
+__version__ = "0.1.3"
diff --git a/sgl-router/pyproject.toml b/sgl-router/pyproject.toml
index 90e82cecf37..3a00d047200 100644
--- a/sgl-router/pyproject.toml
+++ b/sgl-router/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "sglang-router"
-version = "0.1.2"
+version = "0.1.3"
description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances."
authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}]
requires-python = ">=3.8"
From af6c5357d5cb283341ef21f8aeb093753bd10a2b Mon Sep 17 00:00:00 2001
From: Enrique Shockwave <33002121+qeternity@users.noreply.github.com>
Date: Mon, 20 Jan 2025 22:40:12 +0000
Subject: [PATCH 041/147] deepseek v3 and r1 chat template (#3015)
---
python/sglang/lang/chat_template.py | 31 +++++++++++++++++++++++++++++
1 file changed, 31 insertions(+)
diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py
index 845e1e52dda..a2c91c561c2 100644
--- a/python/sglang/lang/chat_template.py
+++ b/python/sglang/lang/chat_template.py
@@ -354,6 +354,37 @@ def get_chat_template_by_model_path(model_path):
)
+register_chat_template(
+ ChatTemplate(
+ name="deepseek-v3",
+ default_system_prompt=None,
+ role_prefix_and_suffix={
+ "system": (
+ "",
+ "",
+ ),
+ "user": (
+ "<|User|>",
+ "",
+ ),
+ "assistant": (
+ "<|Assistant|>",
+ "<|end▁of▁sentence|>",
+ ),
+ },
+ stop_str=("<|end▁of▁sentence|>",),
+ )
+)
+
+
+@register_chat_template_matching_function
+def match_deepseek(model_path: str):
+ if (
+ "deepseek-v3" in model_path.lower() or "deepseek-r1" in model_path.lower()
+ ) and "base" not in model_path.lower():
+ return get_chat_template("deepseek-v3")
+
+
@register_chat_template_matching_function
def match_dbrx(model_path: str):
if "dbrx" in model_path.lower() and "instruct" in model_path.lower():
From da4e8b389280009833cd0a01e8cf5ce4746f77fc Mon Sep 17 00:00:00 2001
From: Hui Liu <96135754+hliuca@users.noreply.github.com>
Date: Mon, 20 Jan 2025 14:40:45 -0800
Subject: [PATCH 042/147] enable kv_scale remap (#3017)
---
python/sglang/srt/models/commandr.py | 10 +++++++++-
python/sglang/srt/models/dbrx.py | 10 +++++++++-
2 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py
index 151087732f0..6d2e6d2bb41 100644
--- a/python/sglang/srt/models/commandr.py
+++ b/python/sglang/srt/models/commandr.py
@@ -61,7 +61,10 @@
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
-from sglang.srt.model_loader.weight_utils import default_weight_loader
+from sglang.srt.model_loader.weight_utils import (
+ default_weight_loader,
+ maybe_remap_kv_scale_name,
+)
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
@@ -372,6 +375,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py
index cedc9639220..92fc679391f 100644
--- a/python/sglang/srt/models/dbrx.py
+++ b/python/sglang/srt/models/dbrx.py
@@ -42,7 +42,10 @@
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
-from sglang.srt.model_loader.weight_utils import default_weight_loader
+from sglang.srt.model_loader.weight_utils import (
+ default_weight_loader,
+ maybe_remap_kv_scale_name,
+)
from sglang.srt.utils import set_weight_attrs
@@ -411,6 +414,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight, weight_name)
break
else:
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
From 949b3fbfce7ce3bc3a2e971bc6fbcd501dcc6ece Mon Sep 17 00:00:00 2001
From: Hongpeng Guo
Date: Mon, 20 Jan 2025 16:50:25 -0800
Subject: [PATCH 043/147] [Doc] Update doc of custom logit processor (#3021)
Signed-off-by: Hongpeng Guo
---
docs/references/sampling_params.md | 68 +++++++++++++++++++++++++
python/sglang/srt/managers/io_struct.py | 11 ++--
2 files changed, 75 insertions(+), 4 deletions(-)
diff --git a/docs/references/sampling_params.md b/docs/references/sampling_params.md
index cdc53da61a4..77d7c9f82e7 100644
--- a/docs/references/sampling_params.md
+++ b/docs/references/sampling_params.md
@@ -32,6 +32,20 @@ class GenerateReqInput:
return_text_in_logprobs: bool = False
# Whether to stream output.
stream: bool = False
+ # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
+ log_metrics: bool = True
+
+ # The modalities of the image data [image, multi-images, video]
+ modalities: Optional[List[str]] = None
+ # LoRA related
+ lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
+
+ # Session info for continual prompting
+ session_params: Optional[Union[List[Dict], Dict]] = None
+ # Custom logit processor for advanced sampling control. Must be a serialized instance
+ # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
+ # Use the processor's `to_str()` method to generate the serialized string.
+ custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
```
The `sampling_params` follows this format
@@ -90,6 +104,14 @@ repetition_penalty: float = 1.0,
# difficult to infer the correct token ID by given `stop` strings.
# Must be 0 <= value < max_new_tokens. Setting to 0 (default) will disable this penalty.
min_new_tokens: int = 0,
+
+
+## Custom Parameters for Custom Logit Processor.
+# A dictionary of custom parameters for the custom logit processor.
+# The custom logit processor takes a list of dictionaries as input, where each
+# dictionary is the custom parameters for one token in a batch of the input.
+# See also python/sglang/srt/sampling/custom_logit_processor.py
+custom_params: Optional[Dict[str, Any]] = None,
```
## Examples
@@ -253,3 +275,49 @@ response = requests.post(
)
print(response.json())
```
+### Custom Logit Processor
+Launch a server with `--enable-custom-logit-processor` flag on.
+```
+python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --enable-custom-logit-processor
+```
+
+Define a custom logit processor that will always sample a specific token id.
+```python
+from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
+
+class DeterministicLogitProcessor(CustomLogitProcessor):
+ """A dummy logit processor that changes the logits to always
+ sample the given token id.
+ """
+
+ def __call__(self, logits, custom_param_list):
+ # Check that the number of logits matches the number of custom parameters
+ assert logits.shape[0] == len(custom_param_list)
+ key = "token_id"
+
+ for i, param_dict in enumerate(custom_param_list):
+ # Mask all other tokens
+ logits[i, :] = -float("inf")
+ # Assign highest probability to the specified token
+ logits[i, param_dict[key]] = 0.0
+ return logits
+```
+
+Send a request
+```python
+import requests
+
+response = requests.post(
+ "http://localhost:30000/generate",
+ json={
+ "text": "The capital of France is",
+ "custom_logit_processor": DeterministicLogitProcessor().to_str(),
+ "sampling_params": {
+ "temperature": 0.0,
+ "max_new_tokens": 32,
+ "custom_params": {"token_id": 5},
+ },
+ },
+)
+print(response.json())
+```
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 9183239838d..eee9b6722d4 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -69,8 +69,10 @@ class GenerateReqInput:
# Session info for continual prompting
session_params: Optional[Union[List[Dict], Dict]] = None
- # Custom logit processor (serialized function)
- custom_logit_processor: Optional[Union[List[Optional[str]], Optional[str]]] = None
+ # Custom logit processor for advanced sampling control. Must be a serialized instance
+ # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
+ # Use the processor's `to_str()` method to generate the serialized string.
+ custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
def normalize_batch_and_arguments(self):
if (
@@ -248,8 +250,9 @@ class TokenizedGenerateReqInput:
# Session info for continual prompting
session_params: Optional[SessionParams] = None
- # Custom logit processor (serialized function)
- # TODO (hpguo): Add an example and update doc string here
+ # Custom logit processor for advanced sampling control. Must be a serialized instance
+ # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
+ # Use the processor's `to_str()` method to generate the serialized string.
custom_logit_processor: Optional[str] = None
From 60b2a44a80d1bb168dcf28a1980c09f2d3364153 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Mon, 20 Jan 2025 16:50:39 -0800
Subject: [PATCH 044/147] Fix flaky tests in test_programs.py (#3022)
---
python/sglang/test/test_programs.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py
index 219ed3cf6ec..361bbaed00c 100644
--- a/python/sglang/test/test_programs.py
+++ b/python/sglang/test/test_programs.py
@@ -535,7 +535,7 @@ def few_shot_hellaswag(s, question, choices):
# Compute accuracy
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
- assert np.abs(accuracy_gen - accuracy) < 0.05
+ assert np.abs(accuracy_gen - accuracy) < 0.1
assert np.abs(latency_gen - latency) < 1
return accuracy, latency
From b730aa6b9e577670f1967b65ea9f24a32e0aca8d Mon Sep 17 00:00:00 2001
From: 996_icu <85502239+josephydu@users.noreply.github.com>
Date: Tue, 21 Jan 2025 09:46:43 +0800
Subject: [PATCH 045/147] [EAGLE] Fix some boundary situation when retract reqs
and req's max token = 1 (#2939)
Co-authored-by: josephyou
---
python/sglang/srt/managers/schedule_batch.py | 2 ++
python/sglang/srt/speculative/eagle_utils.py | 8 ++++++++
2 files changed, 10 insertions(+)
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index 040afe3d324..d9af8151534 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -1112,6 +1112,8 @@ def filter_batch(
self.has_grammar = any(req.grammar for req in self.reqs)
self.sampling_info.filter_batch(keep_indices, new_indices)
+ if self.spec_info:
+ self.spec_info.filter_batch(new_indices)
def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py
index 1a324000cb2..ac16f6c532e 100644
--- a/python/sglang/srt/speculative/eagle_utils.py
+++ b/python/sglang/srt/speculative/eagle_utils.py
@@ -228,6 +228,14 @@ def prepare_for_extend(self, batch: ScheduleBatch):
assert len(batch.extend_lens) == 1
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
+ def filter_batch(
+ self,
+ new_indices: torch.Tensor,
+ ):
+ self.sample_output = self.sample_output[: len(new_indices)]
+ self.hidden_states = self.hidden_states[: len(new_indices)]
+ self.verified_id = self.verified_id[: len(new_indices)]
+
def prepare_for_decode(self, batch: ScheduleBatch):
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
top = torch.topk(prob, self.topk, dim=-1)
From d2571dd5c7b4cee5f690dc3403cfeef0ca7115b7 Mon Sep 17 00:00:00 2001
From: Hui Liu <96135754+hliuca@users.noreply.github.com>
Date: Mon, 20 Jan 2025 19:21:41 -0800
Subject: [PATCH 046/147] Enable Cohere2 Models (#3018)
---
python/sglang/srt/models/commandr.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py
index 6d2e6d2bb41..e4b291b66cb 100644
--- a/python/sglang/srt/models/commandr.py
+++ b/python/sglang/srt/models/commandr.py
@@ -386,4 +386,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loaded_params.add(name)
-EntryClass = CohereForCausalLM
+class Cohere2ForCausalLM(CohereForCausalLM):
+ pass
+
+
+EntryClass = [CohereForCausalLM, Cohere2ForCausalLM]
From 287d07a669d3fd0b0650959b0e35c8e886513824 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Mon, 20 Jan 2025 20:25:13 -0800
Subject: [PATCH 047/147] Misc fixes for eagle (flush_cache, CPU overhead)
(#3014)
---
python/sglang/bench_offline_throughput.py | 28 +++---
python/sglang/bench_serving.py | 91 ++++++++++---------
python/sglang/srt/managers/scheduler.py | 11 ++-
.../srt/model_executor/forward_batch_info.py | 8 +-
python/sglang/srt/server.py | 4 +-
python/sglang/srt/speculative/eagle_utils.py | 43 ++++++---
python/sglang/srt/speculative/eagle_worker.py | 24 +++--
python/sglang/srt/utils.py | 7 ++
python/sglang/test/test_programs.py | 3 +-
python/sglang/test/test_utils.py | 7 +-
test/lang/test_srt_backend.py | 1 +
11 files changed, 132 insertions(+), 95 deletions(-)
diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py
index b0a715e61cc..9d56ff07c8b 100644
--- a/python/sglang/bench_offline_throughput.py
+++ b/python/sglang/bench_offline_throughput.py
@@ -49,12 +49,13 @@ class BenchArgs:
gsp_system_prompt_len: int = 2048
gsp_question_len: int = 128
gsp_output_len: int = 256
+ seed: int = 1
disable_ignore_eos: bool = False
extra_request_body: Optional[str] = None
- seed: int = 1
+ apply_chat_template: bool = False
+ profile: bool = False
skip_warmup: bool = False
do_not_exit: bool = False
- profile: bool = False
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
@@ -141,20 +142,31 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=BenchArgs.gsp_output_len,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
+ parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--disable-ignore-eos",
- type=bool,
- default=BenchArgs.disable_ignore_eos,
+ action="store_true",
help="Disable ignore EOS token",
)
parser.add_argument(
"--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}',
type=str,
+ default=BenchArgs.extra_request_body,
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
- parser.add_argument("--seed", type=int, default=1, help="The random seed.")
+ parser.add_argument(
+ "--apply-chat-template",
+ action="store_true",
+ help="Apply chat template",
+ )
+ parser.add_argument(
+ "--profile",
+ action="store_true",
+ help="Use Torch Profiler. The endpoint must be launched with "
+ "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
+ )
parser.add_argument(
"--skip-warmup",
action="store_true",
@@ -165,12 +177,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
)
- parser.add_argument(
- "--profile",
- action="store_true",
- help="Use Torch Profiler. The endpoint must be launched with "
- "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
- )
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py
index 991b4ddcf1a..10ce965be74 100644
--- a/python/sglang/bench_serving.py
+++ b/python/sglang/bench_serving.py
@@ -453,6 +453,7 @@ def get_dataset(args, tokenizer):
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len,
+ apply_chat_template=args.apply_chat_template,
)
elif args.dataset_name == "random":
input_requests = sample_random_requests(
@@ -517,6 +518,7 @@ class BenchmarkMetrics:
median_e2e_latency_ms: float
std_e2e_latency_ms: float
p99_e2e_latency_ms: float
+ concurrency: float
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
@@ -562,6 +564,7 @@ def sample_sharegpt_requests(
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None,
+ apply_chat_template=False,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
@@ -592,6 +595,15 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions.
prompt = dataset[i][0]
+
+ if apply_chat_template:
+ prompt = tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt}],
+ add_generation_prompt=True,
+ tokenize=False,
+ )
+ prompt = prompt.replace(tokenizer.bos_token, "")
+
prompt_token_ids = tokenizer.encode(prompt)
completion = dataset[i][1]
completion_token_ids = tokenizer.encode(completion)
@@ -600,7 +612,7 @@ def sample_sharegpt_requests(
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
)
- if prompt_len < 1 or output_len < 1:
+ if prompt_len < 2 or output_len < 2:
# Prune too short sequences.
continue
@@ -880,6 +892,7 @@ def calculate_metrics(
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
+ concurrency=np.sum(e2e_latencies) / dur_s,
)
return metrics, output_lens
@@ -1031,6 +1044,7 @@ async def limited_request_func(request_func_input, pbar):
"Total token throughput (tok/s):", metrics.total_throughput
)
)
+ print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
@@ -1062,13 +1076,24 @@ async def limited_request_func(request_func_input, pbar):
and metrics.output_throughput is not None
):
result = {
+ # Arguments
"backend": args.backend,
"dataset_name": args.dataset_name,
"request_rate": request_rate,
"max_concurrency": max_concurrency,
+ "sharegpt_output_len": args.sharegpt_output_len,
+ "random_input_len": args.random_input_len,
+ "random_output_len": args.random_output_len,
+ "random_range_ratio": args.random_range_ratio,
+ # Results
+ "duration": benchmark_duration,
+ "completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized,
+ "request_throughput": metrics.request_throughput,
+ "input_throughput": metrics.input_throughput,
+ "output_throughput": metrics.output_throughput,
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
"std_e2e_latency_ms": metrics.std_e2e_latency_ms,
@@ -1085,14 +1110,7 @@ async def limited_request_func(request_func_input, pbar):
"median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
- "input_throughput": metrics.input_throughput,
- "output_throughput": metrics.output_throughput,
- "sharegpt_output_len": args.sharegpt_output_len,
- "random_input_len": args.random_input_len,
- "random_output_len": args.random_output_len,
- "random_range_ratio": args.random_range_ratio,
- "duration": benchmark_duration,
- "completed": metrics.completed,
+ "concurrency": metrics.concurrency,
}
else:
print(f"Error running benchmark for request rate: {request_rate}")
@@ -1112,36 +1130,16 @@ async def limited_request_func(request_func_input, pbar):
with open(output_file_name, "a") as file:
file.write(json.dumps(result) + "\n")
- result = {
- "duration": benchmark_duration,
- "completed": metrics.completed,
- "total_input_tokens": metrics.total_input,
- "total_output_tokens": metrics.total_output,
- "total_output_tokens_retokenized": metrics.total_output_retokenized,
- "request_throughput": metrics.request_throughput,
- "input_throughput": metrics.input_throughput,
- "output_throughput": metrics.output_throughput,
- "mean_ttft_ms": metrics.mean_ttft_ms,
- "median_ttft_ms": metrics.median_ttft_ms,
- "std_ttft_ms": metrics.std_ttft_ms,
- "p99_ttft_ms": metrics.p99_ttft_ms,
- "mean_tpot_ms": metrics.mean_tpot_ms,
- "median_tpot_ms": metrics.median_tpot_ms,
- "std_tpot_ms": metrics.std_tpot_ms,
- "p99_tpot_ms": metrics.p99_tpot_ms,
- "mean_itl_ms": metrics.mean_itl_ms,
- "median_itl_ms": metrics.median_itl_ms,
- "std_itl_ms": metrics.std_itl_ms,
- "p99_itl_ms": metrics.p99_itl_ms,
- "input_lens": [output.prompt_len for output in outputs],
- "output_lens": output_lens,
- "ttfts": [output.ttft for output in outputs],
- "itls": [output.itl for output in outputs],
- "generated_texts": [output.generated_text for output in outputs],
- "errors": [output.error for output in outputs],
- "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
- "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
- }
+ result.update(
+ {
+ "input_lens": [output.prompt_len for output in outputs],
+ "output_lens": output_lens,
+ "ttfts": [output.ttft for output in outputs],
+ "itls": [output.itl for output in outputs],
+ "generated_texts": [output.generated_text for output in outputs],
+ "errors": [output.error for output in outputs],
+ }
+ )
return result
@@ -1422,7 +1420,6 @@ def set_ulimit(target_soft_limit=65535):
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.",
)
- parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
"--multi",
action="store_true",
@@ -1446,14 +1443,15 @@ def set_ulimit(target_soft_limit=65535):
help="Disable streaming mode.",
)
parser.add_argument(
- "--disable-ignore-eos",
+ "--return-logprob",
action="store_true",
- help="Disable ignoring EOS.",
+ help="Return logprob.",
)
+ parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument(
- "--return-logprob",
+ "--disable-ignore-eos",
action="store_true",
- help="Return logprob.",
+ help="Disable ignoring EOS.",
)
parser.add_argument(
"--extra-request-body",
@@ -1462,6 +1460,11 @@ def set_ulimit(target_soft_limit=65535):
help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.",
)
+ parser.add_argument(
+ "--apply-chat-template",
+ action="store_true",
+ help="Apply chat template",
+ )
parser.add_argument(
"--profile",
action="store_true",
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index fba8a67ecf4..85bd1c2a4ad 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -1023,7 +1023,7 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]:
)
# Check for jump-forward
- if not self.disable_jump_forward:
+ if not self.disable_jump_forward and batch.has_grammar:
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty():
@@ -1564,6 +1564,15 @@ def flush_cache(self):
self.grammar_backend.reset()
self.req_to_token_pool.clear()
self.token_to_kv_pool.clear()
+
+ if not self.spec_algorithm.is_none():
+ self.draft_worker.model_runner.req_to_token_pool.clear()
+ self.draft_worker.model_runner.token_to_kv_pool.clear()
+
+ self.num_generated_tokens = 0
+ self.forward_ct_decode = 0
+ self.spec_num_total_accepted_tokens = 0
+ self.spec_num_total_forward_ct = 0
torch.cuda.empty_cache()
logger.info("Cache flushed successfully!")
if_success = True
diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py
index 354408ab343..8ef5c57b891 100644
--- a/python/sglang/srt/model_executor/forward_batch_info.py
+++ b/python/sglang/srt/model_executor/forward_batch_info.py
@@ -282,6 +282,9 @@ def init_new(
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
+ req_to_token_pool=model_runner.req_to_token_pool,
+ token_to_kv_pool=model_runner.token_to_kv_pool,
+ attn_backend=model_runner.attn_backend,
spec_algorithm=batch.spec_algorithm,
spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode,
@@ -336,11 +339,6 @@ def init_new(
if model_runner.model_is_mrope:
ret.compute_mrope_positions(model_runner, batch)
- # Init attention information
- ret.req_to_token_pool = model_runner.req_to_token_pool
- ret.token_to_kv_pool = model_runner.token_to_kv_pool
- ret.attn_backend = model_runner.attn_backend
-
# Init lora information
if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret)
diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py
index 8b0c5618622..869a984d0cf 100644
--- a/python/sglang/srt/server.py
+++ b/python/sglang/srt/server.py
@@ -12,7 +12,7 @@
# limitations under the License.
# ==============================================================================
-# Some shortcuts for backward compatbility.
+# Some shortcuts for backward compatibility.
# They will be removed in new versions.
from sglang.srt.entrypoints.engine import Engine
-from sglang.srt.entrypoints.http_server import launch_server
+from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server
diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py
index ac16f6c532e..049ba22750a 100644
--- a/python/sglang/srt/speculative/eagle_utils.py
+++ b/python/sglang/srt/speculative/eagle_utils.py
@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
class EAGLEDraftInput(SpecInfo):
def __init__(self):
self.prev_mode = ForwardMode.DECODE
- self.sample_output = None
self.scores: torch.Tensor = None
self.score_list: List[torch.Tensor] = []
@@ -190,12 +189,16 @@ def __init__(self):
self.cache_list: List[torch.Tenor] = []
self.iter = 0
+ # shape: (b, hidden_size)
self.hidden_states: torch.Tensor = None
+ # shape: (b,)
self.verified_id: torch.Tensor = None
+ # shape: (b, vocab_size)
+ self.sample_output: torch.Tensor = None
+
self.positions: torch.Tensor = None
self.accept_length: torch.Tensor = None
- self.has_finished: bool = False
- self.unfinished_index: List[int] = None
+ self.accept_length_cpu: List[int] = None
def load_server_args(self, server_args: ServerArgs):
self.topk: int = server_args.speculative_eagle_topk
@@ -218,7 +221,7 @@ def prepare_for_extend(self, batch: ScheduleBatch):
:pre_len
] = req.prefix_indices
- batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len]
)
@@ -295,7 +298,9 @@ def prepare_for_decode(self, batch: ScheduleBatch):
self.cache_list.append(batch.out_cache_loc)
self.positions = (
batch.seq_lens[:, None]
- + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
+ + torch.full(
+ [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
+ )
).flatten()
bs = len(batch.seq_lens)
@@ -312,24 +317,25 @@ def prepare_for_decode(self, batch: ScheduleBatch):
def prepare_extend_after_decode(self, batch: ScheduleBatch):
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
- batch.extend_lens = (self.accept_length + 1).tolist()
+ accept_length_cpu = batch.spec_info.accept_length_cpu
+ batch.extend_lens = [x + 1 for x in accept_length_cpu]
+ batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
+ seq_lens_cpu = batch.seq_lens.tolist()
pt = 0
- seq_lens = batch.seq_lens.tolist()
-
i = 0
-
for req in batch.reqs:
if req.finished():
continue
# assert seq_len - pre_len == req.extend_input_len
- input_len = self.accept_length[i] + 1
- seq_len = seq_lens[i]
+ input_len = batch.extend_lens[i]
+ seq_len = seq_lens_cpu[i]
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
seq_len - input_len : seq_len
] = batch.out_cache_loc[pt : pt + input_len]
pt += input_len
i += 1
+ assert pt == batch.out_cache_loc.shape[0]
self.positions = torch.empty_like(self.verified_id)
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
@@ -345,7 +351,7 @@ def prepare_extend_after_decode(self, batch: ScheduleBatch):
triton.next_power_of_2(self.spec_steps + 1),
)
- batch.seq_lens_sum = sum(batch.seq_lens)
+ batch.seq_lens_sum = sum(seq_lens_cpu)
batch.input_ids = self.verified_id
self.verified_id = new_verified_id
@@ -573,6 +579,8 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
finished_extend_len = {} # {rid:accept_length + 1}
accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist()
+ has_finished = False
+
# iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
@@ -586,7 +594,7 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
finished_extend_len[req.rid] = j + 1
req.check_finished()
if req.finished():
- draft_input.has_finished = True
+ has_finished = True
# set all tokens after finished token to -1 and break
accept_index[i, j + 1 :] = -1
break
@@ -600,7 +608,6 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
accept_index = accept_index[accept_index != -1]
accept_length_cpu = accept_length.tolist()
verified_id = predict[accept_index]
- verified_id_cpu = verified_id.tolist()
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
@@ -622,7 +629,13 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten
draft_input.verified_id = predict[new_accept_index]
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
draft_input.accept_length = accept_length[unfinished_index]
- draft_input.unfinished_index = unfinished_index
+ draft_input.accept_length_cpu = [
+ accept_length_cpu[i] for i in unfinished_index
+ ]
+ if has_finished:
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
+ else:
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
return (
diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py
index 2a6ec96048b..06a4372fce2 100644
--- a/python/sglang/srt/speculative/eagle_worker.py
+++ b/python/sglang/srt/speculative/eagle_worker.py
@@ -13,6 +13,7 @@
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
+from sglang.srt.utils import rank0_print
class EAGLEWorker(TpModelWorker):
@@ -50,18 +51,18 @@ def __init__(
def forward_draft_decode(self, batch: ScheduleBatch):
batch.spec_info.prepare_for_decode(batch)
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
def forward_draft_extend(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch)
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch)
self._set_mem_pool(batch, self.target_worker.model_runner)
@@ -134,26 +135,23 @@ def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
batch.req_to_token_pool = runner.req_to_token_pool
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
+ seq_lens_backup = batch.seq_lens
+
self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND
- if batch.spec_info.has_finished:
- index = batch.spec_info.unfinished_index
- seq_lens = batch.seq_lens
- batch.seq_lens = batch.seq_lens[index]
-
batch.spec_info.prepare_extend_after_decode(batch)
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch)
-
- batch.spec_info.hidden_states = logits_output.hidden_states
self.capture_for_decode(logits_output, forward_batch)
- batch.forward_mode = ForwardMode.DECODE
- if batch.spec_info.has_finished:
- batch.seq_lens = seq_lens
self._set_mem_pool(batch, self.target_worker.model_runner)
+ # Restore backup.
+ # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
+ batch.forward_mode = ForwardMode.DECODE
+ batch.seq_lens = seq_lens_backup
+
def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
):
diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py
index 4614114b41d..23dcb43d2d9 100644
--- a/python/sglang/srt/utils.py
+++ b/python/sglang/srt/utils.py
@@ -1442,3 +1442,10 @@ def is_valid_ipv6_address(address: str) -> bool:
return True
except ValueError:
return False
+
+
+def rank0_print(msg: str):
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
+
+ if get_tensor_model_parallel_rank() == 0:
+ print(msg, flush=True)
diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py
index 361bbaed00c..088cb0d0af9 100644
--- a/python/sglang/test/test_programs.py
+++ b/python/sglang/test/test_programs.py
@@ -535,7 +535,8 @@ def few_shot_hellaswag(s, question, choices):
# Compute accuracy
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
- assert np.abs(accuracy_gen - accuracy) < 0.1
+ print(f"{accuracy=}, {accuracy_gen=}")
+ assert np.abs(accuracy_gen - accuracy) < 0.05
assert np.abs(latency_gen - latency) < 1
return accuracy, latency
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index c1437074f67..ad8ff6cbf4d 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -567,15 +567,16 @@ def run_bench_serving(
random_range_ratio=0.0,
request_rate=request_rate,
multi=None,
- seed=0,
output_file=None,
disable_tqdm=False,
disable_stream=disable_stream,
- disable_ignore_eos=False,
return_logprob=False,
- lora_name=None,
+ seed=0,
+ disable_ignore_eos=False,
extra_request_body=None,
+ apply_chat_template=False,
profile=None,
+ lora_name=None,
)
try:
diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py
index 0d7cc910557..a4b1b88a23d 100644
--- a/test/lang/test_srt_backend.py
+++ b/test/lang/test_srt_backend.py
@@ -1,6 +1,7 @@
"""
Usage:
python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens
+python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select
"""
import unittest
From 6c856b4f3a4e63a25f5adc3388bf79ac2a6e4f72 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Tue, 21 Jan 2025 13:08:15 +0800
Subject: [PATCH 048/147] minor: update Makefile for sgl-kernel (#3025)
---
.github/workflows/release-pypi-kernel.yml | 1 +
sgl-kernel/Makefile | 4 ++--
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/release-pypi-kernel.yml b/.github/workflows/release-pypi-kernel.yml
index 362088c47fd..466f2bdc70d 100644
--- a/.github/workflows/release-pypi-kernel.yml
+++ b/.github/workflows/release-pypi-kernel.yml
@@ -14,6 +14,7 @@ concurrency:
jobs:
build-wheels:
+ if: github.repository == 'sgl-project/sglang'
runs-on: ubuntu-latest
strategy:
matrix:
diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile
index fac4c5c56c8..c7641bb5fee 100644
--- a/sgl-kernel/Makefile
+++ b/sgl-kernel/Makefile
@@ -1,7 +1,7 @@
.PHONY: tree ln submodule install build clean test format
tree:
- @tree --prune -I "__pycache__|*.egg-info|*.so|build"
+ @tree --prune -I "__pycache__|*.egg-info|*.so|build|3rdparty|dist"
submodule:
@git submodule update --init --recursive
@@ -19,7 +19,7 @@ clean:
@rm -rf build dist *.egg-info
test:
- @pytest tests/
+ @find tests -name "test_*.py" | xargs -n 1 python3
format:
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
From ec1c21cdc4d9dcfc94f48b0dad182dc34b943553 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Tue, 21 Jan 2025 14:32:08 +0800
Subject: [PATCH 049/147] upgrade torch version for sgl-kernel (#3026)
---
.github/workflows/pr-test-sgl-kernel.yml | 16 ++++++++--------
sgl-kernel/build.sh | 2 +-
2 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml
index cacf938a330..31360c0a068 100644
--- a/.github/workflows/pr-test-sgl-kernel.yml
+++ b/.github/workflows/pr-test-sgl-kernel.yml
@@ -34,16 +34,16 @@ jobs:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner
steps:
- - name: Checkout code
- uses: actions/checkout@v3
+ - uses: actions/checkout@v4
+ with:
+ submodules: 'recursive'
- - name: Install dependencies
+ - name: Install
run: |
- bash scripts/ci_install_dependency.sh
-
+ pip3 install torch==2.5.1
+ pip3 uninstall sgl-kernel -y || true
cd sgl-kernel
- git submodule update --init --recursive
- pip3 install -e . --force-reinstall
+ pip3 install .
pip3 list | grep sgl-kernel
- name: Run test
@@ -57,7 +57,7 @@ jobs:
pip3 uninstall sgl-kernel -y
finish:
- needs: [unit-test]
+ needs: [unit-test, lint]
runs-on: ubuntu-latest
steps:
- name: Finish
diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh
index 55ce9df7f33..0d816957951 100755
--- a/sgl-kernel/build.sh
+++ b/sgl-kernel/build.sh
@@ -8,7 +8,7 @@ docker run --rm \
-v "$(pwd)":/sgl-kernel \
pytorch/manylinux-builder:cuda${CUDA_VERSION} \
bash -c "
- ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.4.0 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
+ ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
export CUDA_VERSION=${CUDA_VERSION} && \
mkdir -p /usr/lib/x86_64-linux-gnu/ && \
From a4331cd260c969ff08a0dbd7465c9b5d87b472b6 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Tue, 21 Jan 2025 02:55:14 -0800
Subject: [PATCH 050/147] Add accuracy and latency tests of eagle into CI
(#3027)
---
.github/workflows/pr-test.yml | 18 ++-
python/sglang/test/test_utils.py | 6 +-
test/srt/models/test_qwen_models.py | 6 +-
test/srt/test_bench_one_batch.py | 26 +++-
test/srt/test_bench_serving.py | 34 ++++-
test/srt/test_eagle_infer.py | 217 ++++++++++++++--------------
test/srt/test_torch_compile.py | 2 +-
7 files changed, 186 insertions(+), 123 deletions(-)
diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml
index 8b8d7c56e7f..c5eeeee3c14 100644
--- a/.github/workflows/pr-test.yml
+++ b/.github/workflows/pr-test.yml
@@ -128,7 +128,7 @@ jobs:
timeout-minutes: 10
run: |
cd test/srt
- python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_default
+ python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1
- name: Benchmark online latency
timeout-minutes: 10
@@ -148,6 +148,13 @@ jobs:
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
+ - name: Benchmark online latency (EAGLE)
+ timeout-minutes: 10
+ run: |
+ cd test/srt
+ python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_eagle
+
+
performance-test-1-gpu-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner
@@ -196,7 +203,13 @@ jobs:
timeout-minutes: 10
run: |
cd test/srt
- python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default
+ python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1
+
+ - name: Benchmark single latency + torch.compile (TP=2)
+ timeout-minutes: 10
+ run: |
+ cd test/srt
+ python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_torch_compile_tp2_bs1
- name: Benchmark offline throughput (TP=2)
timeout-minutes: 10
@@ -210,6 +223,7 @@ jobs:
cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
+
accuracy-test-1-gpu:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index ad8ff6cbf4d..ee5ae278d13 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -42,6 +42,9 @@
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4"
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct"
+DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
+DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmzheng/sglang-EAGLE-llama2-chat-7B"
+
def is_in_ci():
"""Return whether it is in CI runner."""
@@ -538,6 +541,7 @@ def run_bench_serving(
random_input_len=4096,
random_output_len=2048,
disable_stream=False,
+ disable_ignore_eos=False,
need_warmup=False,
):
# Launch the server
@@ -572,7 +576,7 @@ def run_bench_serving(
disable_stream=disable_stream,
return_logprob=False,
seed=0,
- disable_ignore_eos=False,
+ disable_ignore_eos=disable_ignore_eos,
extra_request_body=None,
apply_chat_template=False,
profile=None,
diff --git a/test/srt/models/test_qwen_models.py b/test/srt/models/test_qwen_models.py
index 9e61930a76e..c7788fa8e50 100644
--- a/test/srt/models/test_qwen_models.py
+++ b/test/srt/models/test_qwen_models.py
@@ -37,8 +37,7 @@ def test_gsm8k(self):
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
- print(metrics)
-
+ print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.81)
@@ -69,8 +68,7 @@ def test_gsm8k(self):
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
- print(metrics)
-
+ print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.79)
diff --git a/test/srt/test_bench_one_batch.py b/test/srt/test_bench_one_batch.py
index c1bc98e8e04..c6562170d61 100644
--- a/test/srt/test_bench_one_batch.py
+++ b/test/srt/test_bench_one_batch.py
@@ -5,24 +5,46 @@
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
is_in_ci,
run_bench_one_batch,
+ write_github_step_summary,
)
class TestBenchOneBatch(unittest.TestCase):
- def test_default(self):
+ def test_bs1(self):
output_throughput = run_bench_one_batch(DEFAULT_MODEL_NAME_FOR_TEST, [])
if is_in_ci():
+ write_github_step_summary(
+ f"### test_bs1\n"
+ f"output_throughput : {output_throughput:.2f} token/s\n"
+ )
self.assertGreater(output_throughput, 135)
- def test_moe_default(self):
+ def test_moe_tp2_bs1(self):
output_throughput = run_bench_one_batch(
DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"]
)
if is_in_ci():
+ write_github_step_summary(
+ f"### test_moe_tp2_bs1\n"
+ f"output_throughput : {output_throughput:.2f} token/s\n"
+ )
self.assertGreater(output_throughput, 125)
+ def test_torch_compile_tp2_bs1(self):
+ output_throughput = run_bench_one_batch(
+ DEFAULT_MODEL_NAME_FOR_TEST,
+ ["--tp", "2", "--enable-torch-compile", "--cuda-graph-max-bs", "2"],
+ )
+
+ if is_in_ci():
+ write_github_step_summary(
+ f"### test_torch_compile_tp2_bs1\n"
+ f"output_throughput : {output_throughput:.2f} token/s\n"
+ )
+ self.assertGreater(output_throughput, 240)
+
if __name__ == "__main__":
unittest.main()
diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py
index b882f12f9df..b55260f71a6 100644
--- a/test/srt/test_bench_serving.py
+++ b/test/srt/test_bench_serving.py
@@ -1,6 +1,8 @@
import unittest
from sglang.test.test_utils import (
+ DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
+ DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_FP8_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
@@ -47,7 +49,7 @@ def test_offline_throughput_non_stream_small_batch_size(self):
)
# There is a regression with torch 2.5
# This number was 950 for torch 2.4
- self.assertGreater(res["output_throughput"], 800)
+ self.assertGreater(res["output_throughput"], 850)
def test_offline_throughput_without_radix_cache(self):
res = run_bench_serving(
@@ -131,6 +133,36 @@ def test_online_latency_default(self):
self.assertLess(res["median_ttft_ms"], 86)
self.assertLess(res["median_itl_ms"], 10)
+ def test_online_latency_eagle(self):
+ res = run_bench_serving(
+ model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
+ num_prompts=50,
+ request_rate=1,
+ disable_ignore_eos=True,
+ dataset_name="sharegpt",
+ other_server_args=[
+ "--speculative-algorithm",
+ "EAGLE",
+ "--speculative-draft-model-path",
+ DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
+ "--speculative-num-steps",
+ "5",
+ "--speculative-eagle-topk",
+ "8",
+ "--speculative-num-draft-tokens",
+ "64",
+ "--mem-fraction-static",
+ "0.7",
+ ],
+ )
+
+ if is_in_ci():
+ write_github_step_summary(
+ f"### test_online_latency_eagle\n"
+ f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
+ )
+ self.assertLess(res["median_e2e_latency_ms"], 10000)
+
def test_moe_offline_throughput_default(self):
res = run_bench_serving(
model=DEFAULT_MOE_MODEL_NAME_FOR_TEST,
diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py
index 92127b8ef59..b01c260496a 100644
--- a/test/srt/test_eagle_infer.py
+++ b/test/srt/test_eagle_infer.py
@@ -1,14 +1,18 @@
-import multiprocessing
import random
+import threading
import time
import unittest
+from types import SimpleNamespace
import requests
-from transformers import AutoConfig, AutoTokenizer
import sglang as sgl
+from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
+from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
+ DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
+ DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
@@ -19,60 +23,59 @@ class TestEAGLEEngine(unittest.TestCase):
def test_eagle_accuracy(self):
prompt = "Today is a sunny day and I like"
- target_model_path = "meta-llama/Llama-2-7b-chat-hf"
- speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B"
-
sampling_params = {"temperature": 0, "max_new_tokens": 8}
+ # Get the reference output
+ ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
+ ref_output = ref_engine.generate(prompt, sampling_params)["text"]
+ ref_engine.shutdown()
+
+ # Launch EAGLE engine
engine = sgl.Engine(
- model_path=target_model_path,
- speculative_draft_model_path=speculative_draft_model_path,
+ model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
+ speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
speculative_algorithm="EAGLE",
- speculative_num_steps=3,
- speculative_eagle_topk=4,
- speculative_num_draft_tokens=16,
+ speculative_num_steps=5,
+ speculative_eagle_topk=8,
+ speculative_num_draft_tokens=64,
+ mem_fraction_static=0.7,
)
- out1 = engine.generate(prompt, sampling_params)["text"]
- engine.shutdown()
-
- engine = sgl.Engine(model_path=target_model_path)
- out2 = engine.generate(prompt, sampling_params)["text"]
- engine.shutdown()
- print("==== Answer 1 ====")
- print(out1)
-
- print("==== Answer 2 ====")
- print(out2)
- self.assertEqual(out1, out2)
+ # Case 1: Test the output of EAGLE engine is the same as normal engine
+ out1 = engine.generate(prompt, sampling_params)["text"]
+ print(f"{out1=}, {ref_output=}")
+ self.assertEqual(out1, ref_output)
- def test_eagle_end_check(self):
+ # Case 2: Test the output of EAGLE engine does not contain unexpected EOS
prompt = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like [/INST]"
- target_model_path = "meta-llama/Llama-2-7b-chat-hf"
- tokenizer = AutoTokenizer.from_pretrained(target_model_path)
- speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B"
-
sampling_params = {
"temperature": 0,
"max_new_tokens": 1024,
"skip_special_tokens": False,
}
- engine = sgl.Engine(
- model_path=target_model_path,
- speculative_draft_model_path=speculative_draft_model_path,
- speculative_algorithm="EAGLE",
- speculative_num_steps=3,
- speculative_eagle_topk=4,
- speculative_num_draft_tokens=16,
- )
- out1 = engine.generate(prompt, sampling_params)["text"]
- engine.shutdown()
- print("==== Answer 1 ====")
- print(repr(out1))
- tokens = tokenizer.encode(out1, truncation=False)
+ tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
+ out2 = engine.generate(prompt, sampling_params)["text"]
+ print(f"{out2=}")
+ tokens = tokenizer.encode(out2, truncation=False)
assert tokenizer.eos_token_id not in tokens
+ # Case 3: Batched prompts
+ prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+ ]
+ sampling_params = {"temperature": 0, "max_new_tokens": 30}
+ outputs = engine.generate(prompts, sampling_params)
+ for prompt, output in zip(prompts, outputs):
+ print("===============================")
+ print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
+
+ # Shutdown the engine
+ engine.shutdown()
+
prompts = [
"[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]"
@@ -83,64 +86,27 @@ def test_eagle_end_check(self):
]
-def process(server_url: str):
- time.sleep(random.uniform(0, 2))
- for prompt in prompts:
- url = server_url
- data = {
- "model": "base",
- "text": prompt,
- "sampling_params": {
- "temperature": 0,
- "max_new_tokens": 1024,
- },
- }
- response = requests.post(url, json=data)
- assert response.status_code == 200
-
-
-def abort_process(server_url: str):
- for prompt in prompts:
- try:
- time.sleep(1)
- url = server_url
- data = {
- "model": "base",
- "text": prompt,
- "sampling_params": {
- "temperature": 0,
- "max_new_tokens": 1024,
- },
- }
- # set timeout = 1s,mock disconnected
- requests.post(url, json=data, timeout=1)
- except:
- pass
-
-
-class TestEAGLELaunchServer(unittest.TestCase):
+class TestEAGLEServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
- speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B"
- cls.model = "meta-llama/Llama-2-7b-chat-hf"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
- cls.model,
+ DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
- speculative_draft_model_path,
+ DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
- "3",
+ "5",
"--speculative-eagle-topk",
- "4",
+ "8",
"--speculative-num-draft-tokens",
- "16",
- "--served-model-name",
- "base",
+ "64",
+ "--mem-fraction-static",
+ "0.7",
],
)
@@ -148,40 +114,67 @@ def setUpClass(cls):
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
- def test_eagle_server_concurrency(self):
- concurrency = 4
- processes = [
- multiprocessing.Process(
- target=process,
- kwargs={"server_url": self.base_url + "/generate"},
- )
- for _ in range(concurrency)
- ]
- for worker in processes:
- worker.start()
- for p in processes:
- p.join()
-
- def test_eagle_server_request_abort(self):
+ def send_request(self):
+ time.sleep(random.uniform(0, 2))
+ for prompt in prompts:
+ url = self.base_url + "/generate"
+ data = {
+ "text": prompt,
+ "sampling_params": {
+ "temperature": 0,
+ "max_new_tokens": 1024,
+ },
+ }
+ response = requests.post(url, json=data)
+ assert response.status_code == 200
+
+ def send_requests_abort(self):
+ for prompt in prompts:
+ try:
+ time.sleep(random.uniform(0, 2))
+ url = self.base_url + "/generate"
+ data = {
+ "model": "base",
+ "text": prompt,
+ "sampling_params": {
+ "temperature": 0,
+ "max_new_tokens": 1024,
+ },
+ }
+ # set timeout = 1s,mock disconnected
+ requests.post(url, json=data, timeout=1)
+ except Exception as e:
+ print(e)
+ pass
+
+ def test_request_abort(self):
concurrency = 4
- processes = [
- multiprocessing.Process(
- target=process,
- kwargs={"server_url": self.base_url + "/generate"},
- )
- for _ in range(concurrency)
+ threads = [
+ threading.Thread(target=self.send_request) for _ in range(concurrency)
] + [
- multiprocessing.Process(
- target=abort_process,
- kwargs={"server_url": self.base_url + "/generate"},
- )
+ threading.Thread(target=self.send_requests_abort)
for _ in range(concurrency)
]
- for worker in processes:
+ for worker in threads:
worker.start()
- for p in processes:
+ for p in threads:
p.join()
+ def test_gsm8k(self):
+ args = SimpleNamespace(
+ num_shots=5,
+ data_path=None,
+ num_questions=200,
+ max_new_tokens=512,
+ parallel=128,
+ host="http://127.0.0.1",
+ port=int(self.base_url.split(":")[-1]),
+ )
+ metrics = run_eval(args)
+ print(f"{metrics=}")
+
+ self.assertGreater(metrics["accuracy"], 0.20)
+
if __name__ == "__main__":
unittest.main()
diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py
index 6f3b344b3cc..e71de339117 100644
--- a/test/srt/test_torch_compile.py
+++ b/test/srt/test_torch_compile.py
@@ -23,7 +23,7 @@ def setUpClass(cls):
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
- other_args=["--enable-torch-compile"],
+ other_args=["--enable-torch-compile", "--cuda-graph-max-bs", "4"],
)
@classmethod
From 5a0d680a14fc9aa29b2640b69baaf5d28d5975b9 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Tue, 21 Jan 2025 20:44:49 +0800
Subject: [PATCH 051/147] feat: add flashinfer as 3rdparty and use rmsnorm as
example (#3033)
---
.github/workflows/pr-test-sgl-kernel.yml | 1 +
.gitignore | 2 +
.gitmodules | 3 +
sgl-kernel/3rdparty/flashinfer | 1 +
sgl-kernel/THIRDPARTYNOTICES.txt | 225 ++++++++++++++++++
sgl-kernel/setup.py | 21 +-
sgl-kernel/src/sgl-kernel/__init__.py | 2 +
sgl-kernel/src/sgl-kernel/csrc/norm.cu | 28 +++
.../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 5 +
sgl-kernel/src/sgl-kernel/ops/__init__.py | 18 ++
sgl-kernel/tests/test_rmsnorm.py | 31 +++
11 files changed, 335 insertions(+), 2 deletions(-)
create mode 160000 sgl-kernel/3rdparty/flashinfer
create mode 100644 sgl-kernel/THIRDPARTYNOTICES.txt
create mode 100644 sgl-kernel/src/sgl-kernel/csrc/norm.cu
create mode 100644 sgl-kernel/tests/test_rmsnorm.py
diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml
index 31360c0a068..0c29322a402 100644
--- a/.github/workflows/pr-test-sgl-kernel.yml
+++ b/.github/workflows/pr-test-sgl-kernel.yml
@@ -41,6 +41,7 @@ jobs:
- name: Install
run: |
pip3 install torch==2.5.1
+ pip3 install pytest
pip3 uninstall sgl-kernel -y || true
cd sgl-kernel
pip3 install .
diff --git a/.gitignore b/.gitignore
index 91966c664b5..75e29fac373 100644
--- a/.gitignore
+++ b/.gitignore
@@ -225,3 +225,5 @@ compile_commands.json
# VSCode
.vscode
+
+1
diff --git a/.gitmodules b/.gitmodules
index c584a21e8bd..ed7603bfd3c 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -4,3 +4,6 @@
[submodule "sgl-kernel/3rdparty/cccl"]
path = sgl-kernel/3rdparty/cccl
url = https://github.com/NVIDIA/cccl.git
+[submodule "sgl-kernel/3rdparty/flashinfer"]
+ path = sgl-kernel/3rdparty/flashinfer
+ url = https://github.com/flashinfer-ai/flashinfer.git
diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer
new file mode 160000
index 00000000000..a0e99a3a820
--- /dev/null
+++ b/sgl-kernel/3rdparty/flashinfer
@@ -0,0 +1 @@
+Subproject commit a0e99a3a820109763d9a757138a5cdf7bbcd1f85
diff --git a/sgl-kernel/THIRDPARTYNOTICES.txt b/sgl-kernel/THIRDPARTYNOTICES.txt
new file mode 100644
index 00000000000..c930aa5dd3d
--- /dev/null
+++ b/sgl-kernel/THIRDPARTYNOTICES.txt
@@ -0,0 +1,225 @@
+Notice for flashinfer-ai/flashinfer
+-------------------------------
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+-------------------------------------------------------------------------------------------------
+Some of the code in this project are adapted from other open-source projects with different
+licenses. This product also bundles some third-party components under other open source licenses.
+This section summarizes those components and their licenses.
+See licenses/ for text of these licenses.
+
+BSD 3-Clause License
+--------------------
+
+include/flashinfer/attention/hopper/epilogue.cuh
+include/flashinfer/attention/hopper/mainloop.cuh
+include/flashinfer/attention/hopper/kernel_traits.cuh
+include/flashinfer/attention/hopper/named_barrier.cuh
+include/flashinfer/attention/hopper/tile_scheduler.cuh
+include/flashinfer/attention/hopper/utils.cuh
+
+BSD 3-Clause "New" License
+--------------------------
+
+3rdparty/cutlass
+include/flashinfer/attention/hopper/block_sparse_gather.cuh
diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py
index 9f986711338..a8d9517bb25 100644
--- a/sgl-kernel/setup.py
+++ b/sgl-kernel/setup.py
@@ -1,5 +1,6 @@
from pathlib import Path
+import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
@@ -24,10 +25,13 @@ def update_wheel_platform_tag():
cutlass = root / "3rdparty" / "cutlass"
+flashinfer = root / "3rdparty" / "flashinfer"
include_dirs = [
cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include",
root / "src" / "sgl-kernel" / "csrc",
+ flashinfer.resolve() / "include",
+ flashinfer.resolve() / "csrc",
]
nvcc_flags = [
"-DNDEBUG",
@@ -39,9 +43,21 @@ def update_wheel_platform_tag():
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-gencode=arch=compute_90a,code=sm_90a",
- "-U__CUDA_NO_HALF_OPERATORS__",
- "-U__CUDA_NO_HALF2_OPERATORS__",
+ "-std=c++17",
+ "-use_fast_math",
+ "-DFLASHINFER_ENABLE_F16",
+ "-DFLASHINFER_ENABLE_BF16",
]
+for flag in [
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+]:
+ try:
+ torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
+ except ValueError:
+ pass
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python", "cuda"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
@@ -56,6 +72,7 @@ def update_wheel_platform_tag():
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
+ "src/sgl-kernel/csrc/norm.cu",
],
include_dirs=include_dirs,
extra_compile_args={
diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py
index 480bec71f36..3352abeb550 100644
--- a/sgl-kernel/src/sgl-kernel/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/__init__.py
@@ -6,6 +6,7 @@
int8_scaled_mm,
moe_align_block_size,
register_graph_buffers,
+ rmsnorm,
rotary_embedding,
sampling_scaling_penalties,
)
@@ -20,4 +21,5 @@
"get_graph_buffer_ipc_meta",
"register_graph_buffers",
"rotary_embedding",
+ "rmsnorm",
]
diff --git a/sgl-kernel/src/sgl-kernel/csrc/norm.cu b/sgl-kernel/src/sgl-kernel/csrc/norm.cu
new file mode 100644
index 00000000000..ad102a50d3f
--- /dev/null
+++ b/sgl-kernel/src/sgl-kernel/csrc/norm.cu
@@ -0,0 +1,28 @@
+#include
+#include
+
+#include "pytorch_extension_utils.h"
+
+using namespace flashinfer;
+
+void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream) {
+ CHECK_INPUT(input);
+ CHECK_INPUT(weight);
+ auto device = input.device();
+ CHECK_EQ(weight.device(), device);
+ CHECK_DIM(2, input); // input: (batch_size, hidden_size)
+ CHECK_DIM(1, weight); // weight: (hidden_size)
+ CHECK_EQ(input.size(1), weight.size(0));
+ unsigned int batch_size = input.size(0);
+ unsigned int hidden_size = input.size(1);
+ CHECK_EQ(output.size(0), batch_size);
+ CHECK_EQ(output.size(1), hidden_size);
+
+ cudaStream_t stream = reinterpret_cast(cuda_stream);
+ DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
+ cudaError_t status = norm::RMSNorm(static_cast(input.data_ptr()), static_cast(weight.data_ptr()),
+ static_cast(output.data_ptr()), batch_size, hidden_size, eps, stream);
+ TORCH_CHECK(status == cudaSuccess, "RMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
+ return true;
+ });
+}
diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
index f2ae95d7f79..ed359bfbb0a 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
@@ -30,6 +30,9 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
+// rms norm
+void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
@@ -45,4 +48,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
// rotary embedding
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
+ // rms norm
+ m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)");
}
diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py
index b8abd57d39d..e9eadb759cf 100644
--- a/sgl-kernel/src/sgl-kernel/ops/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py
@@ -1,3 +1,6 @@
+from typing import Optional
+
+import torch
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
from sgl_kernel.ops._kernels import dispose as _dispose
from sgl_kernel.ops._kernels import (
@@ -7,6 +10,7 @@
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
+from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
from sgl_kernel.ops._kernels import (
sampling_scaling_penalties as _sampling_scaling_penalties,
@@ -76,3 +80,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
+
+
+def rmsnorm(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+ out: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ if out is None:
+ out = torch.empty_like(input)
+ stream = torch.cuda.current_stream().cuda_stream
+ stream_int = int(stream)
+ _rmsnorm(out, input, weight, eps, stream_int)
+ return out
diff --git a/sgl-kernel/tests/test_rmsnorm.py b/sgl-kernel/tests/test_rmsnorm.py
new file mode 100644
index 00000000000..dda225de9e3
--- /dev/null
+++ b/sgl-kernel/tests/test_rmsnorm.py
@@ -0,0 +1,31 @@
+import pytest
+import torch
+from sgl_kernel import rmsnorm
+
+
+def llama_rms_norm(x, w, eps=1e-6):
+ orig_dtype = x.dtype
+ x = x.float()
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
+ x = x * torch.rsqrt(variance + eps)
+ x = x * w.float()
+ x = x.to(orig_dtype)
+ return x
+
+
+@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
+@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
+@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("specify_out", [True, False])
+def test_norm(batch_size, hidden_size, dtype, specify_out):
+ x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
+ w = torch.randn(hidden_size).to(0).to(dtype)
+
+ y_ref = llama_rms_norm(x, w)
+ if specify_out:
+ y = torch.empty_like(x)
+ rmsnorm(x, w, out=y)
+ else:
+ y = rmsnorm(x, w)
+
+ torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
From 0ac019f17189e2ba3a3bab047cf441e060d339a1 Mon Sep 17 00:00:00 2001
From: Ke Bao
Date: Tue, 21 Jan 2025 22:21:54 +0800
Subject: [PATCH 052/147] Support sm90 Int8 gemm (#3035)
---
.../src/sgl-kernel/csrc/int8_gemm_kernel.cu | 210 +++++++++++++++++-
sgl-kernel/tests/test_int8_gemm.py | 2 +-
2 files changed, 210 insertions(+), 2 deletions(-)
diff --git a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
index cce32c2d894..8e3f7275702 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
@@ -3,13 +3,23 @@
#include
#include
#include
+#include
#include
+#include
+#include
+#include
+#include
+#include
+#include
+
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
#include "cutlass_extensions/gemm/gemm_universal_base_compat.h"
#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h"
#include "utils.hpp"
+using namespace cute;
+
template
void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
@@ -166,6 +176,186 @@ void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const t
}
}
+template
+void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
+ const torch::Tensor& scales_a, const torch::Tensor& scales_b,
+ const c10::optional& bias) {
+ using ArchTag = cutlass::arch::Sm90;
+
+ using ElementAccumulator = int32_t;
+ using ElementCompute = float;
+ using ElementInputA = int8_t;
+ using ElementInputB = int8_t;
+
+ static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value;
+ static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value;
+ static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value;
+ static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value;
+
+ using OperatorClass = cutlass::arch::OpClassTensorOp;
+
+ using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized;
+ using TileSchedulerType = cutlass::gemm::PersistentScheduler;
+
+ using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute,
+ Stride, Int<0>, Int<0>>>;
+
+ using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute,
+ Stride, Int<1>, Int<0>>>;
+
+ using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput,
+ Stride, Int<1>, Int<0>>>;
+
+ using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
+
+ // Scale
+ using Compute0 = cutlass::epilogue::fusion::Sm90Compute;
+
+ using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT;
+
+ using Compute1 = cutlass::epilogue::fusion::Sm90Compute;
+
+ using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT;
+
+ // With bias
+ using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute;
+ using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT;
+
+ using EpilogueEVT = typename cutlass::platform::conditional::type;
+
+ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
+ ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
+ ElementAccumulator, ElementCompute, ElementOutput, cutlass::layout::RowMajor, AlignmentC, ElementOutput,
+ cutlass::layout::RowMajor, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp;
+
+ using Stages = cutlass::gemm::collective::StageCountAutoCarveout(
+ sizeof(typename CollectiveEpilogue::SharedStorage))>;
+
+ using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
+ ArchTag, OperatorClass, ElementInputA, cutlass::layout::RowMajor, AlignmentA, ElementInputB,
+ cutlass::layout::ColumnMajor, AlignmentB, ElementAccumulator, TileShape, ClusterShape, Stages,
+ MainloopScheduleType>::CollectiveOp;
+
+ using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape
+ CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>;
+
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
+
+ Gemm gemm_op;
+
+ int m = mat_a.size(0);
+ int k = mat_a.size(1);
+ int n = mat_b.size(1);
+
+ auto a_ptr = static_cast(mat_a.data_ptr());
+ auto b_ptr = static_cast(mat_b.data_ptr());
+ auto o_ptr = static_cast(out.data_ptr());
+
+ auto a_s_ptr = static_cast(scales_a.data_ptr());
+ auto b_s_ptr = static_cast(scales_b.data_ptr());
+
+ using StrideA = typename Gemm::GemmKernel::StrideA;
+ using StrideB = typename Gemm::GemmKernel::StrideB;
+ using StrideC = typename Gemm::GemmKernel::StrideC;
+ using StrideD = typename Gemm::GemmKernel::StrideD;
+
+ StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1));
+ StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1));
+ StrideC stride_c;
+ StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1));
+
+ typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm,
+ {m, n, k, 1},
+ {a_ptr, stride_a, b_ptr, stride_b},
+ {{}, // epilogue.thread
+ nullptr,
+ stride_c,
+ o_ptr,
+ stride_d}};
+
+ if constexpr (WithBias) {
+ ElementOutput* bias_ptr = static_cast(bias->data_ptr());
+ args.epilogue.thread = {
+ {a_s_ptr},
+ {{b_s_ptr}, {}, {}},
+ {bias_ptr},
+ {},
+ };
+ } else {
+ args.epilogue.thread = {
+ {a_s_ptr},
+ {{b_s_ptr}, {}, {}},
+ {},
+ };
+ }
+
+ auto workspace = torch::empty(gemm_op.get_workspace_size(args),
+ torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
+
+ auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
+
+ auto can_implement = gemm_op.can_implement(args);
+ TORCH_CHECK(can_implement == cutlass::Status::kSuccess,
+ "gemm cannot implement, error: ", cutlassGetStatusString(can_implement));
+
+ auto status = gemm_op(args, workspace.data_ptr(), stream);
+ TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
+}
+
+template
+void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
+ const torch::Tensor& scales_a, const torch::Tensor& scales_b,
+ const c10::optional& bias) {
+ if (bias) {
+ cutlass_int8_scaled_mm_sm90(
+ out, mat_a, mat_b, scales_a, scales_b, bias);
+ } else {
+ cutlass_int8_scaled_mm_sm90(
+ out, mat_a, mat_b, scales_a, scales_b, bias);
+ }
+}
+
+template
+void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b,
+ const torch::Tensor& scales_a, const torch::Tensor& scales_b,
+ const c10::optional& bias) {
+ int m = mat_a.size(0);
+ int n = mat_b.size(1);
+ if (m <= 32) {
+ if (n < 8192) {
+ return sm90_dispatch_bias, Shape<_1, _8, _1>,
+ cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
+ } else {
+ return sm90_dispatch_bias, Shape<_1, _8, _1>,
+ cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
+ }
+ } else if (m <= 64) {
+ if (n < 8192) {
+ return sm90_dispatch_bias, Shape<_1, _4, _1>,
+ cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
+ } else {
+ return sm90_dispatch_bias, Shape<_1, _1, _1>,
+ cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
+ }
+ } else if (m <= 128) {
+ if (n <= 4096) {
+ return sm90_dispatch_bias, Shape<_2, _1, _1>,
+ cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
+ } else {
+ return sm90_dispatch_bias, Shape<_2, _1, _1>,
+ cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias);
+ }
+ } else {
+ return sm90_dispatch_bias, Shape<_2, _1, _1>,
+ cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b,
+ bias);
+ }
+}
+
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional& bias) {
@@ -204,7 +394,24 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
TORCH_CHECK(out_dtype == torch::kHalf, "out_dtype must be Half for SM75");
sm75_dispatch_shape>(
out, mat_a, mat_b, scales_a, scales_b, bias);
- } else if (sm_version >= 80 && sm_version <= 90) {
+ } else if (sm_version >= 80 && sm_version < 90) {
+ if (out_dtype == torch::kBFloat16) {
+ sm80_dispatch_shape>(
+ out, mat_a, mat_b, scales_a, scales_b, bias);
+ } else {
+ sm80_dispatch_shape>(
+ out, mat_a, mat_b, scales_a, scales_b, bias);
+ }
+ } else if (sm_version == 90) {
+#if defined CUDA_VERSION && CUDA_VERSION >= 12000
+ // cutlass 3.x
+ if (out_dtype == torch::kBFloat16) {
+ sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias);
+ } else {
+ sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias);
+ }
+#else
+ // fallback to cutlass 2.x
if (out_dtype == torch::kBFloat16) {
sm80_dispatch_shape>(
out, mat_a, mat_b, scales_a, scales_b, bias);
@@ -212,6 +419,7 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
sm80_dispatch_shape>(
out, mat_a, mat_b, scales_a, scales_b, bias);
}
+#endif
} else {
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability.");
}
diff --git a/sgl-kernel/tests/test_int8_gemm.py b/sgl-kernel/tests/test_int8_gemm.py
index 34d17d1c76a..c33a3effcaf 100644
--- a/sgl-kernel/tests/test_int8_gemm.py
+++ b/sgl-kernel/tests/test_int8_gemm.py
@@ -25,7 +25,7 @@ def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device):
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
if with_bias:
- bias = torch.ones((N,), device="cuda", dtype=out_dtype) * 10
+ bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10
else:
bias = None
From a42213dbd4d952e9484ce0415ea53939d74a51db Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Wed, 22 Jan 2025 00:56:42 +0800
Subject: [PATCH 053/147] fix pr-test-sgl-kernel (#3036)
---
.github/workflows/pr-test-sgl-kernel.yml | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml
index 0c29322a402..3d980265831 100644
--- a/.github/workflows/pr-test-sgl-kernel.yml
+++ b/.github/workflows/pr-test-sgl-kernel.yml
@@ -35,15 +35,13 @@ jobs:
runs-on: 1-gpu-runner
steps:
- uses: actions/checkout@v4
- with:
- submodules: 'recursive'
- name: Install
run: |
- pip3 install torch==2.5.1
- pip3 install pytest
+ pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm
pip3 uninstall sgl-kernel -y || true
cd sgl-kernel
+ git submodule deinit --all --force && git submodule sync --recursive && git submodule update --init --force --recursive
pip3 install .
pip3 list | grep sgl-kernel
From 3d8f1c9bcf50b4d41fcf4ad2dfc430230276b4ab Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Tue, 21 Jan 2025 19:46:09 -0800
Subject: [PATCH 054/147] Use int64 as indices for set_kv_buffer (#3039)
---
python/sglang/bench_one_batch.py | 8 ++---
python/sglang/srt/layers/logits_processor.py | 2 +-
python/sglang/srt/layers/sampler.py | 7 ++--
python/sglang/srt/managers/schedule_batch.py | 14 ++++----
.../srt/model_executor/cuda_graph_runner.py | 32 +++++++++----------
.../srt/model_executor/forward_batch_info.py | 4 +--
6 files changed, 30 insertions(+), 37 deletions(-)
diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py
index e01919399b5..bc7a9c7a1a7 100644
--- a/python/sglang/bench_one_batch.py
+++ b/python/sglang/bench_one_batch.py
@@ -99,10 +99,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--correctness-test", action="store_true")
parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
parser.add_argument(
- "--profile",
- action="store_true",
- help="Use Torch Profiler. The endpoint must be launched with "
- "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
+ "--profile", action="store_true", help="Use Torch Profiler."
)
parser.add_argument(
"--profile-filename-prefix",
@@ -381,6 +378,7 @@ def latency_test_run_once(
parent_dir = os.path.dirname(os.path.abspath(profile_filename))
os.makedirs(parent_dir, exist_ok=True)
profiler.export_chrome_trace(profile_filename)
+ rank_print(f"torch profiler chrome trace saved to {profile_filename}")
# Record decode timing from 2nd output
if output_len > 1:
@@ -451,7 +449,7 @@ def latency_test(
il,
ol,
server_args.device,
- bench_args.profile,
+ bench_args.profile if tp_rank == 0 else None,
bench_args.profile_filename_prefix,
)
if ret is not None:
diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py
index e5794f052c3..08ee5a3509b 100644
--- a/python/sglang/srt/layers/logits_processor.py
+++ b/python/sglang/srt/layers/logits_processor.py
@@ -296,7 +296,7 @@ def fused_softcap_kernel(
n_elements,
BLOCK_SIZE: tl.constexpr,
):
- pid = tl.program_id(0)
+ pid = tl.program_id(0).to(tl.int64)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py
index ebaa1aa0e7e..f3c376ed1eb 100644
--- a/python/sglang/srt/layers/sampler.py
+++ b/python/sglang/srt/layers/sampler.py
@@ -1,12 +1,11 @@
import logging
-from typing import Dict, List
+from typing import List
import torch
from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
-from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
@@ -109,8 +108,6 @@ def forward(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
- batch_next_token_ids = batch_next_token_ids.to(torch.int32)
-
# Attach logprobs to logits_output (in-place modification)
if return_logprob:
if any(x > 0 for x in top_logprobs_nums):
@@ -124,7 +121,7 @@ def forward(
batch_next_token_ids,
]
- return batch_next_token_ids
+ return batch_next_token_ids.to(torch.int32)
def _apply_custom_logit_processor(
self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index d9af8151534..6c44b17ffd8 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -550,13 +550,13 @@ class ScheduleBatch:
next_batch_sampling_info: SamplingBatchInfo = None
# Batched arguments to model runner
- input_ids: torch.Tensor = None
- input_embeds: torch.Tensor = None
- req_pool_indices: torch.Tensor = None
- seq_lens: torch.Tensor = None
+ input_ids: torch.Tensor = None # shape: [b], int32
+ input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
+ req_pool_indices: torch.Tensor = None # shape: [b], int32
+ seq_lens: torch.Tensor = None # shape: [b], int64
# The output locations of the KV cache
- out_cache_loc: torch.Tensor = None
- output_ids: torch.Tensor = None
+ out_cache_loc: torch.Tensor = None # shape: [b], int32
+ output_ids: torch.Tensor = None # shape: [b], int32
# The sum of all sequence lengths
seq_lens_sum: int = None
@@ -1026,7 +1026,7 @@ def prepare_for_idle(self):
self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
- self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device)
+ self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0
self.extend_num_tokens = 0
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py
index 762dac140fb..169b6434368 100644
--- a/python/sglang/srt/model_executor/cuda_graph_runner.py
+++ b/python/sglang/srt/model_executor/cuda_graph_runner.py
@@ -24,7 +24,7 @@
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
-from sglang.srt.distributed.parallel_state import graph_capture
+from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache
@@ -63,7 +63,7 @@ def patch_model(
model: torch.nn.Module,
enable_compile: bool,
batch_size: int,
- tp_group: "GroupCoordinator",
+ tp_group: GroupCoordinator,
):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm = None
@@ -149,9 +149,18 @@ def __init__(self, model_runner: "ModelRunner"):
and bs <= model_runner.server_args.cuda_graph_max_bs
]
+ self.compile_bs = (
+ [
+ bs
+ for bs in self.capture_bs
+ if bs <= self.model_runner.server_args.torch_compile_max_bs
+ ]
+ if self.use_torch_compile
+ else []
+ )
+
self.capture_forward_mode = ForwardMode.DECODE
self.num_tokens_per_bs = 1
-
if model_runner.spec_algorithm.is_eagle():
if self.model_runner.is_draft_worker:
self.num_tokens_per_bs = (
@@ -163,16 +172,6 @@ def __init__(self, model_runner: "ModelRunner"):
self.model_runner.server_args.speculative_num_draft_tokens
)
- self.compile_bs = (
- [
- bs
- for bs in self.capture_bs
- if bs <= self.model_runner.server_args.torch_compile_max_bs
- ]
- if self.use_torch_compile
- else []
- )
-
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
@@ -180,7 +179,6 @@ def __init__(self, model_runner: "ModelRunner"):
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
-
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
@@ -189,14 +187,14 @@ def __init__(self, model_runner: "ModelRunner"):
# Common inputs
with torch.device("cuda"):
- self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
+ self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
- self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
+ self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
- self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
+ self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
# Speculative_inference
if model_runner.spec_algorithm.is_eagle():
diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py
index 8ef5c57b891..8bd1052754c 100644
--- a/python/sglang/srt/model_executor/forward_batch_info.py
+++ b/python/sglang/srt/model_executor/forward_batch_info.py
@@ -38,7 +38,7 @@
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
-from sglang.srt.utils import maybe_torch_compile
+from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention import AttentionBackend
@@ -415,6 +415,6 @@ def compute_position_torch(
return positions.to(torch.int64), extend_start_loc
-@maybe_torch_compile(dynamic=True)
+@torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
From 6fc37bd8ee3535673835712fa76a973bda0cb450 Mon Sep 17 00:00:00 2001
From: Ke Bao
Date: Wed, 22 Jan 2025 16:49:08 +0800
Subject: [PATCH 055/147] Fix sgl-kernel compile for sm80 (#3046)
---
sgl-kernel/setup.py | 21 ++++++++++++++++++++-
1 file changed, 20 insertions(+), 1 deletion(-)
diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py
index a8d9517bb25..1aea485ff8f 100644
--- a/sgl-kernel/setup.py
+++ b/sgl-kernel/setup.py
@@ -24,6 +24,22 @@ def update_wheel_platform_tag():
old_wheel.rename(new_wheel)
+def get_cuda_version():
+ if torch.version.cuda:
+ return tuple(map(int, torch.version.cuda.split(".")))
+ return (0, 0)
+
+
+def get_device_sm():
+ if torch.cuda.is_available():
+ major, minor = torch.cuda.get_device_capability()
+ return major * 10 + minor
+ return 0
+
+
+cuda_version = get_cuda_version()
+sm_version = get_device_sm()
+
cutlass = root / "3rdparty" / "cutlass"
flashinfer = root / "3rdparty" / "flashinfer"
include_dirs = [
@@ -42,12 +58,15 @@ def update_wheel_platform_tag():
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
- "-gencode=arch=compute_90a,code=sm_90a",
"-std=c++17",
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
"-DFLASHINFER_ENABLE_BF16",
]
+
+if cuda_version >= (12, 0) and sm_version >= 90:
+ nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
+
for flag in [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
From 9f8f2c7f749523070a8b259f843cd42acceb9963 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Wed, 22 Jan 2025 18:58:44 +0800
Subject: [PATCH 056/147] update norm cu (#3048)
---
sgl-kernel/setup.py | 2 +-
sgl-kernel/src/sgl-kernel/csrc/norm.cu | 28 --------------------------
2 files changed, 1 insertion(+), 29 deletions(-)
delete mode 100644 sgl-kernel/src/sgl-kernel/csrc/norm.cu
diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py
index 1aea485ff8f..1197611d6a2 100644
--- a/sgl-kernel/setup.py
+++ b/sgl-kernel/setup.py
@@ -91,7 +91,7 @@ def get_device_sm():
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
- "src/sgl-kernel/csrc/norm.cu",
+ "3rdparty/flashinfer/csrc/norm.cu",
],
include_dirs=include_dirs,
extra_compile_args={
diff --git a/sgl-kernel/src/sgl-kernel/csrc/norm.cu b/sgl-kernel/src/sgl-kernel/csrc/norm.cu
deleted file mode 100644
index ad102a50d3f..00000000000
--- a/sgl-kernel/src/sgl-kernel/csrc/norm.cu
+++ /dev/null
@@ -1,28 +0,0 @@
-#include
-#include
-
-#include "pytorch_extension_utils.h"
-
-using namespace flashinfer;
-
-void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream) {
- CHECK_INPUT(input);
- CHECK_INPUT(weight);
- auto device = input.device();
- CHECK_EQ(weight.device(), device);
- CHECK_DIM(2, input); // input: (batch_size, hidden_size)
- CHECK_DIM(1, weight); // weight: (hidden_size)
- CHECK_EQ(input.size(1), weight.size(0));
- unsigned int batch_size = input.size(0);
- unsigned int hidden_size = input.size(1);
- CHECK_EQ(output.size(0), batch_size);
- CHECK_EQ(output.size(1), hidden_size);
-
- cudaStream_t stream = reinterpret_cast(cuda_stream);
- DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
- cudaError_t status = norm::RMSNorm(static_cast(input.data_ptr()), static_cast(weight.data_ptr()),
- static_cast(output.data_ptr()), batch_size, hidden_size, eps, stream);
- TORCH_CHECK(status == cudaSuccess, "RMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
- return true;
- });
-}
From bcda0c9ee6a6e687e53ac933f3541dd5c5a1fe9b Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Wed, 22 Jan 2025 20:33:13 +0800
Subject: [PATCH 057/147] sync the upstream updates of flashinfer (#3051)
---
.github/workflows/pr-test-sgl-kernel.yml | 1 +
sgl-kernel/3rdparty/flashinfer | 2 +-
sgl-kernel/setup.py | 6 ++++++
3 files changed, 8 insertions(+), 1 deletion(-)
diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml
index 3d980265831..794a73f3661 100644
--- a/.github/workflows/pr-test-sgl-kernel.yml
+++ b/.github/workflows/pr-test-sgl-kernel.yml
@@ -40,6 +40,7 @@ jobs:
run: |
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm
pip3 uninstall sgl-kernel -y || true
+ find . -name index.lock -delete
cd sgl-kernel
git submodule deinit --all --force && git submodule sync --recursive && git submodule update --init --force --recursive
pip3 install .
diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer
index a0e99a3a820..4e8eb1879f9 160000
--- a/sgl-kernel/3rdparty/flashinfer
+++ b/sgl-kernel/3rdparty/flashinfer
@@ -1 +1 @@
-Subproject commit a0e99a3a820109763d9a757138a5cdf7bbcd1f85
+Subproject commit 4e8eb1879f9c3ba6d75511e5893183bf8f289a62
diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py
index 1197611d6a2..b9324c35543 100644
--- a/sgl-kernel/setup.py
+++ b/sgl-kernel/setup.py
@@ -47,6 +47,7 @@ def get_device_sm():
cutlass.resolve() / "tools" / "util" / "include",
root / "src" / "sgl-kernel" / "csrc",
flashinfer.resolve() / "include",
+ flashinfer.resolve() / "include" / "gemm",
flashinfer.resolve() / "csrc",
]
nvcc_flags = [
@@ -91,7 +92,12 @@ def get_device_sm():
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
+ "3rdparty/flashinfer/csrc/activation.cu",
+ "3rdparty/flashinfer/csrc/bmm_fp8.cu",
+ "3rdparty/flashinfer/csrc/group_gemm.cu",
+ "3rdparty/flashinfer/csrc/group_gemm_sm90.cu",
"3rdparty/flashinfer/csrc/norm.cu",
+ "3rdparty/flashinfer/csrc/sampling.cu",
],
include_dirs=include_dirs,
extra_compile_args={
From 7353fb9b97705c89d205aa3477b446759fcb86b7 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Wed, 22 Jan 2025 21:32:48 +0800
Subject: [PATCH 058/147] feat: integrate norm kernels into sgl-kernel (#3052)
---
sgl-kernel/src/sgl-kernel/__init__.py | 16 ++-
.../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 16 +++
sgl-kernel/src/sgl-kernel/ops/__init__.py | 45 +++++-
sgl-kernel/tests/test_norm.py | 129 ++++++++++++++++++
sgl-kernel/tests/test_rmsnorm.py | 31 -----
5 files changed, 195 insertions(+), 42 deletions(-)
create mode 100644 sgl-kernel/tests/test_norm.py
delete mode 100644 sgl-kernel/tests/test_rmsnorm.py
diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py
index 3352abeb550..bdbc0ce846c 100644
--- a/sgl-kernel/src/sgl-kernel/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/__init__.py
@@ -1,6 +1,9 @@
from sgl_kernel.ops import (
custom_dispose,
custom_reduce,
+ fused_add_rmsnorm,
+ gemma_fused_add_rmsnorm,
+ gemma_rmsnorm,
get_graph_buffer_ipc_meta,
init_custom_reduce,
int8_scaled_mm,
@@ -12,14 +15,17 @@
)
__all__ = [
- "moe_align_block_size",
- "init_custom_reduce",
"custom_dispose",
"custom_reduce",
- "int8_scaled_mm",
- "sampling_scaling_penalties",
+ "fused_add_rmsnorm",
+ "gemma_fused_add_rmsnorm",
+ "gemma_rmsnorm",
"get_graph_buffer_ipc_meta",
+ "init_custom_reduce",
+ "int8_scaled_mm",
+ "moe_align_block_size",
"register_graph_buffers",
- "rotary_embedding",
"rmsnorm",
+ "rotary_embedding",
+ "sampling_scaling_penalties",
]
diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
index ed359bfbb0a..8f9d1ae5333 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
@@ -33,6 +33,16 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Ten
// rms norm
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
+// fused rms norm
+void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
+
+// gemma rms norm
+void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
+
+// fused gemma rms norm
+void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
+ int64_t cuda_stream);
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
@@ -50,4 +60,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
// rms norm
m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)");
+ // fused rms norm
+ m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused Add RMSNorm (CUDA)");
+ // gemma rms norm
+ m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)");
+ // fused gemma rms norm
+ m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)");
}
diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py
index e9eadb759cf..bbfd76878a7 100644
--- a/sgl-kernel/src/sgl-kernel/ops/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py
@@ -3,6 +3,9 @@
import torch
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
from sgl_kernel.ops._kernels import dispose as _dispose
+from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
+from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm
+from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm
from sgl_kernel.ops._kernels import (
get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta,
)
@@ -17,6 +20,10 @@
)
+def get_cuda_stream(device: torch.device) -> int:
+ return torch.cuda.current_stream(device).cuda_stream
+
+
def init_custom_reduce(
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
):
@@ -88,9 +95,35 @@ def rmsnorm(
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
- if out is None:
- out = torch.empty_like(input)
- stream = torch.cuda.current_stream().cuda_stream
- stream_int = int(stream)
- _rmsnorm(out, input, weight, eps, stream_int)
- return out
+ with input.device as device:
+ if out is None:
+ out = torch.empty_like(input)
+ _rmsnorm(out, input, weight, eps, get_cuda_stream(device))
+ return out
+
+
+def fused_add_rmsnorm(
+ input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
+) -> None:
+ with input.device as device:
+ _fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
+
+
+def gemma_rmsnorm(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ eps: float = 1e-6,
+ out: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ with input.device as device:
+ if out is None:
+ out = torch.empty_like(input)
+ _gemma_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
+ return out
+
+
+def gemma_fused_add_rmsnorm(
+ input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
+) -> None:
+ with input.device as device:
+ _gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
diff --git a/sgl-kernel/tests/test_norm.py b/sgl-kernel/tests/test_norm.py
new file mode 100644
index 00000000000..32f8c25d9f7
--- /dev/null
+++ b/sgl-kernel/tests/test_norm.py
@@ -0,0 +1,129 @@
+# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_norm.py
+
+import pytest
+import sgl_kernel
+import torch
+
+
+def llama_rms_norm(x, w, eps=1e-6):
+ orig_dtype = x.dtype
+ x = x.float()
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
+ x = x * torch.rsqrt(variance + eps)
+ x = x * w.float()
+ x = x.to(orig_dtype)
+ return x
+
+
+def gemma_rms_norm(x, w, eps=1e-6):
+ orig_dtype = x.dtype
+ x = x.float()
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
+ x = x * torch.rsqrt(variance + eps)
+ x = x * (1.0 + w.float())
+ x = x.to(orig_dtype)
+ return x
+
+
+def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
+ orig_dtype = x.dtype
+ x = x + residual
+ residual = x
+ x = x.float()
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
+ x = x * torch.rsqrt(variance + eps)
+ x = x * (1.0 + w.float())
+ x = x.to(orig_dtype)
+ return x, residual
+
+
+def fused_add_rms_norm(x, residual, weight, eps):
+ orig_dtype = x.dtype
+ x = x.to(torch.float32)
+ x = x + residual.to(torch.float32)
+ residual = x.to(orig_dtype)
+
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
+ x = x * torch.rsqrt(variance + eps)
+ x = (x * weight.float()).to(orig_dtype)
+ return x, residual
+
+
+@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
+@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
+@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("specify_out", [True, False])
+def test_norm(batch_size, hidden_size, dtype, specify_out):
+ x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
+ w = torch.randn(hidden_size).to(0).to(dtype)
+
+ y_ref = llama_rms_norm(x, w)
+ if specify_out:
+ y = torch.empty_like(x)
+ sgl_kernel.rmsnorm(x, w, out=y)
+ else:
+ y = sgl_kernel.rmsnorm(x, w)
+
+ torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
+
+
+@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
+@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
+@pytest.mark.parametrize("dtype", [torch.float16])
+def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
+ eps = 1e-6
+
+ x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
+ residual = torch.randn_like(x)
+ weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
+
+ x_native, residual_native = fused_add_rms_norm(
+ x.clone(), residual.clone(), weight, eps
+ )
+
+ x_fused = x.clone()
+ residual_fused = residual.clone()
+ sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
+
+ torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
+ torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
+
+
+@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
+@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
+@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("specify_out", [True, False])
+def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
+ x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
+ w = torch.randn(hidden_size).to(0).to(dtype)
+
+ y_ref = gemma_rms_norm(x, w)
+ if specify_out:
+ y = torch.empty_like(x)
+ sgl_kernel.gemma_rmsnorm(x, w, out=y)
+ else:
+ y = sgl_kernel.gemma_rmsnorm(x, w)
+
+ torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
+
+
+@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
+@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
+@pytest.mark.parametrize("dtype", [torch.float16])
+def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
+ eps = 1e-6
+
+ x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
+ residual = torch.randn_like(x)
+ weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
+
+ x_native, residual_native = gemma_fused_add_rms_norm(
+ x.clone(), residual.clone(), weight, eps
+ )
+
+ x_fused = x.clone()
+ residual_fused = residual.clone()
+ sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
+
+ torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
+ torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
diff --git a/sgl-kernel/tests/test_rmsnorm.py b/sgl-kernel/tests/test_rmsnorm.py
deleted file mode 100644
index dda225de9e3..00000000000
--- a/sgl-kernel/tests/test_rmsnorm.py
+++ /dev/null
@@ -1,31 +0,0 @@
-import pytest
-import torch
-from sgl_kernel import rmsnorm
-
-
-def llama_rms_norm(x, w, eps=1e-6):
- orig_dtype = x.dtype
- x = x.float()
- variance = x.pow(2).mean(dim=-1, keepdim=True)
- x = x * torch.rsqrt(variance + eps)
- x = x * w.float()
- x = x.to(orig_dtype)
- return x
-
-
-@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
-@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
-@pytest.mark.parametrize("dtype", [torch.float16])
-@pytest.mark.parametrize("specify_out", [True, False])
-def test_norm(batch_size, hidden_size, dtype, specify_out):
- x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
- w = torch.randn(hidden_size).to(0).to(dtype)
-
- y_ref = llama_rms_norm(x, w)
- if specify_out:
- y = torch.empty_like(x)
- rmsnorm(x, w, out=y)
- else:
- y = rmsnorm(x, w)
-
- torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
From 9d9b482a392598fc342ee449835af5535ccc772f Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Wed, 22 Jan 2025 23:25:45 +0800
Subject: [PATCH 059/147] feat: integrate activation kernels into sgl-kernel
(#3053)
---
sgl-kernel/src/sgl-kernel/__init__.py | 6 ++
.../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 15 +++++
sgl-kernel/src/sgl-kernel/ops/__init__.py | 61 +++++++++++++++++++
sgl-kernel/tests/test_activation.py | 38 ++++++++++++
4 files changed, 120 insertions(+)
create mode 100644 sgl-kernel/tests/test_activation.py
diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py
index bdbc0ce846c..0bcd77aad37 100644
--- a/sgl-kernel/src/sgl-kernel/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/__init__.py
@@ -2,6 +2,8 @@
custom_dispose,
custom_reduce,
fused_add_rmsnorm,
+ gelu_and_mul,
+ gelu_tanh_and_mul,
gemma_fused_add_rmsnorm,
gemma_rmsnorm,
get_graph_buffer_ipc_meta,
@@ -12,12 +14,15 @@
rmsnorm,
rotary_embedding,
sampling_scaling_penalties,
+ silu_and_mul,
)
__all__ = [
"custom_dispose",
"custom_reduce",
"fused_add_rmsnorm",
+ "gelu_and_mul",
+ "gelu_tanh_and_mul",
"gemma_fused_add_rmsnorm",
"gemma_rmsnorm",
"get_graph_buffer_ipc_meta",
@@ -28,4 +33,5 @@
"rmsnorm",
"rotary_embedding",
"sampling_scaling_penalties",
+ "silu_and_mul",
]
diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
index 8f9d1ae5333..d9aaa41b88b 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
@@ -43,6 +43,15 @@ void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, do
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps,
int64_t cuda_stream);
+// silu and mul
+void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
+
+// gelu tanh and mul
+void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
+
+// gelu and mul
+void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
@@ -66,4 +75,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)");
// fused gemma rms norm
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)");
+ // silu and mul
+ m.def("silu_and_mul", &silu_and_mul, "Silu and Mul (CUDA)");
+ // gelu tanh and mul
+ m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)");
+ // gelu and mul
+ m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)");
}
diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py
index bbfd76878a7..5bfde5df2d0 100644
--- a/sgl-kernel/src/sgl-kernel/ops/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py
@@ -4,6 +4,8 @@
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
from sgl_kernel.ops._kernels import dispose as _dispose
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
+from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul
+from sgl_kernel.ops._kernels import gelu_tanh_and_mul as _gelu_tanh_and_mul
from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm
from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm
from sgl_kernel.ops._kernels import (
@@ -18,6 +20,7 @@
from sgl_kernel.ops._kernels import (
sampling_scaling_penalties as _sampling_scaling_penalties,
)
+from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul
def get_cuda_stream(device: torch.device) -> int:
@@ -127,3 +130,61 @@ def gemma_fused_add_rmsnorm(
) -> None:
with input.device as device:
_gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
+
+
+def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
+ assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
+ assert (
+ input.shape[:-1] == output.shape[:-1]
+ ), f"{input.shape[:-1]} != {output.shape[:-1]}"
+ assert (
+ input.shape[-1] == 2 * output.shape[-1]
+ ), f"{input.shape[-1]} != {2 * output.shape[-1]}"
+
+
+def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
+ if input.shape[-1] * input.dtype.itemsize % 16 != 0:
+ raise ValueError("The pointers must be multiple of 16 bytes.")
+ if out is not None:
+ _check_shape(input, out)
+ else:
+ out = torch.empty(
+ input.shape[:-1] + (input.shape[-1] // 2,),
+ device=input.device,
+ dtype=input.dtype,
+ )
+ with input.device as device:
+ _silu_and_mul(out, input, get_cuda_stream(device))
+ return out
+
+
+def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
+ if input.shape[-1] * input.dtype.itemsize % 16 != 0:
+ raise ValueError("The pointers must be multiple of 16 bytes.")
+ if out is not None:
+ _check_shape(input, out)
+ else:
+ out = torch.empty(
+ input.shape[:-1] + (input.shape[-1] // 2,),
+ device=input.device,
+ dtype=input.dtype,
+ )
+ with input.device as device:
+ _gelu_tanh_and_mul(out, input, get_cuda_stream(device))
+ return out
+
+
+def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
+ if input.shape[-1] * input.dtype.itemsize % 16 != 0:
+ raise ValueError("The pointers must be multiple of 16 bytes.")
+ if out is not None:
+ _check_shape(input, out)
+ else:
+ out = torch.empty(
+ input.shape[:-1] + (input.shape[-1] // 2,),
+ device=input.device,
+ dtype=input.dtype,
+ )
+ with input.device as device:
+ _gelu_and_mul(out, input, get_cuda_stream(device))
+ return out
diff --git a/sgl-kernel/tests/test_activation.py b/sgl-kernel/tests/test_activation.py
new file mode 100644
index 00000000000..f71f36b513d
--- /dev/null
+++ b/sgl-kernel/tests/test_activation.py
@@ -0,0 +1,38 @@
+# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py
+
+import pytest
+import sgl_kernel
+import torch
+
+
+@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
+@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
+@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
+def test_fused_silu_mul(dim, batch_size, seq_len):
+ x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
+ y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim])
+ y = sgl_kernel.silu_and_mul(x)
+ torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
+
+
+@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
+@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
+@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
+def test_fused_gelu_tanh_mul(dim, batch_size, seq_len):
+ x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
+ y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh")
+ y = sgl_kernel.gelu_tanh_and_mul(x)
+ torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
+
+
+@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384])
+@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16])
+@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512])
+def test_fused_gelu_mul(dim, batch_size, seq_len):
+ x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16)
+ y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none")
+ y = sgl_kernel.gelu_and_mul(x)
+ torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
+
+
+test_fused_silu_mul(128, 1, 1)
From b2bd8f444c61c5ffaa6e84bb0f094eb14f605fcc Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Wed, 22 Jan 2025 23:45:18 +0800
Subject: [PATCH 060/147] minor: update header and use pytest (#3054)
---
sgl-kernel/Makefile | 2 +-
sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu | 2 +-
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu | 2 +-
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu | 2 +-
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu | 2 +-
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh | 2 +-
sgl-kernel/src/sgl-kernel/csrc/{utils.hpp => utils.h} | 0
7 files changed, 6 insertions(+), 6 deletions(-)
rename sgl-kernel/src/sgl-kernel/csrc/{utils.hpp => utils.h} (100%)
diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile
index c7641bb5fee..9261b896934 100644
--- a/sgl-kernel/Makefile
+++ b/sgl-kernel/Makefile
@@ -19,7 +19,7 @@ clean:
@rm -rf build dist *.egg-info
test:
- @find tests -name "test_*.py" | xargs -n 1 python3
+ @find tests -name "test_*.py" | xargs -n 1 python3 && pytest tests/test_norm.py && pytest tests/test_activation.py
format:
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
diff --git a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
index 8e3f7275702..c77851c32b6 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu
@@ -16,7 +16,7 @@
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
#include "cutlass_extensions/gemm/gemm_universal_base_compat.h"
#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h"
-#include "utils.hpp"
+#include "utils.h"
using namespace cute;
diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
index c7faf9d3775..83861aee071 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
@@ -6,7 +6,7 @@
#include
-#include "utils.hpp"
+#include "utils.h"
#ifdef USE_ROCM
#include
diff --git a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
index a61d4b86059..2f53bb1a99f 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
@@ -4,7 +4,7 @@
#include
-#include "utils.hpp"
+#include "utils.h"
#include "vectorization.cuh"
template
diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
index d9aaa41b88b..985cfa17326 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
@@ -1,6 +1,6 @@
#include
-#include "utils.hpp"
+#include "utils.h"
// trt_reduce
using fptr_t = int64_t;
diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
index 9d6f9722eb5..22ba0e414fc 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
+++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
@@ -21,7 +21,7 @@
#include
#include
-#include "utils.hpp"
+#include "utils.h"
namespace trt_llm {
constexpr size_t WARP_SIZE = 32;
diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp b/sgl-kernel/src/sgl-kernel/csrc/utils.h
similarity index 100%
rename from sgl-kernel/src/sgl-kernel/csrc/utils.hpp
rename to sgl-kernel/src/sgl-kernel/csrc/utils.h
From bf669606eb84e12dc1ecf15b23c1eedab204d660 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Thu, 23 Jan 2025 00:39:38 +0800
Subject: [PATCH 061/147] feat: integrate bmm_fp8 kernel into sgl-kernel
(#3056)
---
sgl-kernel/setup.py | 12 +++-
sgl-kernel/src/sgl-kernel/__init__.py | 2 +
.../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 6 ++
sgl-kernel/src/sgl-kernel/ops/__init__.py | 61 +++++++++++++++----
sgl-kernel/src/sgl-kernel/ops/utils.py | 19 ++++++
sgl-kernel/tests/test_bmm_fp8.py | 43 +++++++++++++
6 files changed, 131 insertions(+), 12 deletions(-)
create mode 100644 sgl-kernel/src/sgl-kernel/ops/utils.py
create mode 100644 sgl-kernel/tests/test_bmm_fp8.py
diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py
index b9324c35543..81cd96e99ad 100644
--- a/sgl-kernel/setup.py
+++ b/sgl-kernel/setup.py
@@ -62,12 +62,22 @@ def get_device_sm():
"-std=c++17",
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
- "-DFLASHINFER_ENABLE_BF16",
]
if cuda_version >= (12, 0) and sm_version >= 90:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
+if sm_version >= 90:
+ nvcc_flags.extend(
+ [
+ "-DFLASHINFER_ENABLE_FP8",
+ "-DFLASHINFER_ENABLE_FP8_E4M3",
+ "-DFLASHINFER_ENABLE_FP8_E5M2",
+ ]
+ )
+if sm_version >= 80:
+ nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
+
for flag in [
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py
index 0bcd77aad37..86c4f34d353 100644
--- a/sgl-kernel/src/sgl-kernel/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/__init__.py
@@ -1,4 +1,5 @@
from sgl_kernel.ops import (
+ bmm_fp8,
custom_dispose,
custom_reduce,
fused_add_rmsnorm,
@@ -18,6 +19,7 @@
)
__all__ = [
+ "bmm_fp8",
"custom_dispose",
"custom_reduce",
"fused_add_rmsnorm",
diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
index 985cfa17326..12df0747171 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
@@ -52,6 +52,10 @@ void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
// gelu and mul
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
+// bmm fp8
+void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale,
+ at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream);
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
@@ -81,4 +85,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)");
// gelu and mul
m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)");
+ // bmm fp8
+ m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)");
}
diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py
index 5bfde5df2d0..cea3436b631 100644
--- a/sgl-kernel/src/sgl-kernel/ops/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py
@@ -2,6 +2,7 @@
import torch
from sgl_kernel.ops._kernels import all_reduce as _all_reduce
+from sgl_kernel.ops._kernels import bmm_fp8 as _bmm_fp8
from sgl_kernel.ops._kernels import dispose as _dispose
from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm
from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul
@@ -21,10 +22,7 @@
sampling_scaling_penalties as _sampling_scaling_penalties,
)
from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul
-
-
-def get_cuda_stream(device: torch.device) -> int:
- return torch.cuda.current_stream(device).cuda_stream
+from sgl_kernel.ops.utils import _get_cache_buf, _get_cuda_stream
def init_custom_reduce(
@@ -101,7 +99,7 @@ def rmsnorm(
with input.device as device:
if out is None:
out = torch.empty_like(input)
- _rmsnorm(out, input, weight, eps, get_cuda_stream(device))
+ _rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
return out
@@ -109,7 +107,7 @@ def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
with input.device as device:
- _fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
+ _fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device))
def gemma_rmsnorm(
@@ -121,7 +119,7 @@ def gemma_rmsnorm(
with input.device as device:
if out is None:
out = torch.empty_like(input)
- _gemma_rmsnorm(out, input, weight, eps, get_cuda_stream(device))
+ _gemma_rmsnorm(out, input, weight, eps, _get_cuda_stream(device))
return out
@@ -129,7 +127,7 @@ def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
with input.device as device:
- _gemma_fused_add_rmsnorm(input, residual, weight, eps, get_cuda_stream(device))
+ _gemma_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device))
def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
@@ -154,7 +152,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
dtype=input.dtype,
)
with input.device as device:
- _silu_and_mul(out, input, get_cuda_stream(device))
+ _silu_and_mul(out, input, _get_cuda_stream(device))
return out
@@ -170,7 +168,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
dtype=input.dtype,
)
with input.device as device:
- _gelu_tanh_and_mul(out, input, get_cuda_stream(device))
+ _gelu_tanh_and_mul(out, input, _get_cuda_stream(device))
return out
@@ -186,5 +184,46 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
dtype=input.dtype,
)
with input.device as device:
- _gelu_and_mul(out, input, get_cuda_stream(device))
+ _gelu_and_mul(out, input, _get_cuda_stream(device))
return out
+
+
+def _bmm_fp8_internal(
+ workspace_buffer: torch.Tensor,
+ A: torch.Tensor,
+ B: torch.Tensor,
+ D: torch.Tensor,
+ A_scale: torch.Tensor,
+ B_scale: torch.Tensor,
+) -> None:
+ with A.device as device:
+ cublas_handle = torch.cuda.current_blas_handle()
+ _bmm_fp8(
+ A,
+ B,
+ D,
+ A_scale,
+ B_scale,
+ workspace_buffer,
+ cublas_handle,
+ _get_cuda_stream(device),
+ )
+
+
+def bmm_fp8(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ A_scale: torch.Tensor,
+ B_scale: torch.Tensor,
+ dtype: torch.dtype,
+ out: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ if out is None:
+ out = torch.empty(
+ (A.shape[0], A.shape[1], B.shape[2]),
+ device=A.device,
+ dtype=dtype,
+ )
+ workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
+ _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
+ return out
diff --git a/sgl-kernel/src/sgl-kernel/ops/utils.py b/sgl-kernel/src/sgl-kernel/ops/utils.py
new file mode 100644
index 00000000000..af5fccbb786
--- /dev/null
+++ b/sgl-kernel/src/sgl-kernel/ops/utils.py
@@ -0,0 +1,19 @@
+from typing import Dict, Tuple
+
+import torch
+
+
+def _get_cuda_stream(device: torch.device) -> int:
+ return torch.cuda.current_stream(device).cuda_stream
+
+
+_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
+
+
+def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
+ key = (name, device)
+ buf = _cache_buf.get(key)
+ if buf is None:
+ buf = torch.empty(bytes, dtype=torch.uint8, device=device)
+ _cache_buf[key] = buf
+ return buf
diff --git a/sgl-kernel/tests/test_bmm_fp8.py b/sgl-kernel/tests/test_bmm_fp8.py
new file mode 100644
index 00000000000..e0be92896f6
--- /dev/null
+++ b/sgl-kernel/tests/test_bmm_fp8.py
@@ -0,0 +1,43 @@
+# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_bmm_fp8.py
+
+import pytest
+import torch
+import torch.nn.functional as F
+from sgl_kernel import bmm_fp8
+
+
+def to_float8(x, dtype=torch.float8_e4m3fn):
+ finfo = torch.finfo(dtype)
+ min_val, max_val = x.aminmax()
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
+ scale = finfo.max / amax
+ x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
+ return x_scl_sat.to(dtype), scale.float().reciprocal()
+
+
+@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
+@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
+@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
+def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype):
+ if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2:
+ pytest.skip("Invalid combination: both input and mat2 are e5m2")
+
+ input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16)
+ input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)
+
+ # mat2 row major -> column major
+ mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(
+ -2, -1
+ )
+ mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype)
+
+ res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype)
+ bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res)
+
+ reference = torch.bmm(input, mat2)
+ cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0)
+ assert cos_sim > 0.99
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
From 0d2148efaa9c490231f89bd5f587e2241fcff26c Mon Sep 17 00:00:00 2001
From: nstream-ai-devx <155576234+sudo-root-ns@users.noreply.github.com>
Date: Wed, 22 Jan 2025 23:45:32 +0530
Subject: [PATCH 062/147] fix rotary_embedding rope_scaling for phi (#3055)
---
python/sglang/srt/layers/rotary_embedding.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py
index 43478f39d2c..ad265830f8f 100644
--- a/python/sglang/srt/layers/rotary_embedding.py
+++ b/python/sglang/srt/layers/rotary_embedding.py
@@ -1018,7 +1018,12 @@ def get_rope(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
else:
- scaling_type = rope_scaling["rope_type"]
+ if "rope_type" in rope_scaling:
+ scaling_type = rope_scaling["rope_type"]
+ elif "type" in rope_scaling:
+ scaling_type = rope_scaling["type"]
+ else:
+ raise ValueError("Unknown RoPE scaling type")
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
From 806a3002c10b3992b86921e0af17b116794c78e1 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Thu, 23 Jan 2025 02:47:36 +0800
Subject: [PATCH 063/147] add notice about flashinfer in sgl-kernel (#3057)
---
sgl-kernel/src/sgl-kernel/ops/__init__.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py
index cea3436b631..d90f121d4f3 100644
--- a/sgl-kernel/src/sgl-kernel/ops/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py
@@ -90,6 +90,8 @@ def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
+# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
+# Kudos to @yzh119
def rmsnorm(
input: torch.Tensor,
weight: torch.Tensor,
From ddc2001fb00f67d0d657ebaf056d65c4900e8e57 Mon Sep 17 00:00:00 2001
From: Hui Liu <96135754+hliuca@users.noreply.github.com>
Date: Wed, 22 Jan 2025 13:57:22 -0800
Subject: [PATCH 064/147] disable custom allreduce on HIP (#3058)
---
python/sglang/srt/distributed/parallel_state.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py
index c6d1a830781..d97c348ef7b 100644
--- a/python/sglang/srt/distributed/parallel_state.py
+++ b/python/sglang/srt/distributed/parallel_state.py
@@ -41,6 +41,7 @@
from sglang.srt.utils import (
direct_register_custom_op,
is_cuda_alike,
+ is_hip,
supports_custom_op,
)
@@ -952,6 +953,9 @@ def graph_capture():
def set_custom_all_reduce(enable: bool):
global _ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE = enable
+ if enable and is_hip():
+ logger.warning("HIP doesn't support custom_all_reduce, so disable it.")
+ _ENABLE_CUSTOM_ALL_REDUCE = False
def init_distributed_environment(
From b3393e941fd1d9b97ae317ec852b8c6a705dbe40 Mon Sep 17 00:00:00 2001
From: Baizhou Zhang
Date: Wed, 22 Jan 2025 14:17:26 -0800
Subject: [PATCH 065/147] [Doc] Update doc of profiling with PyTorch Profiler
(#3038)
---
docs/references/benchmark_and_profiling.md | 21 ++++++++++++++++++---
1 file changed, 18 insertions(+), 3 deletions(-)
diff --git a/docs/references/benchmark_and_profiling.md b/docs/references/benchmark_and_profiling.md
index 87ac5177424..0600b192b4f 100644
--- a/docs/references/benchmark_and_profiling.md
+++ b/docs/references/benchmark_and_profiling.md
@@ -64,16 +64,31 @@ with nvtx.annotate("description", color="color"):
```bash
# set trace path
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
+
# start server
python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct
-python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --profile
+# send profiling request from client
+python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile
```
-
-Traces can be visualized using https://ui.perfetto.dev/.
+Please make sure that the `SGLANG_TORCH_PROFILER_DIR` should be set at both server and client side, otherwise the trace file cannot be generated correctly . A secure way will be setting `SGLANG_TORCH_PROFILER_DIR` in the `.*rc` file of shell (e.g. `~/.bashrc` for bash shells).
- To profile offline
```bash
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8
```
+
+- View Traces
+
+Trace files can be loaded and visualized from:
+1. https://ui.perfetto.dev/ (any browser)
+2. chrome://tracing (Chrome browser only)
+
+If browser cannot open trace file due to its large size,
+client can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs.
+For example, when profiling a server,
+```bash
+python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile
+```
+sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly.
From b8ab989ff4666f4cee2a1d77aa3941d794ffabd7 Mon Sep 17 00:00:00 2001
From: lukec <118525388+sleepcoo@users.noreply.github.com>
Date: Thu, 23 Jan 2025 06:19:33 +0800
Subject: [PATCH 066/147] Fix the FP8 E4M3 parsing offline scales failure bug
(#3045)
---
.../sglang/srt/model_loader/weight_utils.py | 78 ++++++++++++++++++-
1 file changed, 77 insertions(+), 1 deletion(-)
diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py
index 77c3fcbee74..f2f67ecab1d 100644
--- a/python/sglang/srt/model_loader/weight_utils.py
+++ b/python/sglang/srt/model_loader/weight_utils.py
@@ -27,6 +27,7 @@
import numpy as np
import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
+from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
@@ -650,6 +651,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return name
+# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py
+class KVCacheQuantSchema(BaseModel):
+ dtype: str
+ # Each key is a TP rank. Each value is a dictionary mapping a TP rank's
+ # layer indices to their per-tensor KV cache scaling factor.
+ # TODO: Consider pulling this and its validation methods out into its
+ # own schema class (tricky as its members are variable)
+ scaling_factor: Dict[int, Dict[int, float]]
+
+ @model_validator(mode="after")
+ def check_is_fp8(self) -> "KVCacheQuantSchema":
+ assert self.dtype == "float8_e4m3fn", (
+ "Loaded scaling factors intended for KV cache dtype = "
+ f"{self.dtype} rather than float8_e4m3fn!"
+ )
+ return self
+
+ @model_validator(mode="after")
+ def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
+ context = info.context
+ if context:
+ tp_size = context["tp_size"]
+ num_hidden_layers = context["num_hidden_layers"]
+ assert len(self.scaling_factor) == tp_size, (
+ f"Loaded dictionary has TP size {len(self.scaling_factor)} "
+ f"but LLM engine is currently running with TP size {tp_size}."
+ )
+ for tp_rank, layer_maps in self.scaling_factor.items():
+ assert len(layer_maps) == num_hidden_layers, (
+ f"KV cache scales map for TP rank {tp_rank} is malformed. "
+ f"Expected {num_hidden_layers} layers, got "
+ f"{len(layer_maps)}."
+ )
+ for i in range(tp_size):
+ assert (
+ i in self.scaling_factor
+ ), f"KV cache scales map for TP rank {i} not found."
+ return self
+
+ @model_validator(mode="after")
+ def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
+ context = info.context
+ if context:
+ tp_rank = context["tp_rank"]
+ num_hidden_layers = context["num_hidden_layers"]
+ layer_scales_map = self.scaling_factor[tp_rank]
+ for i in range(num_hidden_layers):
+ assert i in layer_scales_map, (
+ f"Could not find KV cache scales for layer {i} in "
+ f"TP rank {tp_rank}."
+ )
+ return self
+
+
+class QuantParamSchema(BaseModel):
+ # TODO: Generalize and extend with more fields
+ # (e.g. weights/activations params) once functionality is enabled
+ model_config = ConfigDict(protected_namespaces=())
+ model_type: Optional[str]
+ kv_cache: KVCacheQuantSchema
+
+ @model_validator(mode="after")
+ def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
+ context = info.context
+ if context:
+ model_type = context.get("model_type", None)
+ if model_type is not None:
+ assert model_type == self.model_type, (
+ f"Model type is {model_type} but loaded "
+ f"scaling factors belonging to different "
+ f"model type {self.model_type}!"
+ )
+ return self
+
+
def kv_cache_scales_loader(
filename: str,
tp_rank: int,
@@ -681,7 +757,7 @@ def kv_cache_scales_loader(
except json.JSONDecodeError:
logger.error("Error decoding JSON in file '%s'.", filename)
except Exception:
- logger.exception("An error occurred while reading '%s'.", filename)
+ logger.error("An error occurred while reading '%s'.", filename)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
From 022614d26e25cbd963d3bd2706582198943a44ee Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Wed, 22 Jan 2025 15:05:51 -0800
Subject: [PATCH 067/147] Add some flags to allow sync token ids across TP
ranks (#3060)
---
python/sglang/srt/layers/sampler.py | 24 +++++++++++++++++++++++-
1 file changed, 23 insertions(+), 1 deletion(-)
diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py
index f3c376ed1eb..24f951f2b5d 100644
--- a/python/sglang/srt/layers/sampler.py
+++ b/python/sglang/srt/layers/sampler.py
@@ -2,12 +2,18 @@
from typing import List
import torch
+import torch.distributed as dist
from torch import nn
+from sglang.srt.distributed import get_tensor_model_parallel_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
-from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
+from sglang.srt.utils import (
+ crash_on_warnings,
+ get_bool_env_var,
+ is_flashinfer_available,
+)
if is_flashinfer_available():
from flashinfer.sampling import (
@@ -20,6 +26,8 @@
logger = logging.getLogger(__name__)
+SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
+
class Sampler(nn.Module):
def __init__(self):
@@ -121,6 +129,20 @@ def forward(
batch_next_token_ids,
]
+ if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars:
+ # For performance reasons, SGLang does not sync the final token IDs across TP ranks by default.
+ # This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators:
+ # the last all-reduce, the last lm_head matmul, and all sampling kernels.
+ # These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic.
+ # In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized.
+ # When using xgrammar, this becomes more likely so we also do the sync when grammar is used.
+
+ torch.distributed.all_reduce(
+ batch_next_token_ids,
+ op=dist.ReduceOp.MIN,
+ group=get_tensor_model_parallel_group().device_group,
+ )
+
return batch_next_token_ids.to(torch.int32)
def _apply_custom_logit_processor(
From c0bf9bf15c2e4969e63c7fc13c51ae99d14e1570 Mon Sep 17 00:00:00 2001
From: Byron Hsu
Date: Wed, 22 Jan 2025 17:47:54 -0800
Subject: [PATCH 068/147] [devcontainer] add non-root user (#2989)
---
.devcontainer/Dockerfile | 35 +++++++++++++++++++++++++++++++++
.devcontainer/devcontainer.json | 3 ++-
docker/Dockerfile.dev | 6 ------
3 files changed, 37 insertions(+), 7 deletions(-)
create mode 100644 .devcontainer/Dockerfile
diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile
new file mode 100644
index 00000000000..0c061cd1871
--- /dev/null
+++ b/.devcontainer/Dockerfile
@@ -0,0 +1,35 @@
+From lmsysorg/sglang:dev
+
+# Create non-root user with specified UID and GID
+# NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908.
+ARG HOST_UID=1003
+ARG HOST_GID=1003
+RUN groupadd -g $HOST_GID devuser && \
+ useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser
+
+# Give devuser sudo access
+RUN apt-get update && apt-get install -y sudo && \
+ echo "devuser ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/devuser && \
+ rm -rf /var/lib/apt/lists/* && \
+ apt-get clean
+
+# Set up oh-my-zsh for devuser
+RUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \
+ cp /root/.zshrc /home/devuser/.zshrc && \
+ cp /root/.vimrc /home/devuser/.vimrc && \
+ cp /root/.tmux.conf /home/devuser/.tmux.conf && \
+ sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \
+ chown -R devuser:devuser /home/devuser/
+
+# Set workspace directory and ownership
+WORKDIR /sgl-workspace/sglang
+RUN chown -R devuser:devuser /sgl-workspace
+
+# Switch to devuser
+USER devuser
+
+# Install uv
+RUN curl -LsSf https://astral.sh/uv/install.sh | sh
+
+# Install rust
+RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index 66f7aecbf82..5767aa2631a 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -1,8 +1,9 @@
{
"name": "sglang",
"build": {
- "dockerfile": "../docker/Dockerfile.dev"
+ "dockerfile": "Dockerfile"
},
+ "remoteUser": "devuser",
"customizations": {
"vscode": {
"extensions": [
diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev
index 9d05ee5997e..5ff1fa7a51a 100644
--- a/docker/Dockerfile.dev
+++ b/docker/Dockerfile.dev
@@ -67,12 +67,6 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.31.1/cmake-3.31.1
&& cp -r cmake-3.31.1-linux-x86_64/share/* /usr/local/share/ \
&& rm -rf cmake-3.31.1-linux-x86_64 cmake-3.31.1-linux-x86_64.tar.gz
-# Install uv
-RUN curl -LsSf https://astral.sh/uv/install.sh | sh
-
-# Install rust
-RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
-
# Add yank script
COPY --chown=root:root <<-"EOF" /usr/local/bin/yank
#!/bin/bash
From 5de50653cd52195c7effe4f58b284ab05ce04809 Mon Sep 17 00:00:00 2001
From: Byron Hsu
Date: Wed, 22 Jan 2025 17:56:21 -0800
Subject: [PATCH 069/147] [router] make error actionable (#3063)
---
sgl-router/src/router.rs | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs
index 5bbffc74ccf..a189ff9eb88 100644
--- a/sgl-router/src/router.rs
+++ b/sgl-router/src/router.rs
@@ -238,12 +238,12 @@ impl Router {
loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
error!(
- "Timeout {}s waiting for workers to become healthy",
- timeout_secs
+ "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
+ timeout_secs, worker_urls
);
return Err(format!(
- "Timeout {}s waiting for workers to become healthy",
- timeout_secs
+ "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
+ timeout_secs, worker_urls
));
}
@@ -644,11 +644,11 @@ impl Router {
loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
error!(
- "Timeout {}s waiting for worker {} to become healthy",
+ "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
timeout_secs, worker_url
);
return Err(format!(
- "Timeout {}s waiting for worker {} to become healthy",
+ "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value",
timeout_secs, worker_url
));
}
From 8b84e69f25929c8de0286c6e0e0c2ce4686b561c Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Wed, 22 Jan 2025 18:51:40 -0800
Subject: [PATCH 070/147] Fix tp token sync for dp attention (#3062)
---
python/sglang/srt/layers/sampler.py | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py
index 24f951f2b5d..3173d533d16 100644
--- a/python/sglang/srt/layers/sampler.py
+++ b/python/sglang/srt/layers/sampler.py
@@ -6,6 +6,7 @@
from torch import nn
from sglang.srt.distributed import get_tensor_model_parallel_group
+from sglang.srt.layers.dp_attention import get_attention_tp_group
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -33,6 +34,10 @@ class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
+ self.tp_sync_group = get_tensor_model_parallel_group().device_group
+
+ if global_server_args_dict["enable_dp_attention"]:
+ self.tp_sync_group = get_attention_tp_group().device_group
def forward(
self,
@@ -140,7 +145,7 @@ def forward(
torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
- group=get_tensor_model_parallel_group().device_group,
+ group=self.tp_sync_group,
)
return batch_next_token_ids.to(torch.int32)
From 862bcff833c8ae480fea0fdab6e53e619c650cb5 Mon Sep 17 00:00:00 2001
From: Ke Wen
Date: Wed, 22 Jan 2025 21:33:17 -0800
Subject: [PATCH 071/147] Support loading of larger models with on-the-fly
quantization (#3061)
---
python/sglang/srt/configs/load_config.py | 1 +
python/sglang/srt/layers/torchao_utils.py | 18 +++--
.../sglang/srt/model_executor/model_runner.py | 9 ++-
python/sglang/srt/model_loader/loader.py | 75 +++++++++++++++++++
.../sglang/srt/models/torch_native_llama.py | 21 +++++-
python/sglang/srt/server_args.py | 6 +-
6 files changed, 116 insertions(+), 14 deletions(-)
diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py
index 2b2b341faeb..6cb35ab47c6 100644
--- a/python/sglang/srt/configs/load_config.py
+++ b/python/sglang/srt/configs/load_config.py
@@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum):
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
+ LAYERED = "layered"
@dataclass
diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py
index c5bca25df37..e08abd5ae1d 100644
--- a/python/sglang/srt/layers/torchao_utils.py
+++ b/python/sglang/srt/layers/torchao_utils.py
@@ -5,6 +5,7 @@
import logging
import os
import pwd
+from typing import Callable, Optional
import torch
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
return True
+def proj_filter(
+ module: torch.nn.Module,
+ fqn: str,
+):
+ """Filter function for quantizing projection layers."""
+ return "proj" in fqn
+
+
def apply_torchao_config_to_model(
- model: torch.nn.Module, torchao_config: str, filter_fn=None
+ model: torch.nn.Module,
+ torchao_config: str,
+ filter_fn: Optional[Callable] = proj_filter,
):
"""Quantize a modelwith torchao quantization specified by torchao_config
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
)
from torchao.quantization.observer import PerRow, PerTensor
- if filter_fn is None:
-
- def filter_fn(module, fqn):
- return "proj" in fqn
-
if torchao_config == "" or torchao_config is None:
return model
elif "int8wo" in torchao_config:
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index d5cdcf2beb0..e7dc6bd66c5 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -185,9 +185,12 @@ def __init__(
self.load_model()
# Apply torchao quantization
- apply_torchao_config_to_model(
- self.model, global_server_args_dict["torchao_config"]
- )
+ torchao_applied = getattr(self.model, "torchao_applied", False)
+ # In layered loading, torchao may have been applied
+ if not torchao_applied:
+ apply_torchao_config_to_model(
+ self.model, global_server_args_dict["torchao_config"]
+ )
# Apply torch TP if the model supports it
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py
index 677d716d43b..9e6b09488e6 100644
--- a/python/sglang/srt/model_loader/loader.py
+++ b/python/sglang/srt/model_loader/loader.py
@@ -374,6 +374,78 @@ def load_model(
return model.eval()
+class LayeredModelLoader(DefaultModelLoader):
+ """Model loader that loads weights layer by layer so that one can quantize a
+ layer before loading another to make the peak memory envelope smaller."""
+
+ def __init__(self, load_config: LoadConfig):
+ # Back to the default load format
+ load_config.load_format = LoadFormat.AUTO
+ super().__init__(load_config)
+
+ def load_model(
+ self,
+ *,
+ model_config: ModelConfig,
+ device_config: DeviceConfig,
+ ) -> nn.Module:
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
+
+ torchao_config = global_server_args_dict.get("torchao_config")
+ target_device = torch.device(device_config.device)
+
+ with set_default_torch_dtype(model_config.dtype):
+ # Create model on meta device
+ with torch.device("meta"):
+ model = _initialize_model(
+ model_config,
+ self.load_config,
+ )
+
+ # Check model's layered load support
+ if not hasattr(model, "load_weights_to_module"):
+ raise ValueError(
+ "LayeredModelLoader requires the model to have a "
+ "`load_weights_to_module` method. "
+ f"{model_config.model_path} does not support it."
+ )
+
+ # Get all weights from disk
+ weights = self._get_all_weights(model_config, model)
+
+ # Helper function to recursively fill the weights of a module
+ def fill_module(module, fqn: List[str], weights):
+ """
+ fqn: list of strings representing the fully qualified name of `module`.
+ """
+ # Layer by layer
+ for name, submod in module.named_children():
+ fill_module(submod, fqn + [name], weights)
+
+ # First materialize on target device
+ module.to_empty(device=target_device, recurse=False)
+ fqn_path = ".".join(fqn)
+ # Fill weights
+ model.load_weights_to_module(
+ fqn_path,
+ weights,
+ )
+ # Quantize weights if applicable
+ if torchao_config and "proj" in fqn_path:
+ # Note: `None` here is needed to indicate no filter, see
+ # `apply_torchao_config_to_model` for details.
+ apply_torchao_config_to_model(module, torchao_config, None)
+
+ # Start calling on root module
+ fill_module(model, [], weights)
+
+ if torchao_config:
+ model.torchao_applied = True
+
+ return model.eval()
+
+
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
@@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)
+ if load_config.load_format == LoadFormat.LAYERED:
+ return LayeredModelLoader(load_config)
+
return DefaultModelLoader(load_config)
diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py
index 024a6f317fa..7b3e5bc5ddd 100644
--- a/python/sglang/srt/models/torch_native_llama.py
+++ b/python/sglang/srt/models/torch_native_llama.py
@@ -460,7 +460,12 @@ def get_num_params(self):
params_dict = dict(self.named_parameters())
return len(params_dict)
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+ def load_weights_to_module(
+ self,
+ fqn: str,
+ weights: Iterable[Tuple[str, torch.Tensor]],
+ ):
+ """Load weights onto submodule pointed by path `fqn`."""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -469,7 +474,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
- params_dict = dict(self.named_parameters())
+ module = self.get_submodule(fqn)
+ params_dict = dict(module.named_parameters(prefix=fqn, recurse=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -486,7 +492,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
+ if name.endswith(".bias") or name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
@@ -494,12 +500,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
break
else:
# Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
+ if name.endswith(".bias") or name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
+ def load_weights(
+ self,
+ weights: Iterable[Tuple[str, torch.Tensor]],
+ ):
+ """Load weights onto the full model."""
+ self.load_weights_to_module("", weights)
+
class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
pass
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 4a7a28751db..330c3813288 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -317,6 +317,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
"dummy",
"gguf",
"bitsandbytes",
+ "layered",
],
help="The format of the model weights to load. "
'"auto" will try to load the weights in the safetensors format '
@@ -330,7 +331,10 @@ def add_cli_args(parser: argparse.ArgumentParser):
"which is mainly for profiling."
'"gguf" will load the weights in the gguf format. '
'"bitsandbytes" will load the weights using bitsandbytes '
- "quantization.",
+ "quantization."
+ '"layered" loads weights layer by layer so that one can quantize a '
+ "layer before loading another to make the peak memory envelope "
+ "smaller.",
)
parser.add_argument(
"--trust-remote-code",
From ea535dc5745e3a5e7197ec1a58a26d60e4ab3d05 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Wed, 22 Jan 2025 21:33:35 -0800
Subject: [PATCH 072/147] Revert "disable custom allreduce on HIP" (#3067)
---
python/sglang/srt/distributed/parallel_state.py | 4 ----
1 file changed, 4 deletions(-)
diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py
index d97c348ef7b..c6d1a830781 100644
--- a/python/sglang/srt/distributed/parallel_state.py
+++ b/python/sglang/srt/distributed/parallel_state.py
@@ -41,7 +41,6 @@
from sglang.srt.utils import (
direct_register_custom_op,
is_cuda_alike,
- is_hip,
supports_custom_op,
)
@@ -953,9 +952,6 @@ def graph_capture():
def set_custom_all_reduce(enable: bool):
global _ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE = enable
- if enable and is_hip():
- logger.warning("HIP doesn't support custom_all_reduce, so disable it.")
- _ENABLE_CUSTOM_ALL_REDUCE = False
def init_distributed_environment(
From a547aad61f9f7182724caa1e4ea883848f1c632d Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Thu, 23 Jan 2025 13:47:53 +0800
Subject: [PATCH 073/147] docs: add developer guide for sgl-kernel (#3068)
---
sgl-kernel/developer_guide.md | 46 +++++++++++++++++++++++++++++++++++
1 file changed, 46 insertions(+)
create mode 100644 sgl-kernel/developer_guide.md
diff --git a/sgl-kernel/developer_guide.md b/sgl-kernel/developer_guide.md
new file mode 100644
index 00000000000..8c9bf6195b0
--- /dev/null
+++ b/sgl-kernel/developer_guide.md
@@ -0,0 +1,46 @@
+# Developer Guide for sgl-kernel
+
+## Development Environment Setup
+
+Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer/development_guide_using_docker.md#setup-docker-container).
+
+Create and enter development container:
+```bash
+docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh
+docker exec -it sglang_zhyncs /bin/zsh
+```
+
+## Project Structure
+
+### Dependencies
+
+Third-party libraries:
+
+- [CCCL](https://github.com/NVIDIA/cccl)
+- [CUTLASS](https://github.com/NVIDIA/cutlass)
+- [FlashInfer](https://github.com/flashinfer-ai/flashinfer)
+
+### Kernel Development
+
+Steps to add a new kernel:
+
+1. Implement in `sgl-kernel/src/sgl-kernel/csrc`
+2. Expose interface in `sgl-kernel/csrc/sgl_kernel_ops.cu` with pybind11
+3. Create Python wrapper in `sgl-kernel/src/sgl-kernel/ops/__init__.py`
+4. Expose Python interface in `sgl-kernel/src/sgl-kernel/__init__.py`
+
+### Build & Install
+
+Development build:
+
+```bash
+make build
+pip3 install dist/*whl --force-reinstall --no-deps
+# Or use: make install (runs pip install -e .)
+```
+
+### Testing & Benchmarking
+
+1. Add pytest tests in `sgl-kernel/tests/`
+2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in `sgl-kernel/benchmark/`
+3. Run test suite
From 44e12ce463f44a29f87fc8af0f1e6c784b2c82ac Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Thu, 23 Jan 2025 14:08:25 +0800
Subject: [PATCH 074/147] docs: update developer guide for sgl-kernel (#3069)
---
sgl-kernel/developer_guide.md | 17 +++++++++++------
1 file changed, 11 insertions(+), 6 deletions(-)
diff --git a/sgl-kernel/developer_guide.md b/sgl-kernel/developer_guide.md
index 8c9bf6195b0..8afb6b0e460 100644
--- a/sgl-kernel/developer_guide.md
+++ b/sgl-kernel/developer_guide.md
@@ -24,10 +24,11 @@ Third-party libraries:
Steps to add a new kernel:
-1. Implement in `sgl-kernel/src/sgl-kernel/csrc`
-2. Expose interface in `sgl-kernel/csrc/sgl_kernel_ops.cu` with pybind11
-3. Create Python wrapper in `sgl-kernel/src/sgl-kernel/ops/__init__.py`
-4. Expose Python interface in `sgl-kernel/src/sgl-kernel/__init__.py`
+1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
+2. Expose interface in [csrc/sgl_kernel_ops.cu](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu) with pybind11
+3. Create Python wrapper in [src/sgl-kernel/ops/__init__.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
+4. Expose Python interface in [src/sgl-kernel/__init__.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
+5. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
### Build & Install
@@ -41,6 +42,10 @@ pip3 install dist/*whl --force-reinstall --no-deps
### Testing & Benchmarking
-1. Add pytest tests in `sgl-kernel/tests/`
-2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in `sgl-kernel/benchmark/`
+1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests)
+2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
3. Run test suite
+
+### Release new version
+
+Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml)
From 3e032c07cc45b7fe3fc16041e602e4e5ed13ef79 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Thu, 23 Jan 2025 14:19:38 +0800
Subject: [PATCH 075/147] use v0.6.4.post1 for sgl-kernel ci (#3071)
---
.github/workflows/pr-test-sgl-kernel.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml
index 794a73f3661..55eb636d64f 100644
--- a/.github/workflows/pr-test-sgl-kernel.yml
+++ b/.github/workflows/pr-test-sgl-kernel.yml
@@ -38,7 +38,7 @@ jobs:
- name: Install
run: |
- pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm
+ pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1
pip3 uninstall sgl-kernel -y || true
find . -name index.lock -delete
cd sgl-kernel
From ac2dc35d0e529a278450bceb4d234aae3a1c93d8 Mon Sep 17 00:00:00 2001
From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Date: Thu, 23 Jan 2025 15:29:20 +0800
Subject: [PATCH 076/147] support lightning_attention_decode in sgl-kernel for
MiniMax-Text-01 (#3030)
---
.../benchmark_lightning_attention_decode.py | 77 ++++-
.../bench_lightning_attention_decode.py | 299 ++++++++++++++++++
sgl-kernel/setup.py | 1 +
sgl-kernel/src/sgl-kernel/__init__.py | 2 +
.../csrc/lightning_attention_decode_kernel.cu | 119 +++++++
.../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 7 +
sgl-kernel/src/sgl-kernel/ops/__init__.py | 7 +
.../tests/test_lightning_attention_decode.py | 84 +++++
8 files changed, 588 insertions(+), 8 deletions(-)
create mode 100644 sgl-kernel/benchmark/bench_lightning_attention_decode.py
create mode 100644 sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
create mode 100644 sgl-kernel/tests/test_lightning_attention_decode.py
diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py
index a2d1e10f662..57fbcfddf2c 100644
--- a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py
+++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py
@@ -9,6 +9,7 @@
import triton
import triton.language as tl
from einops import rearrange
+from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
@triton.jit
@@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params):
model_params["num_attention_heads"],
d,
d,
- dtype=dtype,
device=device,
)
with torch.no_grad():
@@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params):
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ past_kv = past_kv.contiguous()
+ slope_rate = slope_rate.contiguous()
+ # Test Triton implementation
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
triton_output = triton_output.transpose(1, 2).contiguous()
triton_output = triton_output.view(batch_size, seq_len, -1)
@@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params):
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
triton_output = model_attn.out_proj(triton_output)
+ # Test SGL implementation
+ sgl_output = torch.empty_like(v)
+ sgl_new_kv = torch.empty_like(past_kv)
+ sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
+
+ sgl_output = sgl_output.transpose(1, 2).contiguous()
+ sgl_output = sgl_output.view(batch_size, seq_len, -1)
+ sgl_output = model_attn.norm(sgl_output)
+ sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
+ sgl_output = model_attn.out_proj(sgl_output)
+
+ # Verify Triton implementation results
torch.testing.assert_close(
model_output,
triton_output,
rtol=1e-3,
atol=1e-2,
- msg="Lightning attention implementations produce different output results",
+ msg="Triton lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
triton_new_kv,
rtol=1e-3,
atol=1e-2,
- msg="Lightning attention implementations produce different kv results",
+ msg="Triton lightning attention implementation produces different kv results",
)
- print("✅ Two implementations match")
+ # Verify SGL implementation results
+ torch.testing.assert_close(
+ model_output,
+ sgl_output,
+ rtol=1e-3,
+ atol=1e-2,
+ msg="SGL lightning attention implementation produces different output results",
+ )
+ torch.testing.assert_close(
+ new_kv,
+ sgl_new_kv,
+ rtol=1e-3,
+ atol=1e-2,
+ msg="SGL lightning attention implementation produces different kv results",
+ )
+
+ print("✅ All implementations match")
def _build_slope_tensor(n_attention_heads: int):
@@ -408,12 +442,13 @@ def get_benchmark():
x_names=["batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
- line_vals=["Original", "Triton"],
+ line_vals=["Original", "Triton", "SGL"],
line_names=[
"Original PyTorch Implementation",
"Triton Implementation",
+ "SGL Implementation",
],
- styles=[("blue", "-"), ("green", "-")],
+ styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name="lightning-attention-decode-performance",
args={},
@@ -446,7 +481,6 @@ def benchmark(batch_size, seq_len, provider):
params["num_attention_heads"],
d,
d,
- dtype=dtype,
device=device,
)
@@ -461,7 +495,7 @@ def benchmark(batch_size, seq_len, provider):
),
quantiles=quantiles,
)
- else:
+ elif provider == "Triton":
def run_triton():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
@@ -483,6 +517,33 @@ def run_triton():
run_triton,
quantiles=quantiles,
)
+ else: # SGL
+
+ def run_sgl():
+ qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
+ new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
+ qkv = qkv.view(*new_shape)
+ q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
+ q = q.transpose(1, 2).contiguous()
+ k = k.transpose(1, 2).contiguous()
+ v = v.transpose(1, 2).contiguous()
+
+ output = torch.empty_like(v)
+ new_kv = torch.empty_like(past_kv)
+ sgl_lightning_attention_decode(
+ q, k, v, past_kv, slope_rate, output, new_kv
+ )
+
+ output = output.transpose(1, 2).contiguous()
+ output = output.view(batch_size, seq_len, -1)
+ output = model_attn.norm(output)
+ output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
+ return model_attn.out_proj(output)
+
+ ms, min_ms, max_ms = triton.testing.do_bench(
+ run_sgl,
+ quantiles=quantiles,
+ )
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
diff --git a/sgl-kernel/benchmark/bench_lightning_attention_decode.py b/sgl-kernel/benchmark/bench_lightning_attention_decode.py
new file mode 100644
index 00000000000..24872e61a4d
--- /dev/null
+++ b/sgl-kernel/benchmark/bench_lightning_attention_decode.py
@@ -0,0 +1,299 @@
+import itertools
+import math
+
+import torch
+import triton
+import triton.language as tl
+from sgl_kernel import lightning_attention_decode
+
+
+def next_power_of_2(n):
+ return 2 ** (int(math.ceil(math.log(n, 2))))
+
+
+@triton.jit
+def _decode_kernel(
+ Q,
+ K,
+ V,
+ KV,
+ Out,
+ S,
+ b: tl.constexpr,
+ h: tl.constexpr,
+ n: tl.constexpr,
+ d: tl.constexpr,
+ d_original: tl.constexpr,
+ e: tl.constexpr,
+ e_original: tl.constexpr,
+):
+ off_bh = tl.program_id(0)
+ off_h = off_bh % h
+
+ qk_offset = off_bh * n * d
+ v_offset = off_bh * n * e
+ o_offset = off_bh * n * e
+ kv_offset = off_bh * d * e
+
+ s = tl.load(S + off_h)
+ ratio = tl.exp(-s)
+
+ d_idx = tl.arange(0, d)
+ e_idx = tl.arange(0, e)
+
+ # Create masks for original dimensions
+ d_mask = d_idx < d_original
+ e_mask = e_idx < e_original
+
+ # Load with masking
+ q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
+ k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
+ v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
+
+ # Load KV with 2D masking
+ kv = tl.load(
+ KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
+ mask=(d_mask[:, None] & e_mask[None, :]),
+ other=0.0,
+ )
+
+ # Compute outer product using element-wise operations
+ k_v_prod = k[:, None] * v[None, :]
+ kv = ratio * kv + k_v_prod
+
+ # Store KV with 2D masking
+ tl.store(
+ KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
+ kv.to(KV.dtype.element_ty),
+ mask=(d_mask[:, None] & e_mask[None, :]),
+ )
+
+ # Compute matrix-vector multiplication using element-wise operations and reduction
+ o = tl.sum(q[:, None] * kv, axis=0)
+
+ # Store output with masking
+ tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
+
+
+def triton_lightning_attn_decode(q, k, v, kv, s):
+ """Triton implementation of Lightning Attention decode operation"""
+ b, h, n, d = q.shape
+ e = v.shape[-1]
+ assert n == 1, "Sequence length must be 1 in decode mode"
+
+ # Get padded dimensions (power of 2)
+ d_padded = next_power_of_2(d)
+ e_padded = next_power_of_2(e)
+
+ # Create output tensor (padded)
+ o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
+
+ # Create padded tensors without actually padding the data
+ q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
+ k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
+ v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
+ kv_padded = torch.empty(
+ b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
+ )
+
+ # Copy data to padded tensors
+ q_padded[..., :d] = q
+ k_padded[..., :d] = k
+ v_padded[..., :e] = v
+ kv_padded[..., :d, :e] = kv
+
+ # Launch kernel
+ grid = (b * h, 1)
+ _decode_kernel[grid](
+ q_padded,
+ k_padded,
+ v_padded,
+ kv_padded,
+ o_padded,
+ s,
+ b=b,
+ h=h,
+ n=n,
+ d=d_padded,
+ d_original=d,
+ e=e_padded,
+ e_original=e,
+ )
+
+ # Get unpadded outputs
+ o = o_padded[..., :e]
+ kv_out = kv_padded[..., :d, :e]
+
+ return o, kv_out
+
+
+def lightning_attention_decode_naive(q, k, v, past_kv, slope):
+ """Naive implementation of lightning attention decode"""
+ original_dtype = q.dtype
+ ratio = torch.exp(-slope) # [h, 1, 1]
+
+ kv = past_kv
+ b, h, n, d = q.shape
+
+ output = []
+ for i in range(n):
+ kv = ratio * kv.to(torch.float32) + torch.einsum(
+ "... n d, ... n e -> ... d e",
+ k[:, :, i : i + 1],
+ v[:, :, i : i + 1],
+ )
+ qkv = torch.einsum(
+ "... n e, ... e d -> ... n d",
+ q[:, :, i : i + 1].to(torch.float32),
+ kv.to(torch.float32),
+ )
+ output.append(qkv)
+ output = torch.concat(output, dim=-2)
+
+ return output.to(original_dtype), kv
+
+
+def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv):
+ return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
+
+
+def calculate_diff(batch_size):
+ dtype = torch.bfloat16
+ device = torch.device("cuda")
+ num_heads = 64
+ head_dim = 96
+ seq_len = 1
+
+ q = torch.randn(
+ batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
+ )
+ k = torch.randn(
+ batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
+ )
+ v = torch.randn(
+ batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
+ )
+ past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
+ slope = torch.randn(num_heads, 1, 1, device=device)
+
+ output_naive, new_kv_naive = lightning_attention_decode_naive(
+ q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
+ )
+
+ output_kernel = torch.empty_like(output_naive)
+ new_kv_kernel = torch.empty_like(new_kv_naive)
+ lightning_attention_decode_kernel(
+ q.clone(),
+ k.clone(),
+ v.clone(),
+ past_kv.clone(),
+ slope.clone(),
+ output_kernel,
+ new_kv_kernel,
+ )
+
+ output_triton, new_kv_triton = triton_lightning_attn_decode(
+ q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
+ )
+
+ if (
+ torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2)
+ and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2)
+ and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2)
+ and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2)
+ ):
+ print("✅ All implementations match")
+ else:
+ print("❌ Implementations differ")
+
+
+batch_size_range = [i for i in range(1, 65)] # 1 to 128
+configs = [(bs,) for bs in batch_size_range]
+
+
+@triton.testing.perf_report(
+ triton.testing.Benchmark(
+ x_names=["batch_size"],
+ x_vals=[list(_) for _ in configs],
+ line_arg="provider",
+ line_vals=["naive", "kernel", "triton"],
+ line_names=["PyTorch Naive", "SGL Kernel", "Triton"],
+ styles=[("blue", "-"), ("red", "-"), ("green", "-")],
+ ylabel="us",
+ plot_name="lightning-attention-decode-performance",
+ args={},
+ )
+)
+def benchmark(batch_size, provider):
+ dtype = torch.bfloat16
+ device = torch.device("cuda")
+ num_heads = 64
+ head_dim = 96
+ seq_len = 1
+
+ q = torch.randn(
+ batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
+ )
+ k = torch.randn(
+ batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
+ )
+ v = torch.randn(
+ batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
+ )
+ past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
+ slope = torch.randn(num_heads, 1, 1, device=device)
+
+ quantiles = [0.5, 0.2, 0.8]
+
+ if provider == "naive":
+ ms, min_ms, max_ms = triton.testing.do_bench(
+ lambda: lightning_attention_decode_naive(
+ q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
+ ),
+ quantiles=quantiles,
+ )
+ elif provider == "kernel":
+ output = torch.empty(
+ batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
+ )
+ new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
+ ms, min_ms, max_ms = triton.testing.do_bench(
+ lambda: lightning_attention_decode_kernel(
+ q.clone(),
+ k.clone(),
+ v.clone(),
+ past_kv.clone(),
+ slope.clone(),
+ output,
+ new_kv,
+ ),
+ quantiles=quantiles,
+ )
+ elif provider == "triton":
+ ms, min_ms, max_ms = triton.testing.do_bench(
+ lambda: triton_lightning_attn_decode(
+ q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
+ ),
+ quantiles=quantiles,
+ )
+
+ return 1000 * ms, 1000 * max_ms, 1000 * min_ms
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--save_path",
+ type=str,
+ default="./configs/benchmark_ops/lightning_attention_decode_sgl/",
+ help="Path to save lightning attention decode benchmark results",
+ )
+ args = parser.parse_args()
+
+ # Run correctness test
+ calculate_diff(batch_size=4)
+
+ # Run performance benchmark
+ benchmark.run(print_data=True)
diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py
index 81cd96e99ad..9a2324b60d8 100644
--- a/sgl-kernel/setup.py
+++ b/sgl-kernel/setup.py
@@ -100,6 +100,7 @@ def get_device_sm():
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
+ "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
"3rdparty/flashinfer/csrc/activation.cu",
diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py
index 86c4f34d353..9eaa64e5083 100644
--- a/sgl-kernel/src/sgl-kernel/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/__init__.py
@@ -10,6 +10,7 @@
get_graph_buffer_ipc_meta,
init_custom_reduce,
int8_scaled_mm,
+ lightning_attention_decode,
moe_align_block_size,
register_graph_buffers,
rmsnorm,
@@ -35,5 +36,6 @@
"rmsnorm",
"rotary_embedding",
"sampling_scaling_penalties",
+ "lightning_attention_decode",
"silu_and_mul",
]
diff --git a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
new file mode 100644
index 00000000000..eb79373b22c
--- /dev/null
+++ b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
@@ -0,0 +1,119 @@
+#include
+#include
+#include
+#include
+#include
+
+#include "utils.h"
+
+#define THREADS_PER_BLOCK 128
+
+template
+__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d]
+ const T* __restrict__ k, // [b, h, 1, d]
+ const T* __restrict__ v, // [b, h, 1, e]
+ const float* __restrict__ past_kv, // [b, h, d, e]
+ const float* __restrict__ slope, // [h, 1, 1]
+ T* __restrict__ output, // [b, h, 1, e]
+ float* __restrict__ new_kv, // [b, h, d, e]
+ const int batch_size, const int num_heads, const int qk_dim,
+ const int v_dim) {
+ extern __shared__ char smem[];
+ T* q_shared = reinterpret_cast(smem);
+ T* k_shared = reinterpret_cast(smem + qk_dim * sizeof(T));
+ T* v_shared = reinterpret_cast(smem + 2 * qk_dim * sizeof(T));
+ float* new_kv_shared = reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T));
+ T* output_shared =
+ reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));
+
+ const int32_t tid = threadIdx.x;
+ const int32_t current_head = blockIdx.x;
+ const int32_t b = current_head / num_heads;
+ const int32_t h = current_head % num_heads;
+
+ if (b >= batch_size) return;
+
+ const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim;
+ const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
+ const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;
+
+ for (int d = tid; d < qk_dim; d += blockDim.x) {
+ q_shared[d] = q[qk_offset + d];
+ k_shared[d] = k[qk_offset + d];
+ }
+ for (int e = tid; e < v_dim; e += blockDim.x) {
+ v_shared[e] = v[v_offset + e];
+ }
+
+ __syncthreads();
+
+ const float ratio = expf(-1.0f * slope[h]);
+
+ for (int d = tid; d < qk_dim; d += blockDim.x) {
+ T k_val = k_shared[d];
+ for (int e = 0; e < v_dim; ++e) {
+ int past_kv_idx = kv_offset + d * v_dim + e;
+ T v_val = v_shared[e];
+ float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
+ int shared_idx = d * (v_dim + 1) + e;
+ new_kv_shared[shared_idx] = new_val;
+ }
+ }
+
+ __syncthreads();
+
+ for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
+ int d = idx / v_dim;
+ int e = idx % v_dim;
+ int shared_idx = d * (v_dim + 1) + e;
+ int global_idx = kv_offset + idx;
+ new_kv[global_idx] = new_kv_shared[shared_idx];
+ }
+
+ __syncthreads();
+
+ for (int e = tid; e < v_dim; e += blockDim.x) {
+ float sum = 0.0f;
+ for (int d = 0; d < qk_dim; ++d) {
+ int shared_idx = d * (v_dim + 1) + e;
+ sum += q_shared[d] * new_kv_shared[shared_idx];
+ }
+ output_shared[e] = static_cast(sum);
+ }
+
+ __syncthreads();
+
+ if (tid == 0) {
+ for (int e = 0; e < v_dim; ++e) {
+ output[v_offset + e] = output_shared[e];
+ }
+ }
+}
+
+void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
+ const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
+ torch::Tensor new_kv) {
+ TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
+ TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
+ TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
+ TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous");
+
+ auto batch_size = q.size(0);
+ auto num_heads = q.size(1);
+ auto qk_dim = q.size(3);
+ auto v_dim = v.size(3);
+
+ dim3 block(THREADS_PER_BLOCK);
+ dim3 grid(batch_size * num_heads);
+
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ AT_DISPATCH_FLOATING_TYPES_AND2(
+ at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
+ size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
+ lightning_attention_decode_kernel<<>>(
+ q.data_ptr(), k.data_ptr(), v.data_ptr(), past_kv.data_ptr(),
+ slope.data_ptr(), output.data_ptr(), new_kv.data_ptr(), batch_size, num_heads,
+ qk_dim, v_dim);
+ }));
+}
diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
index 12df0747171..cd5df07895a 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
@@ -26,6 +26,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional& bias);
+// lightning_attention_decode
+void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
+ const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
+ torch::Tensor new_kv);
+
// rotary embedding
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
@@ -69,6 +74,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
// int8_scaled_mm
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
+ // lightning_attention_decode
+ m.def("lightning_attention_decode", &lightning_attention_decode, "Lightning Attention Ddecode (CUDA)");
// rotary embedding
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
// rms norm
diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py
index d90f121d4f3..0aead260bc4 100644
--- a/sgl-kernel/src/sgl-kernel/ops/__init__.py
+++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py
@@ -14,6 +14,9 @@
)
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
+from sgl_kernel.ops._kernels import (
+ lightning_attention_decode as _lightning_attention_decode,
+)
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
@@ -86,6 +89,10 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
)
+def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
+ _lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
+
+
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
diff --git a/sgl-kernel/tests/test_lightning_attention_decode.py b/sgl-kernel/tests/test_lightning_attention_decode.py
new file mode 100644
index 00000000000..74af78e27b5
--- /dev/null
+++ b/sgl-kernel/tests/test_lightning_attention_decode.py
@@ -0,0 +1,84 @@
+import pytest
+import torch
+from sgl_kernel import lightning_attention_decode
+
+
+def naive_lightning_attention_decode(q, k, v, past_kv, slope):
+ """Naive implementation of lightning attention decode"""
+ original_dtype = q.dtype
+ ratio = torch.exp(-slope) # [h, 1, 1]
+
+ kv = past_kv
+ b, h, n, d = q.shape
+
+ output = []
+ for i in range(n):
+ kv = ratio * kv.to(torch.float32) + torch.einsum(
+ "... n d, ... n e -> ... d e",
+ k[:, :, i : i + 1],
+ v[:, :, i : i + 1],
+ )
+ qkv = torch.einsum(
+ "... n e, ... e d -> ... n d",
+ q[:, :, i : i + 1].to(torch.float32),
+ kv.to(torch.float32),
+ )
+ output.append(qkv)
+ output = torch.concat(output, dim=-2)
+
+ return output.to(original_dtype), kv
+
+
+configs = [
+ # (batch_size, num_heads, dim, embed_dim)
+ (1, 8, 64, 64),
+ (2, 8, 64, 64),
+ (1, 32, 32, 64),
+ (2, 32, 32, 64),
+ (4, 32, 64, 64),
+ (4, 32, 64, 64),
+ (16, 64, 96, 96),
+ (64, 64, 96, 96),
+]
+
+dtypes = [torch.float32, torch.float16, torch.bfloat16]
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
+@pytest.mark.parametrize("dtype", dtypes)
+@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs)
+def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim):
+ device = torch.device("cuda")
+
+ q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
+ k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
+ v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype)
+ past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device)
+ slope = torch.randn(num_heads, 1, 1, device=device)
+
+ ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope)
+
+ output = torch.empty_like(ref_output)
+ new_kv = torch.empty_like(ref_new_kv)
+ lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
+
+ rtol = 1e-2
+ atol = 1e-2
+
+ torch.testing.assert_close(
+ output,
+ ref_output,
+ rtol=rtol,
+ atol=atol,
+ msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, "
+ f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
+ )
+
+ torch.testing.assert_close(
+ new_kv,
+ ref_new_kv,
+ rtol=rtol,
+ atol=atol,
+ msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, "
+ f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
+ )
From 553f5a3ffe28d186524cb182849b1ef0d7020a49 Mon Sep 17 00:00:00 2001
From: Lianmin Zheng
Date: Thu, 23 Jan 2025 01:23:37 -0800
Subject: [PATCH 077/147] Remove torch dependency in sgl-kernel (#3074)
---
sgl-kernel/pyproject.toml | 4 +---
sgl-kernel/setup.py | 1 -
2 files changed, 1 insertion(+), 4 deletions(-)
diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml
index ab9d68b44c8..11e9880a5af 100644
--- a/sgl-kernel/pyproject.toml
+++ b/sgl-kernel/pyproject.toml
@@ -14,9 +14,7 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Environment :: GPU :: NVIDIA CUDA"
]
-dependencies = [
- "torch",
-]
+dependencies = []
[project.urls]
"Homepage" = "https://github.com/sgl-project/sglang"
diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py
index 9a2324b60d8..c51fd704504 100644
--- a/sgl-kernel/setup.py
+++ b/sgl-kernel/setup.py
@@ -127,7 +127,6 @@ def get_device_sm():
package_dir={"": "src"},
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
- install_requires=["torch"],
)
update_wheel_platform_tag()
From 1f6cf0d4b9bc3b03243157a854407fdcf1db6c11 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Thu, 23 Jan 2025 19:16:35 +0800
Subject: [PATCH 078/147] fix build error for sgl-kernel (#3078)
---
.github/workflows/release-pypi-kernel.yml | 4 ++--
sgl-kernel/pyproject.toml | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/release-pypi-kernel.yml b/.github/workflows/release-pypi-kernel.yml
index 466f2bdc70d..1b925b77218 100644
--- a/.github/workflows/release-pypi-kernel.yml
+++ b/.github/workflows/release-pypi-kernel.yml
@@ -18,8 +18,8 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
- cuda-version: ['12.1']
+ python-version: ['3.9', '3.10', '3.11', '3.12']
+ cuda-version: ['12.4']
steps:
- uses: actions/checkout@v4
diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml
index 11e9880a5af..eecf68b37cf 100644
--- a/sgl-kernel/pyproject.toml
+++ b/sgl-kernel/pyproject.toml
@@ -7,7 +7,7 @@ name = "sgl-kernel"
version = "0.0.2.post15"
description = "Kernel Library for SGLang"
readme = "README.md"
-requires-python = ">=3.8"
+requires-python = ">=3.9"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
From 3d0bfa3e17bb1468ccb93fcc731c7e2e99d12af1 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Thu, 23 Jan 2025 19:45:25 +0800
Subject: [PATCH 079/147] update version setup for sgl-kernel (#3079)
---
.github/workflows/release-pypi-kernel.yml | 2 +-
sgl-kernel/developer_guide.md | 6 +++---
sgl-kernel/setup.py | 10 ++--------
sgl-kernel/version.py | 1 +
4 files changed, 7 insertions(+), 12 deletions(-)
create mode 100644 sgl-kernel/version.py
diff --git a/.github/workflows/release-pypi-kernel.yml b/.github/workflows/release-pypi-kernel.yml
index 1b925b77218..c07069c5d12 100644
--- a/.github/workflows/release-pypi-kernel.yml
+++ b/.github/workflows/release-pypi-kernel.yml
@@ -5,7 +5,7 @@ on:
branches:
- main
paths:
- - sgl-kernel/pyproject.toml
+ - sgl-kernel/version.py
workflow_dispatch:
concurrency:
diff --git a/sgl-kernel/developer_guide.md b/sgl-kernel/developer_guide.md
index 8afb6b0e460..f41ce071e0b 100644
--- a/sgl-kernel/developer_guide.md
+++ b/sgl-kernel/developer_guide.md
@@ -26,8 +26,8 @@ Steps to add a new kernel:
1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc)
2. Expose interface in [csrc/sgl_kernel_ops.cu](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu) with pybind11
-3. Create Python wrapper in [src/sgl-kernel/ops/__init__.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
-4. Expose Python interface in [src/sgl-kernel/__init__.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
+3. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py)
+4. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py)
5. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source
### Build & Install
@@ -48,4 +48,4 @@ pip3 install dist/*whl --force-reinstall --no-deps
### Release new version
-Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml)
+Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/version.py)
diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py
index c51fd704504..71952655cd4 100644
--- a/sgl-kernel/setup.py
+++ b/sgl-kernel/setup.py
@@ -3,17 +3,11 @@
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+from version import __version__
root = Path(__file__).parent.resolve()
-def get_version():
- with open(root / "pyproject.toml") as f:
- for line in f:
- if line.startswith("version"):
- return line.split("=")[1].strip().strip('"')
-
-
def update_wheel_platform_tag():
wheel_dir = Path("dist")
if wheel_dir.exists() and wheel_dir.is_dir():
@@ -122,7 +116,7 @@ def get_device_sm():
setup(
name="sgl-kernel",
- version=get_version(),
+ version=__version__,
packages=find_packages(),
package_dir={"": "src"},
ext_modules=ext_modules,
diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py
new file mode 100644
index 00000000000..4bb48c132a2
--- /dev/null
+++ b/sgl-kernel/version.py
@@ -0,0 +1 @@
+__version__ = "0.0.2.post15"
From 07a22cbba34e5012d4ee9606c51ff9bac8124d0e Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Thu, 23 Jan 2025 20:46:49 +0800
Subject: [PATCH 080/147] use env variable to control the build conf on the CPU
build node (#3080)
---
sgl-kernel/build.sh | 3 ++
sgl-kernel/setup.py | 67 +++++++++++++++++++++++++++++++--------------
2 files changed, 49 insertions(+), 21 deletions(-)
diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh
index 0d816957951..c899224818e 100755
--- a/sgl-kernel/build.sh
+++ b/sgl-kernel/build.sh
@@ -11,6 +11,9 @@ docker run --rm \
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
export CUDA_VERSION=${CUDA_VERSION} && \
+ export SGL_KERNEL_ENABLE_BF16=1 && \
+ export SGL_KERNEL_ENABLE_FP8=1 && \
+ export SGL_KERNEL_ENABLE_SM90A=1 && \
mkdir -p /usr/lib/x86_64-linux-gnu/ && \
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \
cd /sgl-kernel && \
diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py
index 71952655cd4..184cc08c437 100644
--- a/sgl-kernel/setup.py
+++ b/sgl-kernel/setup.py
@@ -1,14 +1,14 @@
+import os
from pathlib import Path
import torch
from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
-from version import __version__
root = Path(__file__).parent.resolve()
-def update_wheel_platform_tag():
+def _update_wheel_platform_tag():
wheel_dir = Path("dist")
if wheel_dir.exists() and wheel_dir.is_dir():
old_wheel = next(wheel_dir.glob("*.whl"))
@@ -18,21 +18,25 @@ def update_wheel_platform_tag():
old_wheel.rename(new_wheel)
-def get_cuda_version():
+def _get_cuda_version():
if torch.version.cuda:
return tuple(map(int, torch.version.cuda.split(".")))
return (0, 0)
-def get_device_sm():
+def _get_device_sm():
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
return major * 10 + minor
return 0
-cuda_version = get_cuda_version()
-sm_version = get_device_sm()
+def _get_version():
+ with open(root / "pyproject.toml") as f:
+ for line in f:
+ if line.startswith("version"):
+ return line.split("=")[1].strip().strip('"')
+
cutlass = root / "3rdparty" / "cutlass"
flashinfer = root / "3rdparty" / "flashinfer"
@@ -58,19 +62,39 @@ def get_device_sm():
"-DFLASHINFER_ENABLE_F16",
]
-if cuda_version >= (12, 0) and sm_version >= 90:
- nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
-
-if sm_version >= 90:
- nvcc_flags.extend(
- [
- "-DFLASHINFER_ENABLE_FP8",
- "-DFLASHINFER_ENABLE_FP8_E4M3",
- "-DFLASHINFER_ENABLE_FP8_E5M2",
- ]
- )
-if sm_version >= 80:
- nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
+enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
+enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
+enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
+cuda_version = _get_cuda_version()
+sm_version = _get_device_sm()
+
+if torch.cuda.is_available():
+ if cuda_version >= (12, 0) and sm_version >= 90:
+ nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
+ if sm_version >= 90:
+ nvcc_flags.extend(
+ [
+ "-DFLASHINFER_ENABLE_FP8",
+ "-DFLASHINFER_ENABLE_FP8_E4M3",
+ "-DFLASHINFER_ENABLE_FP8_E5M2",
+ ]
+ )
+ if sm_version >= 80:
+ nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
+else:
+ # compilation environment without GPU
+ if enable_sm90a:
+ nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
+ if enable_fp8:
+ nvcc_flags.extend(
+ [
+ "-DFLASHINFER_ENABLE_FP8",
+ "-DFLASHINFER_ENABLE_FP8_E4M3",
+ "-DFLASHINFER_ENABLE_FP8_E5M2",
+ ]
+ )
+ if enable_bf16:
+ nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
for flag in [
"-D__CUDA_NO_HALF_OPERATORS__",
@@ -82,6 +106,7 @@ def get_device_sm():
torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag)
except ValueError:
pass
+
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python", "cuda"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
@@ -116,11 +141,11 @@ def get_device_sm():
setup(
name="sgl-kernel",
- version=__version__,
+ version=_get_version(),
packages=find_packages(),
package_dir={"": "src"},
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
)
-update_wheel_platform_tag()
+_update_wheel_platform_tag()
From 0da0989ad4f468ce35f4c8220241901a75ed1b26 Mon Sep 17 00:00:00 2001
From: Yineng Zhang
Date: Thu, 23 Jan 2025 21:13:55 +0800
Subject: [PATCH 081/147] sync flashinfer and update sgl-kernel tests (#3081)
---
.github/workflows/pr-test-sgl-kernel.yml | 2 +-
sgl-kernel/3rdparty/flashinfer | 2 +-
sgl-kernel/Makefile | 2 +-
sgl-kernel/tests/test_activation.py | 3 ++-
sgl-kernel/tests/test_lightning_attention_decode.py | 4 ++++
sgl-kernel/tests/test_norm.py | 4 ++++
6 files changed, 13 insertions(+), 4 deletions(-)
diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml
index 55eb636d64f..aea60969719 100644
--- a/.github/workflows/pr-test-sgl-kernel.yml
+++ b/.github/workflows/pr-test-sgl-kernel.yml
@@ -47,7 +47,7 @@ jobs:
pip3 list | grep sgl-kernel
- name: Run test
- timeout-minutes: 10
+ timeout-minutes: 30
run: |
cd sgl-kernel
find tests -name "test_*.py" | xargs -n 1 python3
diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer
index 4e8eb1879f9..93e1a2634e2 160000
--- a/sgl-kernel/3rdparty/flashinfer
+++ b/sgl-kernel/3rdparty/flashinfer
@@ -1 +1 @@
-Subproject commit 4e8eb1879f9c3ba6d75511e5893183bf8f289a62
+Subproject commit 93e1a2634e22355b0856246b032b285ad1d1da6b
diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile
index 9261b896934..c7641bb5fee 100644
--- a/sgl-kernel/Makefile
+++ b/sgl-kernel/Makefile
@@ -19,7 +19,7 @@ clean:
@rm -rf build dist *.egg-info
test:
- @find tests -name "test_*.py" | xargs -n 1 python3 && pytest tests/test_norm.py && pytest tests/test_activation.py
+ @find tests -name "test_*.py" | xargs -n 1 python3
format:
@find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black
diff --git a/sgl-kernel/tests/test_activation.py b/sgl-kernel/tests/test_activation.py
index f71f36b513d..43593441e3b 100644
--- a/sgl-kernel/tests/test_activation.py
+++ b/sgl-kernel/tests/test_activation.py
@@ -35,4 +35,5 @@ def test_fused_gelu_mul(dim, batch_size, seq_len):
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
-test_fused_silu_mul(128, 1, 1)
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/sgl-kernel/tests/test_lightning_attention_decode.py b/sgl-kernel/tests/test_lightning_attention_decode.py
index 74af78e27b5..f2cace00157 100644
--- a/sgl-kernel/tests/test_lightning_attention_decode.py
+++ b/sgl-kernel/tests/test_lightning_attention_decode.py
@@ -82,3 +82,7 @@ def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim
msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, "
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/sgl-kernel/tests/test_norm.py b/sgl-kernel/tests/test_norm.py
index 32f8c25d9f7..7b38dba72bf 100644
--- a/sgl-kernel/tests/test_norm.py
+++ b/sgl-kernel/tests/test_norm.py
@@ -127,3 +127,7 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
From f1b68618281d680add95b9c30635ef644f1f6f25 Mon Sep 17 00:00:00 2001
From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Date: Thu, 23 Jan 2025 22:19:04 +0800
Subject: [PATCH 082/147] use flashinfer vec_dtypes in sgl_kernel (#3083)
---
.../csrc/sampling_scaling_penalties.cu | 47 ++++++++--------
.../src/sgl-kernel/csrc/vectorization.cuh | 29 ----------
.../tests/test_sampling_scaling_penalties.py | 55 +++++++++----------
3 files changed, 51 insertions(+), 80 deletions(-)
delete mode 100644 sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
diff --git a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
index 2f53bb1a99f..2a9de4d9f71 100644
--- a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
+++ b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
@@ -1,11 +1,12 @@
#include
#include
#include
+#include
#include
+#include
#include "utils.h"
-#include "vectorization.cuh"
template
__global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const scalar_t* scaling_penalties,
@@ -13,31 +14,31 @@ __global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const
const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t stride = blockDim.x * gridDim.x;
- auto const* vectorized_logits = reinterpret_cast const*>(logits);
- auto const* vectorized_penalties = reinterpret_cast const*>(scaling_penalties);
- auto* vectorized_output = reinterpret_cast*>(output);
+ constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
+ using vec_t = flashinfer::vec_t;
- const int32_t num_vec_elems = numel >> 2;
+ const int32_t num_vec_elems = numel / vec_size;
-#pragma unroll 4
+#pragma unroll 1
for (int32_t i = tid; i < num_vec_elems; i += stride) {
- vec4_t logits_vec = vectorized_logits[i];
- vec4_t penalties_vec = vectorized_penalties[i];
- vec4_t out_vec;
+ vec_t logits_vec, penalties_vec, out_vec;
+ logits_vec.cast_load(logits + i * vec_size);
+ penalties_vec.cast_load(scaling_penalties + i * vec_size);
- out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x;
- out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y;
- out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z;
- out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w;
+#pragma unroll
+ for (uint32_t j = 0; j < vec_size; ++j) {
+ out_vec[j] = logits_vec[j] > scalar_t(0.0f) ? logits_vec[j] / penalties_vec[j] : logits_vec[j] * penalties_vec[j];
+ }
- vectorized_output[i] = out_vec;
+ out_vec.cast_store(output + i * vec_size);
}
- const int32_t start_idx = num_vec_elems * 4;
+ // process the remaining elements
+ const int32_t start_idx = num_vec_elems * vec_size;
for (int32_t i = start_idx + tid; i < numel; i += stride) {
scalar_t logit = logits[i];
scalar_t penalty = scaling_penalties[i];
- output[i] = logit > 0 ? logit / penalty : logit * penalty;
+ output[i] = logit > scalar_t(0.0f) ? logit / penalty : logit * penalty;
}
}
@@ -48,12 +49,14 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- AT_DISPATCH_FLOATING_TYPES_AND2(
- at::ScalarType::Half, at::ScalarType::BFloat16, logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] {
- const int blocks = (numel + threads * 4 - 1) / (threads * 4);
- sampling_scaling_penalties_kernel<<>>(
- logits.data_ptr(), scaling_penalties.data_ptr(), output.data_ptr(), numel);
- }));
+ DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(logits.scalar_type(), scalar_t, [&] {
+ uint32_t vec_size = 16 / sizeof(scalar_t);
+ const int blocks = (numel + threads * vec_size - 1) / (threads * vec_size);
+ sampling_scaling_penalties_kernel<<>>(
+ static_cast(logits.data_ptr()), static_cast