diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000000..0c061cd1871 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,35 @@ +From lmsysorg/sglang:dev + +# Create non-root user with specified UID and GID +# NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908. +ARG HOST_UID=1003 +ARG HOST_GID=1003 +RUN groupadd -g $HOST_GID devuser && \ + useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser + +# Give devuser sudo access +RUN apt-get update && apt-get install -y sudo && \ + echo "devuser ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/devuser && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean + +# Set up oh-my-zsh for devuser +RUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \ + cp /root/.zshrc /home/devuser/.zshrc && \ + cp /root/.vimrc /home/devuser/.vimrc && \ + cp /root/.tmux.conf /home/devuser/.tmux.conf && \ + sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \ + chown -R devuser:devuser /home/devuser/ + +# Set workspace directory and ownership +WORKDIR /sgl-workspace/sglang +RUN chown -R devuser:devuser /sgl-workspace + +# Switch to devuser +USER devuser + +# Install uv +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +# Install rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index aee28589864..5767aa2631a 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,8 +1,9 @@ { "name": "sglang", "build": { - "dockerfile": "../docker/Dockerfile.dev" + "dockerfile": "Dockerfile" }, + "remoteUser": "devuser", "customizations": { "vscode": { "extensions": [ @@ -15,6 +16,9 @@ ] } }, - "workspaceFolder": "/sgl-workspace/sglang", - "forwardPorts": [] + "forwardPorts": [], + "runArgs": [ + "--gpus", + "all" + ] } diff --git a/.github/workflows/execute-notebook.yml b/.github/workflows/execute-notebook.yml index e03edd6ce79..49d649797ed 100644 --- a/.github/workflows/execute-notebook.yml +++ b/.github/workflows/execute-notebook.yml @@ -42,7 +42,7 @@ jobs: python -m ipykernel install --user --name python3 --display-name "Python 3" - name: Execute notebooks - timeout-minutes: 30 + timeout-minutes: 40 run: | cd docs make clean diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 928d0efa5b3..277ddef774e 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -40,7 +40,7 @@ jobs: cd sgl-router/ cargo test - e2e-rust: + e2e-python: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 2-gpu-runner steps: @@ -65,7 +65,7 @@ jobs: python3 run_suite.py finish: - needs: [unit-test-rust, e2e-rust] + needs: [unit-test-rust, e2e-python] runs-on: ubuntu-latest steps: - name: Finish diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index 4115677dcb0..df059c1f402 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -16,24 +16,77 @@ concurrency: cancel-in-progress: true jobs: - unit-test: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' - runs-on: 1-gpu-runner + lint: + runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v3 - - name: Install dependencies + - name: Check clang-format + uses: DoozyX/clang-format-lint-action@v0.18.1 + with: + source: sgl-kernel + extensions: h,c,cpp,hpp,cu,cuh,cc + clangFormatVersion: 16 + style: file + + build-wheels: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: sgl-kernel-build-node + strategy: + matrix: + python-version: ['3.9'] + cuda-version: ['12.4'] + + steps: + - name: Cleanup run: | - bash scripts/ci_install_dependency.sh + sudo rm -rf $GITHUB_WORKSPACE/* || true + + - uses: actions/checkout@v4 + with: + submodules: 'recursive' + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }} + run: | cd sgl-kernel - git submodule update --init --recursive - pip3 install -e . --force-reinstall + chmod +x ./build.sh + ./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}" + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }} + path: sgl-kernel/dist/* + + unit-test: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + needs: build-wheels + runs-on: 1-gpu-runner + steps: + - uses: actions/checkout@v4 + + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-* + + - name: Install + run: | + pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1 + pip3 uninstall sgl-kernel -y || true + pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps pip3 list | grep sgl-kernel - name: Run test - timeout-minutes: 10 + timeout-minutes: 30 run: | cd sgl-kernel find tests -name "test_*.py" | xargs -n 1 python3 @@ -43,7 +96,7 @@ jobs: pip3 uninstall sgl-kernel -y finish: - needs: [unit-test] + needs: [unit-test, lint] runs-on: ubuntu-latest steps: - name: Finish diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 51117127ada..6ed6046ee6a 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -29,7 +29,7 @@ concurrency: jobs: unit-test-frontend: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -48,11 +48,12 @@ jobs: python3 run_suite.py --suite per-commit unit-test-backend-1-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner strategy: + fail-fast: false matrix: - range: [0-6, 6-16, 16-23, 23-30, 30-38, 38-100] + range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100] steps: - name: Checkout code uses: actions/checkout@v3 @@ -75,7 +76,7 @@ jobs: unit-test-backend-2-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 2-gpu-runner steps: - name: Checkout code @@ -112,7 +113,7 @@ jobs: python3 test_moe_ep.py performance-test-1-gpu-part-1: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -128,7 +129,7 @@ jobs: timeout-minutes: 10 run: | cd test/srt - python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_default + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1 - name: Benchmark online latency timeout-minutes: 10 @@ -148,8 +149,15 @@ jobs: cd test/srt python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size + - name: Benchmark online latency (EAGLE) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_eagle + + performance-test-1-gpu-part-2: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -180,7 +188,7 @@ jobs: python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8 performance-test-2-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 2-gpu-runner steps: - name: Checkout code @@ -196,7 +204,13 @@ jobs: timeout-minutes: 10 run: | cd test/srt - python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1 + + - name: Benchmark single latency + torch.compile (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_torch_compile_tp2_bs1 - name: Benchmark offline throughput (TP=2) timeout-minutes: 10 @@ -210,8 +224,9 @@ jobs: cd test/srt python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache + accuracy-test-1-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -235,7 +250,7 @@ jobs: accuracy-test-2-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 2-gpu-runner steps: - name: Checkout code diff --git a/.github/workflows/release-pypi-kernel.yml b/.github/workflows/release-pypi-kernel.yml index 362088c47fd..af34c8423ce 100644 --- a/.github/workflows/release-pypi-kernel.yml +++ b/.github/workflows/release-pypi-kernel.yml @@ -5,7 +5,7 @@ on: branches: - main paths: - - sgl-kernel/pyproject.toml + - sgl-kernel/version.py workflow_dispatch: concurrency: @@ -14,11 +14,12 @@ concurrency: jobs: build-wheels: + if: github.repository == 'sgl-project/sglang' runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] - cuda-version: ['12.1'] + python-version: ['3.9'] + cuda-version: ['12.4'] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/release-pypi-router.yml b/.github/workflows/release-pypi-router.yml index df20c211cb3..547522e8aa6 100644 --- a/.github/workflows/release-pypi-router.yml +++ b/.github/workflows/release-pypi-router.yml @@ -7,7 +7,7 @@ on: branches: - main paths: - - sglang-router/pyproject.toml + - sgl-router/pyproject.toml workflow_dispatch: jobs: @@ -26,9 +26,9 @@ jobs: with: path: sglang-repo - - name: Move sglang-router folder to root and delete sglang-repo + - name: Move sgl-router folder to root and delete sglang-repo run: | - mv sglang-repo/sglang-router/* . + mv sglang-repo/sgl-router/* . rm -rf sglang-repo ls -alt @@ -69,9 +69,9 @@ jobs: with: path: sglang-repo - - name: Move sglang-router folder to root, copy the license file, and delete sglang-repo + - name: Move sgl-router folder to root, copy the license file, and delete sglang-repo run: | - mv sglang-repo/sglang-router/* . + mv sglang-repo/sgl-router/* . mv sglang-repo/LICENSE . rm -rf sglang-repo ls -alt @@ -84,6 +84,7 @@ jobs: - name: Build SDist run: | pip install build + python -m pip install -U packaging python -m build --sdist - uses: actions/upload-artifact@v4 diff --git a/.github/workflows/release-whl-kernel.yml b/.github/workflows/release-whl-kernel.yml new file mode 100644 index 00000000000..70c451778fa --- /dev/null +++ b/.github/workflows/release-whl-kernel.yml @@ -0,0 +1,92 @@ +name: Release SGLang Kernel Wheel (cu118) + +on: + workflow_dispatch: + inputs: + tag_name: + type: string + push: + branches: + - main + paths: + - sgl-kernel/version.py + +jobs: + build-wheels: + if: github.repository == 'sgl-project/sglang' + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9'] + cuda-version: ['11.8'] + + steps: + - uses: actions/checkout@v4 + with: + submodules: 'recursive' + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }} + run: | + cd sgl-kernel + chmod +x ./build.sh + ./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}" + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }} + path: sgl-kernel/dist/* + + release: + needs: build-wheels + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-* + + - name: Set tag name + id: set_tag_name + run: | + if [ -z "${{ inputs.tag_name }}" ]; then + TAG_NAME="v$(cat sgl-kernel/version.py | cut -d'"' -f2)" + echo "tag_name=$TAG_NAME" >> $GITHUB_OUTPUT + else + echo "tag_name=${{ inputs.tag_name }}" >> $GITHUB_OUTPUT + fi + + - name: Release + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ steps.set_tag_name.outputs.tag_name }} + repository: sgl-project/whl + token: ${{ secrets.WHL_TOKEN }} + files: | + sgl-kernel/dist/* + + - name: Clone wheel index + run: git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl + env: + WHL_TOKEN: ${{ secrets.WHL_TOKEN }} + + - name: Update wheel index + run: python3 scripts/update_kernel_whl_index.py + + - name: Push wheel index + run: | + cd sgl-whl + git config --local user.name "github-actions[bot]" + git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" + git add -A + git commit -m "update whl index" + git push diff --git a/.gitignore b/.gitignore index 41432ff2968..267e4085338 100644 --- a/.gitignore +++ b/.gitignore @@ -222,6 +222,11 @@ work_dirs/ compile_commands.json *.iml -.vscode/ + +# VSCode +.vscode + +1 + *.nsys-rep *.ncu-rep diff --git a/.gitmodules b/.gitmodules index c588176e7c0..97f3421449d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,12 @@ [submodule "sgl-kernel/3rdparty/cutlass"] path = sgl-kernel/3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git -[submodule "sgl-kernel/3rdparty/cub"] - path = sgl-kernel/3rdparty/cub - url = https://github.com/NVIDIA/cub.git +[submodule "sgl-kernel/3rdparty/cccl"] + path = sgl-kernel/3rdparty/cccl + url = https://github.com/NVIDIA/cccl.git +[submodule "sgl-kernel/3rdparty/flashinfer"] + path = sgl-kernel/3rdparty/flashinfer + url = https://github.com/flashinfer-ai/flashinfer.git +[submodule "sgl-kernel/3rdparty/turbomind"] + path = sgl-kernel/3rdparty/turbomind + url = https://github.com/InternLM/turbomind diff --git a/README.md b/README.md index f99e0c5d3fe..0b08c919949 100644 --- a/README.md +++ b/README.md @@ -62,16 +62,16 @@ python -m sglang.launch_server \ | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) | ## News -- [2024/12] 🔥 SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). -- [2024/10] 🔥 The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)). -- [2024/09] SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). -- [2024/07] Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)). +- [2025/01] 🔥 SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html)) +- [2024/12] 🔥 v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). +- [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). +- [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
More +- [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)). - [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)). -- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)). - [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)). - [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)). diff --git a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py index 0a6049a1200..e2c4d8d3506 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py @@ -1,12 +1,13 @@ import argparse import itertools -import time import torch import triton import triton.language as tl from sgl_kernel import moe_align_block_size +USE_RANDOM_PERM = False + def ceil_div(a, b): return (a + b - 1) // b @@ -141,8 +142,13 @@ def moe_align_block_size_triton( def calculate_diff(batch_size, seq_len): num_experts = 256 block_size = 128 - topk_ids = torch.randint( - 0, num_experts, (batch_size, seq_len), dtype=torch.int32, device="cuda" + topk = 8 + + topk_ids = torch.stack( + [ + torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] + for _ in range(batch_size * seq_len) + ] ) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) @@ -169,7 +175,7 @@ def calculate_diff(batch_size, seq_len): expert_ids_triton = torch.empty_like(expert_ids_cuda) num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) - # 运行两个实现 + # compare the performance of cuda and triton implementation moe_align_block_size( topk_ids, num_experts, @@ -206,6 +212,15 @@ def calculate_diff(batch_size, seq_len): configs = list(itertools.product(batch_size_range, seq_length_range)) +def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: + topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda") + for i in range(num_tokens): + topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[ + :topk + ] + return topk_ids + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size", "seq_len"], @@ -223,9 +238,17 @@ def benchmark(batch_size, seq_len, provider): num_experts = 256 block_size = 128 topk = 8 - topk_ids = torch.randint( - 0, num_experts, (batch_size * seq_len, topk), dtype=torch.int32, device="cuda" - ) + + if USE_RANDOM_PERM: + topk_ids = get_topk_ids(batch_size * seq_len, num_experts, topk) + else: + topk_ids = torch.randint( + 0, + num_experts, + (batch_size * seq_len, topk), + dtype=torch.int32, + device="cuda", + ) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) sorted_ids = torch.empty( diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 72715fb5072..249401d0910 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -18,6 +18,9 @@ get_default_config, get_moe_configs, ) +from sglang.srt.utils import is_hip + +_is_hip_ = is_hip() class BenchmarkConfig(TypedDict): @@ -102,8 +105,8 @@ def benchmark_config( (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32 ) - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) + w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn) input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) @@ -165,17 +168,15 @@ def run(): return avg -def get_configs_compute_bound() -> List[Dict[str, int]]: - # Reduced search space for faster tuning. - # TODO(woosuk): Increase the search space and use a performance model to - # prune the search space. +def get_rocm_configs_compute_bound() -> List[Dict[str, int]]: configs: List[BenchmarkConfig] = [] - for num_stages in [2, 3, 4, 5]: - for block_m in [16, 32, 64, 128, 256]: - for block_k in [64, 128, 256]: - for block_n in [32, 64, 128, 256]: + waves_per_eu_range = 0 + for num_stages in [2]: + for block_m in [32, 64, 128, 256]: + for block_k in [32, 64, 128, 256]: + for block_n in [16, 32, 64, 128, 256]: for num_warps in [4, 8]: - for group_size in [1, 16, 32, 64]: + for group_size in [1, 4, 8, 16, 32]: configs.append( { "BLOCK_SIZE_M": block_m, @@ -184,11 +185,39 @@ def get_configs_compute_bound() -> List[Dict[str, int]]: "GROUP_SIZE_M": group_size, "num_warps": num_warps, "num_stages": num_stages, + "waves_per_eu": waves_per_eu_range, } ) return configs +def get_configs_compute_bound() -> List[Dict[str, int]]: + # Reduced search space for faster tuning. + # TODO(woosuk): Increase the search space and use a performance model to + # prune the search space. + configs: List[BenchmarkConfig] = [] + if _is_hip_: + configs = get_rocm_configs_compute_bound() + else: + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128, 256]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + @ray.remote(num_gpus=1) class BenchmarkWorker: @@ -297,6 +326,9 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: "GROUP_SIZE_M": config["GROUP_SIZE_M"], "num_warps": config["num_warps"], "num_stages": config["num_stages"], + **( + {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {} + ), } diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py index a2d1e10f662..57fbcfddf2c 100644 --- a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py +++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py @@ -9,6 +9,7 @@ import triton import triton.language as tl from einops import rearrange +from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode @triton.jit @@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params): model_params["num_attention_heads"], d, d, - dtype=dtype, device=device, ) with torch.no_grad(): @@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params): q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + past_kv = past_kv.contiguous() + slope_rate = slope_rate.contiguous() + # Test Triton implementation triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate) triton_output = triton_output.transpose(1, 2).contiguous() triton_output = triton_output.view(batch_size, seq_len, -1) @@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params): triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output triton_output = model_attn.out_proj(triton_output) + # Test SGL implementation + sgl_output = torch.empty_like(v) + sgl_new_kv = torch.empty_like(past_kv) + sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv) + + sgl_output = sgl_output.transpose(1, 2).contiguous() + sgl_output = sgl_output.view(batch_size, seq_len, -1) + sgl_output = model_attn.norm(sgl_output) + sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output + sgl_output = model_attn.out_proj(sgl_output) + + # Verify Triton implementation results torch.testing.assert_close( model_output, triton_output, rtol=1e-3, atol=1e-2, - msg="Lightning attention implementations produce different output results", + msg="Triton lightning attention implementation produces different output results", ) torch.testing.assert_close( new_kv, triton_new_kv, rtol=1e-3, atol=1e-2, - msg="Lightning attention implementations produce different kv results", + msg="Triton lightning attention implementation produces different kv results", ) - print("✅ Two implementations match") + # Verify SGL implementation results + torch.testing.assert_close( + model_output, + sgl_output, + rtol=1e-3, + atol=1e-2, + msg="SGL lightning attention implementation produces different output results", + ) + torch.testing.assert_close( + new_kv, + sgl_new_kv, + rtol=1e-3, + atol=1e-2, + msg="SGL lightning attention implementation produces different kv results", + ) + + print("✅ All implementations match") def _build_slope_tensor(n_attention_heads: int): @@ -408,12 +442,13 @@ def get_benchmark(): x_names=["batch_size", "seq_len"], x_vals=[list(_) for _ in configs], line_arg="provider", - line_vals=["Original", "Triton"], + line_vals=["Original", "Triton", "SGL"], line_names=[ "Original PyTorch Implementation", "Triton Implementation", + "SGL Implementation", ], - styles=[("blue", "-"), ("green", "-")], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], ylabel="us", plot_name="lightning-attention-decode-performance", args={}, @@ -446,7 +481,6 @@ def benchmark(batch_size, seq_len, provider): params["num_attention_heads"], d, d, - dtype=dtype, device=device, ) @@ -461,7 +495,7 @@ def benchmark(batch_size, seq_len, provider): ), quantiles=quantiles, ) - else: + elif provider == "Triton": def run_triton(): qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) @@ -483,6 +517,33 @@ def run_triton(): run_triton, quantiles=quantiles, ) + else: # SGL + + def run_sgl(): + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + + output = torch.empty_like(v) + new_kv = torch.empty_like(past_kv) + sgl_lightning_attention_decode( + q, k, v, past_kv, slope_rate, output, new_kv + ) + + output = output.transpose(1, 2).contiguous() + output = output.view(batch_size, seq_len, -1) + output = model_attn.norm(output) + output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output + return model_attn.out_proj(output) + + ms, min_ms, max_ms = triton.testing.do_bench( + run_sgl, + quantiles=quantiles, + ) return 1000 * ms, 1000 * max_ms, 1000 * min_ms diff --git a/docker/Dockerfile b/docker/Dockerfile index 1901d4c27a1..1fe702d4014 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -32,6 +32,7 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118; \ + python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ else \ echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ fi \ @@ -43,6 +44,7 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ + python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ else \ echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ fi; \ @@ -53,6 +55,7 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ + python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ else \ echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ fi; \ diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index 70860d8ef88..5ff1fa7a51a 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -18,6 +18,9 @@ RUN apt-get update && apt-get install -y \ silversearcher-ag \ cloc \ unzip \ + pkg-config \ + libssl-dev \ + bear \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean @@ -39,7 +42,8 @@ RUN python3 -m pip install --no-cache-dir \ pytest \ black \ isort \ - icdiff + icdiff \ + pre-commit # Install diff-so-fancy RUN curl -LSso /usr/local/bin/diff-so-fancy https://github.com/so-fancy/diff-so-fancy/releases/download/v1.4.4/diff-so-fancy \ diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 5a6e9770b72..f04254e54c9 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,5 +1,5 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.4.1.post6 -t v0.4.1.post6-rocm620 -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.4.2 -t v0.4.2-rocm620 -f Dockerfile.rocm . # default base image ARG BASE_IMAGE="rocmshared/vllm-rocm:20250114-tuned-elementwise-layernorm" diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb index 47a2e227806..05e7108e60e 100644 --- a/docs/backend/function_calling.ipynb +++ b/docs/backend/function_calling.ipynb @@ -4,32 +4,23 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Function Calling\n", + "# Tool and Function Calling\n", "\n", - "This notebook provides a quick-start guide to use function tooling using SGLang chat completions API\n", - "\n", - "## Supported Models\n", - "\n", - "Currently, we added the support for tools calling in the following models:\n", - " - Llama 3.2 models\n", - " - Llama 3.1 models\n", - " - Qwen 2.5 models\n", - " - InternLM Models" + "This guide demonstrates how to use SGLang’s **Tool Calling** functionality." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Usage\n", - "\n", - "### Launch a server\n", - "\n", - "This code block is equivalent to executing\n", - "\n", - "`python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", - "--port 30000 --host 0.0.0.0`\n", - "in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the OpenAI-compatible APIs." + "## OpenAI Compatible API" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Launching the Server" ] }, { @@ -38,6 +29,8 @@ "metadata": {}, "outputs": [], "source": [ + "from openai import OpenAI\n", + "import json\n", "from sglang.utils import (\n", " execute_shell_command,\n", " wait_for_server,\n", @@ -45,21 +38,30 @@ " print_highlight,\n", ")\n", "\n", - "\n", "server_process = execute_shell_command(\n", - " \"\"\"\n", - " python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\n", - "\"\"\"\n", + " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --port 30333 --host 0.0.0.0\" # llama3\n", ")\n", + "wait_for_server(\"http://localhost:30333\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n", "\n", - "wait_for_server(\"http://localhost:30000\")" + "- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n", + "- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n", + "Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n", + "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Single Round Invocation" + "### Define Tools for Function Call\n", + "Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters." ] }, { @@ -68,8 +70,7 @@ "metadata": {}, "outputs": [], "source": [ - "from openai import OpenAI\n", - "\n", + "# Define tools\n", "tools = [\n", " {\n", " \"type\": \"function\",\n", @@ -79,22 +80,264 @@ " \"parameters\": {\n", " \"type\": \"object\",\n", " \"properties\": {\n", - " \"location\": {\n", + " \"city\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n", + " },\n", + " \"state\": {\n", " \"type\": \"string\",\n", - " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", + " \"description\": \"the two-letter abbreviation for the state that the city is\"\n", + " \" in, e.g. 'CA' which would mean 'California'\",\n", + " },\n", + " \"unit\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The unit to fetch the temperature in\",\n", + " \"enum\": [\"celsius\", \"fahrenheit\"],\n", " },\n", - " \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n", " },\n", - " \"required\": [\"location\"],\n", + " \"required\": [\"city\", \"state\", \"unit\"],\n", " },\n", " },\n", " }\n", - "]\n", - "messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_messages():\n", + " return [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n", + " }\n", + " ]\n", + "\n", + "\n", + "messages = get_messages()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize the Client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize OpenAI-like client\n", + "client = OpenAI(api_key=\"None\", base_url=\"http://0.0.0.0:30333/v1\")\n", + "model_name = client.models.list().data[0].id" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Non-Streaming Request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Non-streaming mode test\n", + "response_non_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.8,\n", + " top_p=0.8,\n", + " stream=False, # Non-streaming\n", + " tools=tools,\n", + ")\n", + "print_highlight(\"Non-stream response:\")\n", + "print(response_non_stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Streaming Request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Streaming mode test\n", + "print_highlight(\"Streaming response:\")\n", + "response_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.8,\n", + " top_p=0.8,\n", + " stream=True, # Enable streaming\n", + " tools=tools,\n", + ")\n", + "\n", + "chunks = []\n", + "for chunk in response_stream:\n", + " chunks.append(chunk)\n", + " if chunk.choices[0].delta.tool_calls:\n", + " print(chunk.choices[0].delta.tool_calls[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Handle Tool Calls\n", + "\n", + "When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Non-Streaming Request**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n", + "arguments_non_stream = (\n", + " response_non_stream.choices[0].message.tool_calls[0].function.arguments\n", + ")\n", + "\n", + "print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n", + "print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Streaming Request**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Parse and combine function call arguments\n", + "arguments = []\n", + "for chunk in chunks:\n", + " choice = chunk.choices[0]\n", + " delta = choice.delta\n", + " if delta.tool_calls:\n", + " tool_call = delta.tool_calls[0]\n", + " if tool_call.function.name:\n", + " print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n", + "\n", + " if tool_call.function.arguments:\n", + " arguments.append(tool_call.function.arguments)\n", + " print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n", + "\n", + "# Combine all fragments into a single JSON string\n", + "full_arguments = \"\".join(arguments)\n", + "print_highlight(f\"Final streamed function call arguments: {full_arguments}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define a Tool Function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This is a demonstration, define real function according to your usage.\n", + "def get_current_weather(city: str, state: str, unit: \"str\"):\n", + " return (\n", + " f\"The weather in {city}, {state} is 85 degrees {unit}. It is \"\n", + " \"partly cloudly, with highs in the 90's.\"\n", + " )\n", + "\n", + "\n", + "available_tools = {\"get_current_weather\": get_current_weather}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Execute the Tool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "call_data = json.loads(full_arguments)\n", + "\n", + "messages.append(\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"\",\n", + " \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n", + " }\n", + ")\n", "\n", - "client = OpenAI(api_key=\"YOUR_API_KEY\", base_url=\"http://0.0.0.0:30000/v1\")\n", - "model_name = client.models.list().data[0].id\n", - "response = client.chat.completions.create(\n", + "# Call the corresponding tool function\n", + "tool_name = messages[-1][\"tool_calls\"][\"name\"]\n", + "tool_to_call = available_tools[tool_name]\n", + "result = tool_to_call(**call_data)\n", + "print_highlight(f\"Function call result: {result}\")\n", + "messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n", + "\n", + "print_highlight(f\"Updated message history: {messages}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Send Results Back to Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "final_response = client.chat.completions.create(\n", " model=model_name,\n", " messages=messages,\n", " temperature=0.8,\n", @@ -102,17 +345,56 @@ " stream=False,\n", " tools=tools,\n", ")\n", + "print_highlight(\"Non-stream response:\")\n", + "print(final_response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Native API and SGLang Runtime (SRT)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "import requests\n", + "\n", + "# generate an answer\n", + "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", + "\n", + "messages = get_messages()\n", + "\n", + "input = tokenizer.apply_chat_template(\n", + " messages,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " tools=tools,\n", + ")\n", "\n", - "print(response)\n", + "gen_url = \"http://localhost:30333/generate\"\n", + "gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n", + "gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n", + "print(gen_response)\n", "\n", - "\"\"\"\n", + "# parse the response\n", + "parse_url = \"http://localhost:30333/function_call\"\n", "\n", - "ChatCompletion(id='d6f620e1767e490d85b5ce45c15151cf', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, \n", - "role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"a\": \"3\", \"b\": \"5\"}', name='add'), type='function')]), \n", - "matched_stop=128008)], created=1735411703, model='meta-llama/Llama-3.2-1B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, \n", - "usage=CompletionUsage(completion_tokens=23, prompt_tokens=198, total_tokens=221, completion_tokens_details=None, prompt_tokens_details=None))\n", + "function_call_input = {\n", + " \"text\": gen_response,\n", + " \"tool_call_parser\": \"llama3\",\n", + " \"tools\": tools,\n", + "}\n", "\n", - "\"\"\"" + "function_call_response = requests.post(parse_url, json=function_call_input)\n", + "function_call_response_json = function_call_response.json()\n", + "print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n", + "print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])" ] }, { @@ -128,17 +410,112 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## How to support a new model?\n", + "## Offline Engine API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sglang as sgl\n", + "from sglang.srt.function_call_parser import FunctionCallParser\n", + "from sglang.srt.managers.io_struct import Tool, Function\n", + "\n", + "llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", + "tokenizer = llm.tokenizer_manager.tokenizer\n", + "input_ids = tokenizer.apply_chat_template(\n", + " messages, tokenize=True, add_generation_prompt=True, tools=tools\n", + ")\n", + "\n", + "sampling_params = {\n", + " \"max_new_tokens\": 128,\n", + " \"temperature\": 0.3,\n", + " \"top_p\": 0.95,\n", + " \"skip_special_tokens\": False,\n", + "}\n", + "\n", + "# 1) Offline generation\n", + "result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\n", + "generated_text = result[\"text\"] # Assume there is only one prompt\n", + "\n", + "print(\"=== Offline Engine Output Text ===\")\n", + "print(generated_text)\n", + "\n", + "\n", + "# 2) Parse using FunctionCallParser\n", + "def convert_dict_to_tool(tool_dict: dict) -> Tool:\n", + " function_dict = tool_dict.get(\"function\", {})\n", + " return Tool(\n", + " type=tool_dict.get(\"type\", \"function\"),\n", + " function=Function(\n", + " name=function_dict.get(\"name\"),\n", + " description=function_dict.get(\"description\"),\n", + " parameters=function_dict.get(\"parameters\"),\n", + " ),\n", + " )\n", + "\n", + "\n", + "tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n", + "\n", + "parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n", + "normal_text, calls = parser.parse_non_stream(generated_text)\n", + "\n", + "print(\"\\n=== Parsing Result ===\")\n", + "print(\"Normal text portion:\", normal_text)\n", + "print(\"Function call portion:\")\n", + "for call in calls:\n", + " # call: ToolCallItem\n", + " print(f\" - tool name: {call.name}\")\n", + " print(f\" parameters: {call.parameters}\")\n", "\n", - "For adding support of more different models:\n", - " 1. Update the `TOOLS_TAG_LIST` in `sglang/srt/utils.py` with the tool tag used by the model.\n", - " 2. Add support in `parse_tool_response` function for converting into tool calls `sglang/srt/utils.py`\n" + "# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm.shutdown()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## How to support a new model?\n", + "1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\n", + "```\n", + "\tTOOLS_TAG_LIST = [\n", + "\t “<|plugin|>“,\n", + "\t ““,\n", + "\t “<|python_tag|>“,\n", + "\t “[TOOL_CALLS]”\n", + "\t]\n", + "```\n", + "2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\n", + "```\n", + " class NewModelDetector(BaseFormatDetector):\n", + "```\n", + "3. Add the new detector to the MultiFormatParser class that manages all the format detectors." ] } ], "metadata": { "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" } }, "nbformat": 4, diff --git a/docs/backend/offline_engine_api.ipynb b/docs/backend/offline_engine_api.ipynb index 7ce89d435d5..58d24ac3ff6 100644 --- a/docs/backend/offline_engine_api.ipynb +++ b/docs/backend/offline_engine_api.ipynb @@ -37,7 +37,7 @@ "outputs": [], "source": [ "# launch the offline engine\n", - "\n", + "from sglang.utils import stream_and_merge, async_stream_and_merge\n", "import sglang as sgl\n", "import asyncio\n", "\n", @@ -86,20 +86,22 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", - "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", "\n", - "print(\"\\n=== Testing synchronous streaming generation ===\")\n", + "sampling_params = {\n", + " \"temperature\": 0.2,\n", + " \"top_p\": 0.9,\n", + "}\n", "\n", - "for prompt in prompts:\n", - " print(f\"\\nPrompt: {prompt}\")\n", - " print(\"Generated text: \", end=\"\", flush=True)\n", + "print(\"\\n=== Testing synchronous streaming generation with overlap removal ===\\n\")\n", "\n", - " for chunk in llm.generate(prompt, sampling_params, stream=True):\n", - " print(chunk[\"text\"], end=\"\", flush=True)\n", + "for prompt in prompts:\n", + " print(f\"Prompt: {prompt}\")\n", + " merged_output = stream_and_merge(llm, prompt, sampling_params)\n", + " print(\"Generated text:\", merged_output)\n", " print()" ] }, @@ -117,9 +119,9 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", "\n", "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", @@ -152,13 +154,14 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", + "\n", "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", "\n", - "print(\"\\n=== Testing asynchronous streaming generation ===\")\n", + "print(\"\\n=== Testing asynchronous streaming generation (no repeats) ===\")\n", "\n", "\n", "async def main():\n", @@ -166,10 +169,11 @@ " print(f\"\\nPrompt: {prompt}\")\n", " print(\"Generated text: \", end=\"\", flush=True)\n", "\n", - " generator = await llm.async_generate(prompt, sampling_params, stream=True)\n", - " async for chunk in generator:\n", - " print(chunk[\"text\"], end=\"\", flush=True)\n", - " print()\n", + " # Replace direct calls to async_generate with our custom overlap-aware version\n", + " async for cleaned_chunk in async_stream_and_merge(llm, prompt, sampling_params):\n", + " print(cleaned_chunk, end=\"\", flush=True)\n", + "\n", + " print() # New line after each prompt\n", "\n", "\n", "asyncio.run(main())" diff --git a/docs/backend/openai_api_completions.ipynb b/docs/backend/openai_api_completions.ipynb index 8660da2f98f..58b524108db 100644 --- a/docs/backend/openai_api_completions.ipynb +++ b/docs/backend/openai_api_completions.ipynb @@ -41,10 +41,10 @@ ")\n", "\n", "server_process = execute_shell_command(\n", - " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\"\n", + " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30020 --host 0.0.0.0\"\n", ")\n", "\n", - "wait_for_server(\"http://localhost:30000\")" + "wait_for_server(\"http://localhost:30020\")" ] }, { @@ -68,7 +68,7 @@ "source": [ "import openai\n", "\n", - "client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "response = client.chat.completions.create(\n", " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", @@ -214,125 +214,8 @@ "metadata": {}, "source": [ "## Structured Outputs (JSON, Regex, EBNF)\n", - "You can specify a JSON schema, [regular expression](https://en.wikipedia.org/wiki/Regular_expression) or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (`json_schema`, `regex`, or `ebnf`) can be specified for a request.\n", "\n", - "SGLang supports two grammar backends:\n", - "\n", - "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", - "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints.\n", - " - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n", - "\n", - "Initialize the XGrammar backend using `--grammar-backend xgrammar` flag\n", - "```bash\n", - "python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", - "--port 30000 --host 0.0.0.0 --grammar-backend [xgrammar|outlines] # xgrammar or outlines (default: outlines)\n", - "```\n", - "\n", - "### JSON" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "json_schema = json.dumps(\n", - " {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n", - " \"population\": {\"type\": \"integer\"},\n", - " },\n", - " \"required\": [\"name\", \"population\"],\n", - " }\n", - ")\n", - "\n", - "response = client.chat.completions.create(\n", - " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", - " messages=[\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": \"Give me the information of the capital of France in the JSON format.\",\n", - " },\n", - " ],\n", - " temperature=0,\n", - " max_tokens=128,\n", - " response_format={\n", - " \"type\": \"json_schema\",\n", - " \"json_schema\": {\"name\": \"foo\", \"schema\": json.loads(json_schema)},\n", - " },\n", - ")\n", - "\n", - "print_highlight(response.choices[0].message.content)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Regular expression (use default \"outlines\" backend)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "response = client.chat.completions.create(\n", - " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", - " messages=[\n", - " {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n", - " ],\n", - " temperature=0,\n", - " max_tokens=128,\n", - " extra_body={\"regex\": \"(Paris|London)\"},\n", - ")\n", - "\n", - "print_highlight(response.choices[0].message.content)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### EBNF (use \"xgrammar\" backend)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# terminate the existing server(that's using default outlines backend) for this demo\n", - "terminate_process(server_process)\n", - "\n", - "# start new server with xgrammar backend\n", - "server_process = execute_shell_command(\n", - " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0 --grammar-backend xgrammar\"\n", - ")\n", - "wait_for_server(\"http://localhost:30000\")\n", - "\n", - "# EBNF example\n", - "ebnf_grammar = r\"\"\"\n", - " root ::= \"Hello\" | \"Hi\" | \"Hey\"\n", - " \"\"\"\n", - "response = client.chat.completions.create(\n", - " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", - " messages=[\n", - " {\"role\": \"system\", \"content\": \"You are a helpful EBNF test bot.\"},\n", - " {\"role\": \"user\", \"content\": \"Say a greeting.\"},\n", - " ],\n", - " temperature=0,\n", - " max_tokens=32,\n", - " extra_body={\"ebnf\": ebnf_grammar},\n", - ")\n", - "\n", - "print_highlight(response.choices[0].message.content)" + "For OpenAI compatible structed outputs API, refer to [Structured Outputs](https://docs.sglang.ai/backend/structured_outputs.html#OpenAI-Compatible-API) for more details.\n" ] }, { @@ -362,7 +245,7 @@ "import time\n", "from openai import OpenAI\n", "\n", - "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "requests = [\n", " {\n", @@ -465,7 +348,7 @@ "import time\n", "from openai import OpenAI\n", "\n", - "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "requests = []\n", "for i in range(100):\n", @@ -542,7 +425,7 @@ "from openai import OpenAI\n", "import os\n", "\n", - "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "requests = []\n", "for i in range(500):\n", diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 6d72aa55a3f..7e8f4ca0a54 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -1,13 +1,16 @@ # Server Arguments +## Common launch commands + - To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command. ``` python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 2 ``` -- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. +- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. We recommend [SGLang Router](https://docs.sglang.ai/router/router.html) for data parallelism. ``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2 +python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2 ``` + - If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`. ``` python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --mem-fraction-static 0.7 @@ -31,3 +34,151 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct # Node 1 python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 1 ``` + +Please consult the documentation below to learn more about the parameters you may provide when launching a server. + + +## Model and tokenizer + +* `model_path`: Path to the model that will be served. +* `tokenizer_path`: Defaults to the `model_path`. +* `tokenizer_mode`: By default `auto`, see [here](https://huggingface.co/docs/transformers/en/main_classes/tokenizer) for different mode. +* `load_format`: The format the weights are loaded in. Defaults to `*.safetensors`/`*.bin`. +* `trust_remote_code`: If `True`, will use locally cached config files, other wise use remote configs in HuggingFace. +* `dtype`: Dtype used for the model, defaults to `bfloat16`. +* `kv_cache_dtype`: Dtype of the kv cache, defaults to the `dtype`. +* `context_length`: The number of tokens our model can process *including the input*. Not that extending the default might lead to strange behavior. +* `device`: The device we put the model, defaults to `cuda`. +* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.html#Chat-Template). +* `is_embedding`: Set to true to perform [embedding](https://docs.sglang.ai/backend/openai_api_embeddings.html) / [enocode](https://docs.sglang.ai/backend/native_api.html#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api.html#Classify-(reward-model)) tasks. +* `revision`: Adjust if a specific version of the model should be used. +* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. +* `json_model_override_args`: Override model config with the provided JSON. +* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model. + +## Serving: HTTP & API + +### HTTP Server configuration + +* `port` and `host`: Setup the host for HTTP server. By default `host: str = "127.0.0.1"` and `port: int = 30000` + +### API configuration + +* `api_key`: Sets an API key for the server and the OpenAI-compatible API. +* `file_storage_pth`: Directory for storing uploaded or generated files from API calls. +* `enable_cache_report`: If set, includes detailed usage of cached tokens in the response usage. + +## Parallelism + +### Tensor parallelism + +* `tp_size`: The number of GPUs the model weights get sharded over. Mainly for saving memory rather than for high throughput, see [this blogpost](https://pytorch.org/tutorials/intermediate/TP_tutorial.html#how-tensor-parallel-works). + +### Data parallelism + +* `dp_size`: Will be deprecated. The number of data-parallel copies of the model. [SGLang router](https://docs.sglang.ai/router/router.html) is recommended instead of the current naive data parallel. +* `load_balance_method`: Will be deprecated. Load balancing strategy for data parallel requests. + +### Expert parallelism + +* `ep_size`: Distribute the experts onto multiple GPUs for MoE models. Remember to shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). + +## Memory and scheduling + +* `mem_fraction_static`: Fraction of the free GPU memory used for static memory like model weights and KV cache. If building KV cache fails, it should be increased. If CUDA runs out of memory, it should be decreased. +* `max_running_requests`: The maximum number of requests to run concurrently. +* `max_total_tokens`: The maximum number of tokens that can be stored into the KV cache. Use mainly for debugging. +* `chunked_prefill_size`: Perform the prefill in chunks of these size. Larger chunk size speeds up the prefill phase but increases the VRAM consumption. If CUDA runs out of memory, it should be decreased. +* `max_prefill_tokens`: Token budget of how many tokens to accept in one prefill batch. The actual number is the max of this parameter and the `context_length`. +* `schedule_policy`: The scheduling policy to control the processing order of waiting prefill requests in a single engine. +* `schedule_conservativeness`: Can be used to decrease/increase the conservativeness of the server when taking new requests. Highly conservative behavior leads to starvation, but low conservativeness leads to slowed-down performance. +* `cpu_offload_gb`: Reserve this amount of RAM in GB for offloading of model parameters to the CPU. +* `prefill_only_one_req`: When this flag is turned on, the engine prefills only one request at a time. + +## Other runtime options + +* `stream_interval`: Interval (in tokens) for streaming responses. Smaller values lead to smoother streaming, and larger values lead to better throughput. +* `random_seed`: Can be used to enforce more deterministic behavior. +* `watchdog_timeout`: Adjusts the watchdog thread’s timeout before killing the server if batch generation takes too long. +* `download_dir`: Use to override the default Hugging Face cache directory for model weights. +* `base_gpu_id`: Use to adjust first GPU used to distribute the model across available GPUs. +* `allow_auto_truncate`: Automatically truncate requests that exceed the maximum input length. + +## Logging + +* `log_level`: Global log verbosity. +* `log_level_http`: Separate verbosity level for the HTTP server logs (if unset, defaults to `log_level`). +* `log_requests`: Logs the inputs and outputs of all requests for debugging. +* `show_time_cost`: Prints or logs detailed timing info for internal operations (helpful for performance tuning). +* `enable_metrics`: Exports Prometheus-like metrics for request usage and performance. +* `decode_log_interval`: How often (in tokens) to log decode progress. + +## Multi-node distributed serving + +* `dist_init_addr`: The TCP address used for initializing PyTorch’s distributed backend (e.g. `192.168.0.2:25000`). +* `nnodes`: Total number of nodes in the cluster. Refer to how to run the [Llama 405B model](https://docs.sglang.ai/references/llama_405B.html#run-405b-fp16-on-two-nodes). +* `node_rank`: Rank (ID) of this node among the `nnodes` in the distributed setup. + + +## LoRA + +* `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supportet with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929). +* `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model. + +## Kernel backend + +* `attention_backend`: The backend for attention computation and KV cache management. +* `sampling_backend`: The backend for sampling. + +## Constrained Decoding + +* `grammar_backend`: The grammar backend for constraint decoding. Detailed usage can be found in this [document](https://docs.sglang.ai/backend/structured_outputs.html). +* `constrained_json_whitespace_pattern`: Use with `Outlines` grammar backend to allow JSON with syntatic newlines, tabs or multiple spaces. Details can be found [here](https://dottxt-ai.github.io/outlines/latest/reference/generation/json/#using-pydantic). + +## Speculative decoding + +* `speculative_draft_model_path`: The draft model path for speculative decoding. +* `speculative_algorithm`: The algorithm for speculative decoding. Currently only [Eagle](https://arxiv.org/html/2406.16858v1) is supported. Note that the radix cache, chunked prefill, and overlap scheduler are disabled when using eagle speculative decoding. +* `speculative_num_steps`: How many draft passes we run before verifying. +* `speculative_num_draft_tokens`: The number of tokens proposed in a draft. +* `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1). + + +## Double Sparsity + +* `enable_double_sparsity`: Enables [double sparsity](https://arxiv.org/html/2408.07092v2) which increases throughput. +* `ds_channel_config_path`: The double sparsity config. For a guide on how to generate the config for your model see [this repo](https://github.com/andy-yang-1/DoubleSparse/tree/main/config). +* `ds_heavy_channel_num`: Number of channel indices to keep for each layer. +* `ds_heavy_token_num`: Number of tokens used for attention during decode. Skip sparse decoding if `min_seq_len` in batch < this number. +* `ds_heavy_channel_type`: The type of heavy channels. Either `q`, `k` or `qk`. +* `ds_sparse_decode_threshold`: Don't apply sparse decoding if `max_seq_len` in batch < this threshold. + +## Debug options + +*Note: We recommend to stay with the defaults and only use these options for debugging for best possible performance.* + +* `disable_radix_cache`: Disable [Radix](https://lmsys.org/blog/2024-01-17-sglang/) backend for prefix caching. +* `disable_jump_forward`: Disable [jump-forward](https://lmsys.org/blog/2024-02-05-compressed-fsm/#our-method-jump-forward-decoding-with-a-compressed-finite-state-machine) for outlines grammar backend. +* `disable_cuda_graph`: Disable [cuda graph](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for model forward. +* `disable_cuda_graph_padding`: Disable cuda graph when padding is needed. In other case still use cuda graph. +* `disable_outlines_disk_cache`: Disable disk cache for outlines grammar backend. +* `disable_custom_all_reduce`: Disable usage of custom all reduce kernel. +* `disable_mla`: Disable [Multi-Head Latent Attention](https://arxiv.org/html/2405.04434v5) for Deepseek model. +* `disable_overlap_schedule`: Disable the [Overhead-Scheduler](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#zero-overhead-batch-scheduler). +* `enable_nan_detection`: Turning this on makes the sampler print a warning if the logits contain `NaN`. +* `enable_p2p_check`: Turns off the default of allowing always p2p check when accessing GPU. +* `triton_attention_reduce_in_fp32`: In triton kernels this will cast the intermediate attention result to `float32`. + +## Optimization + +*Note: Some of these options are still in experimental stage.* + +* `enable_mixed_chunk`: Enables mixing prefill and decode, see [this discussion](https://github.com/sgl-project/sglang/discussions/1163). +* `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models. Note that you need to choose `dp_size = tp_size` for this. +* `enable_ep_moe`: Enables expert parallelism, see the description of `ep_size`. +* `enable_torch_compile`: Torch compile the model. This is an experimental feature. +* `torch_compile_max_bs`: The maximum batch size when using `torch_compile`. +* `cuda_graph_max_bs`: Adjust the maximum batchsize when using cuda graph. By default this is chosen for you based on GPU specifics. +* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. +* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. +* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb new file mode 100644 index 00000000000..d69436eed17 --- /dev/null +++ b/docs/backend/speculative_decoding.ipynb @@ -0,0 +1,181 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Speculative Decoding\n", + "\n", + "SGLang now provides an EAGLE-based speculative decoding option. The implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n", + "\n", + "To run the following tests or benchmarks, you also need to install [**cutex**](https://pypi.org/project/cutex/): \n", + "> ```bash\n", + "> pip install cutex\n", + "> ```\n", + "\n", + "### Performance Highlights\n", + "\n", + "- Official EAGLE code ([SafeAILab/EAGLE](https://github.com/SafeAILab/EAGLE)): ~200 tokens/s\n", + "- Standard SGLang Decoding: ~156 tokens/s\n", + "- EAGLE Decoding in SGLang: ~297 tokens/s\n", + "- EAGLE Decoding in SGLang (w/ `torch.compile`): ~316 tokens/s\n", + "\n", + "All benchmarks below were run on a single H100." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## EAGLE Decoding\n", + "\n", + "To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# EAGLE decoding\n", + "from sglang.utils import (\n", + " execute_shell_command,\n", + " wait_for_server,\n", + " terminate_process,\n", + " print_highlight,\n", + ")\n", + "\n", + "server_process = execute_shell_command(\n", + " \"\"\"\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", + " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", + " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 --port=30020\n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(\"http://localhost:30020\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", + "\n", + "response = client.chat.completions.create(\n", + " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n", + " ],\n", + " temperature=0,\n", + " max_tokens=64,\n", + ")\n", + "\n", + "print_highlight(f\"Response: {response}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### EAGLE Decoding with `torch.compile`\n", + "\n", + "You can also enable `torch.compile` for further optimizations and optionally set `--cuda-graph-max-bs`:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "server_process = execute_shell_command(\n", + " \"\"\"\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", + " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", + " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 \\\n", + " --enable-torch-compile --cuda-graph-max-bs 2 --port=30020\n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(\"http://localhost:30020\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Benchmark Script\n", + "\n", + "The following code example shows how to measure the decoding speed when generating tokens:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import requests\n", + "\n", + "tic = time.time()\n", + "response = requests.post(\n", + " \"http://localhost:30020/generate\",\n", + " json={\n", + " \"text\": \"[INST] Give me a simple FastAPI server. Show the python code. [/INST]\",\n", + " \"sampling_params\": {\n", + " \"temperature\": 0,\n", + " \"max_new_tokens\": 256,\n", + " },\n", + " },\n", + ")\n", + "latency = time.time() - tic\n", + "ret = response.json()\n", + "completion_text = ret[\"text\"]\n", + "speed = ret[\"meta_info\"][\"completion_tokens\"] / latency\n", + "\n", + "print_highlight(completion_text)\n", + "print_highlight(f\"speed: {speed:.2f} token/s\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/backend/structured_outputs.ipynb b/docs/backend/structured_outputs.ipynb index 55ca0b627f9..e413743ccfd 100644 --- a/docs/backend/structured_outputs.ipynb +++ b/docs/backend/structured_outputs.ipynb @@ -16,11 +16,13 @@ "SGLang supports two grammar backends:\n", "\n", "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", - "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints and currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md).\n", + "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n", "\n", - "We suggest using XGrammar whenever possible for its better performance. For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n", + "We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n", "\n", - "To use Xgrammar, simply add `--grammar-backend` xgrammar when launching the server. If no backend is specified, Outlines will be used as the default." + "To use Xgrammar, simply add `--grammar-backend` xgrammar when launching the server. If no backend is specified, Outlines will be used as the default.\n", + "\n", + "For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\n" ] }, { @@ -92,7 +94,7 @@ " messages=[\n", " {\n", " \"role\": \"user\",\n", - " \"content\": \"Give me the information of the capital of France in the JSON format.\",\n", + " \"content\": \"Please generate the information of the capital of France in the JSON format.\",\n", " },\n", " ],\n", " temperature=0,\n", @@ -196,20 +198,6 @@ "print_highlight(response.choices[0].message.content)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "terminate_process(server_process)\n", - "server_process = execute_shell_command(\n", - " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\"\n", - ")\n", - "\n", - "wait_for_server(\"http://localhost:30000\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -236,15 +224,6 @@ "print_highlight(response.choices[0].message.content)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "terminate_process(server_process)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -252,21 +231,6 @@ "## Native API and SGLang Runtime (SRT)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "server_process = execute_shell_command(\n", - " \"\"\"\n", - "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010 --grammar-backend xgrammar\n", - "\"\"\"\n", - ")\n", - "\n", - "wait_for_server(\"http://localhost:30010\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -300,7 +264,7 @@ "\n", "# Make API request\n", "response = requests.post(\n", - " \"http://localhost:30010/generate\",\n", + " \"http://localhost:30000/generate\",\n", " json={\n", " \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n", " \"sampling_params\": {\n", @@ -345,7 +309,7 @@ "\n", "# JSON\n", "response = requests.post(\n", - " \"http://localhost:30010/generate\",\n", + " \"http://localhost:30000/generate\",\n", " json={\n", " \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n", " \"sampling_params\": {\n", @@ -375,7 +339,7 @@ "import requests\n", "\n", "response = requests.post(\n", - " \"http://localhost:30010/generate\",\n", + " \"http://localhost:30000/generate\",\n", " json={\n", " \"text\": \"Give me the information of the capital of France.\",\n", " \"sampling_params\": {\n", @@ -398,22 +362,6 @@ "print_highlight(response.json())" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "terminate_process(server_process)\n", - "server_process = execute_shell_command(\n", - " \"\"\"\n", - "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010\n", - "\"\"\"\n", - ")\n", - "\n", - "wait_for_server(\"http://localhost:30010\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -428,7 +376,7 @@ "outputs": [], "source": [ "response = requests.post(\n", - " \"http://localhost:30010/generate\",\n", + " \"http://localhost:30000/generate\",\n", " json={\n", " \"text\": \"Paris is the capital of\",\n", " \"sampling_params\": {\n", @@ -465,7 +413,7 @@ "source": [ "import sglang as sgl\n", "\n", - "llm_xgrammar = sgl.Engine(\n", + "llm = sgl.Engine(\n", " model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", grammar_backend=\"xgrammar\"\n", ")" ] @@ -513,7 +461,7 @@ " \"json_schema\": json.dumps(CapitalInfo.model_json_schema()),\n", "}\n", "\n", - "outputs = llm_xgrammar.generate(prompts, sampling_params)\n", + "outputs = llm.generate(prompts, sampling_params)\n", "for prompt, output in zip(prompts, outputs):\n", " print_highlight(\"===============================\")\n", " print_highlight(f\"Prompt: {prompt}\") # validate the output by the pydantic model\n", @@ -553,7 +501,7 @@ "\n", "sampling_params = {\"temperature\": 0.1, \"top_p\": 0.95, \"json_schema\": json_schema}\n", "\n", - "outputs = llm_xgrammar.generate(prompts, sampling_params)\n", + "outputs = llm.generate(prompts, sampling_params)\n", "for prompt, output in zip(prompts, outputs):\n", " print_highlight(\"===============================\")\n", " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" @@ -590,22 +538,12 @@ " ),\n", "}\n", "\n", - "outputs = llm_xgrammar.generate(prompts, sampling_params)\n", + "outputs = llm.generate(prompts, sampling_params)\n", "for prompt, output in zip(prompts, outputs):\n", " print_highlight(\"===============================\")\n", " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "llm_xgrammar.shutdown()\n", - "llm_outlines = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -626,7 +564,7 @@ "\n", "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"regex\": \"(France|England)\"}\n", "\n", - "outputs = llm_outlines.generate(prompts, sampling_params)\n", + "outputs = llm.generate(prompts, sampling_params)\n", "for prompt, output in zip(prompts, outputs):\n", " print_highlight(\"===============================\")\n", " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" @@ -638,7 +576,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm_outlines.shutdown()" + "llm.shutdown()" ] } ], diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md index edc03d66183..779c413977c 100644 --- a/docs/developer/setup_github_runner.md +++ b/docs/developer/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.1.post6-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2-rocm620 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.1.post6-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2-rocm620 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/index.rst b/docs/index.rst index 51796d4a107..aaa46384490 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,6 +29,8 @@ The core features include: backend/native_api.ipynb backend/offline_engine_api.ipynb backend/structured_outputs.ipynb + backend/speculative_decoding.ipynb + backend/function_calling.ipynb backend/server_arguments.md diff --git a/docs/references/benchmark_and_profiling.md b/docs/references/benchmark_and_profiling.md index 87ac5177424..0600b192b4f 100644 --- a/docs/references/benchmark_and_profiling.md +++ b/docs/references/benchmark_and_profiling.md @@ -64,16 +64,31 @@ with nvtx.annotate("description", color="color"): ```bash # set trace path export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log + # start server python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct -python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --profile +# send profiling request from client +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile ``` - -Traces can be visualized using https://ui.perfetto.dev/. +Please make sure that the `SGLANG_TORCH_PROFILER_DIR` should be set at both server and client side, otherwise the trace file cannot be generated correctly . A secure way will be setting `SGLANG_TORCH_PROFILER_DIR` in the `.*rc` file of shell (e.g. `~/.bashrc` for bash shells). - To profile offline ```bash export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8 ``` + +- View Traces + +Trace files can be loaded and visualized from: +1. https://ui.perfetto.dev/ (any browser) +2. chrome://tracing (Chrome browser only) + +If browser cannot open trace file due to its large size, +client can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs. +For example, when profiling a server, +```bash +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile +``` +sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly. diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 913395357e1..2bdceb90478 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -34,6 +34,10 @@ Overall, with these optimizations, we have achieved up to a 7x acceleration in o **Usage**: This optimization is aimed at improving throughput and should be used for scenarios with high QPS (Queries Per Second). Data Parallelism Attention optimization can be enabeld by `--enable-dp-attention` for DeepSeek Series Models. +

+ Data Parallelism Attention Performance Comparison +

+ **Reference**: Check [Blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models). ## Multi Node Tensor Parallelism diff --git a/docs/references/sampling_params.md b/docs/references/sampling_params.md index 5dad3fd1259..77d7c9f82e7 100644 --- a/docs/references/sampling_params.md +++ b/docs/references/sampling_params.md @@ -32,6 +32,20 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + log_metrics: bool = True + + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None + # LoRA related + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + + # Session info for continual prompting + session_params: Optional[Union[List[Dict], Dict]] = None + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. + custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None ``` The `sampling_params` follows this format @@ -90,6 +104,14 @@ repetition_penalty: float = 1.0, # difficult to infer the correct token ID by given `stop` strings. # Must be 0 <= value < max_new_tokens. Setting to 0 (default) will disable this penalty. min_new_tokens: int = 0, + + +## Custom Parameters for Custom Logit Processor. +# A dictionary of custom parameters for the custom logit processor. +# The custom logit processor takes a list of dictionaries as input, where each +# dictionary is the custom parameters for one token in a batch of the input. +# See also python/sglang/srt/sampling/custom_logit_processor.py +custom_params: Optional[Dict[str, Any]] = None, ``` ## Examples @@ -189,7 +211,7 @@ You can specify a JSON schema, regular expression or [EBNF](https://en.wikipedia SGLang supports two grammar backends: - [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints. -- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints. +- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints. - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) Initialize the XGrammar backend using `--grammar-backend xgrammar` flag @@ -253,3 +275,49 @@ response = requests.post( ) print(response.json()) ``` +### Custom Logit Processor +Launch a server with `--enable-custom-logit-processor` flag on. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --enable-custom-logit-processor +``` + +Define a custom logit processor that will always sample a specific token id. +```python +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor + +class DeterministicLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits to always + sample the given token id. + """ + + def __call__(self, logits, custom_param_list): + # Check that the number of logits matches the number of custom parameters + assert logits.shape[0] == len(custom_param_list) + key = "token_id" + + for i, param_dict in enumerate(custom_param_list): + # Mask all other tokens + logits[i, :] = -float("inf") + # Assign highest probability to the specified token + logits[i, param_dict[key]] = 0.0 + return logits +``` + +Send a request +```python +import requests + +response = requests.post( + "http://localhost:30000/generate", + json={ + "text": "The capital of France is", + "custom_logit_processor": DeterministicLogitProcessor().to_str(), + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": 32, + "custom_params": {"token_id": 5}, + }, + }, +) +print(response.json()) +``` diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 860841816e0..93c4273765d 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -28,6 +28,7 @@ - XVERSE / XVERSE MoE - SmolLM - GLM-4 +- Phi-3 / Phi-4 - Phi-3-Small - IBM Granite 3 @@ -77,10 +78,12 @@ Another valuable resource is the [vLLM Models Directory](https://github.com/vllm To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically, - Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`. - Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`. + - Replace Multi-headed `Attention` of ViT with SGLang's `VisionAttention`. - Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`). - Remove `Sample`. - Change `forward()` functions, and add `forward_batch`. - Add `EntryClass` at the end. + - Please ensure the new implementation uses **only SGLang components and does not rely on any vLLM components**. ### Registering an external model implementation @@ -90,7 +93,7 @@ Here is how you can do it: ```python from sglang.srt.models.registry import ModelRegistry -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server # for a single model, you can add it to the registry ModelRegistry.models[model_name] = model_class diff --git a/docs/start/install.md b/docs/start/install.md index 8b84527c4ff..90964ac6b6c 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -5,6 +5,7 @@ You can install SGLang using any of the methods below. ## Method 1: With pip ``` pip install --upgrade pip +pip install sgl-kernel --force-reinstall --no-deps pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ ``` @@ -13,23 +14,25 @@ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/ ## Method 2: From source ``` # Use the last release branch -git clone -b v0.4.1.post6 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip +pip install sgl-kernel --force-reinstall --no-deps pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ ``` -Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. +Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. If you meet with issue like **ImportError: cannot import name `_grouped_size_compiled_for_decode_kernels`**, installing FlashInfer with some older version like 0.1.6 instead of the latest version could solve it. Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: ``` # Use the last release branch -git clone -b v0.4.1.post6 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip +pip install sgl-kernel --force-reinstall --no-deps pip install -e "python[all_hip]" ``` @@ -51,7 +54,7 @@ docker run --gpus all \ Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.4.1.post6 -t v0.4.1.post6-rocm620 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.2 -t v0.4.2-rocm620 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -60,11 +63,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.4.1.post6-rocm620 \ + v0.4.2-rocm620 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.4.1.post6-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.2-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/examples/frontend_language/usage/json_decode.py b/examples/frontend_language/usage/json_decode.py index ce8f5ba7062..5dc3522d512 100644 --- a/examples/frontend_language/usage/json_decode.py +++ b/examples/frontend_language/usage/json_decode.py @@ -9,7 +9,7 @@ from pydantic import BaseModel import sglang as sgl -from sglang.srt.constrained import build_regex_from_object +from sglang.srt.constrained.outlines_backend import build_regex_from_object character_regex = ( r"""\{\n""" diff --git a/examples/frontend_language/usage/triton/models/character_generation/1/model.py b/examples/frontend_language/usage/triton/models/character_generation/1/model.py index 5550e93984b..4bf86f1b691 100644 --- a/examples/frontend_language/usage/triton/models/character_generation/1/model.py +++ b/examples/frontend_language/usage/triton/models/character_generation/1/model.py @@ -3,8 +3,8 @@ from pydantic import BaseModel import sglang as sgl -from sglang import function, set_default_backend -from sglang.srt.constrained import build_regex_from_object +from sglang import function +from sglang.srt.constrained.outlines_backend import build_regex_from_object sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) diff --git a/examples/runtime/async_io_api.py b/examples/runtime/async_io_api.py deleted file mode 100644 index 23d3d0b90bf..00000000000 --- a/examples/runtime/async_io_api.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Usage: - -python3 async_io.py -""" - -import asyncio - -from sglang import Runtime - - -async def generate( - engine, - prompt, - sampling_params, -): - tokenizer = engine.get_tokenizer() - - messages = [ - { - "role": "system", - "content": "You will be given question answer tasks.", - }, - {"role": "user", "content": prompt}, - ] - - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - - stream = engine.add_request(prompt, sampling_params) - - async for output in stream: - print(output, end="", flush=True) - print() - - -if __name__ == "__main__": - runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf") - print("--- runtime ready ---\n") - - prompt = "Who is Alan Turing?" - sampling_params = {"max_new_tokens": 128} - asyncio.run(generate(runtime, prompt, sampling_params)) - - runtime.shutdown() diff --git a/examples/runtime/engine/offline_batch_inference.py b/examples/runtime/engine/offline_batch_inference.py index 724051eab53..92e68dcd72c 100644 --- a/examples/runtime/engine/offline_batch_inference.py +++ b/examples/runtime/engine/offline_batch_inference.py @@ -1,3 +1,8 @@ +""" +Usage: +python3 offline_batch_inference.py --model meta-llama/Llama-3.1-8B-Instruct +""" + import argparse import dataclasses diff --git a/python/pyproject.toml b/python/pyproject.toml index 379a4c9acf8..11c984f82d7 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.4.1.post6" +version = "0.4.2" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" @@ -23,11 +23,11 @@ runtime_common = [ "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", - "xgrammar>=0.1.6" + "xgrammar>=0.1.10" ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.2.post14", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", + "sgl-kernel>=0.0.3", "torch", "vllm==0.6.4.post1", "flashinfer==0.1.6" ] diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index de9134857a6..70d58043d40 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -1,5 +1,6 @@ -# SGL API Components +# SGLang public APIs +# Frontend Language APIs from sglang.api import ( Engine, Runtime, @@ -23,16 +24,26 @@ user_end, video, ) +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.choices import ( greedy_token_selection, token_length_normalized, unconditional_likelihood_normalized, ) +from sglang.utils import LazyImport + +Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") +LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") +OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") +VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") + +# Other configs +from sglang.global_config import global_config +from sglang.version import __version__ -# SGLang DSL APIs __all__ = [ - "Runtime", "Engine", + "Runtime", "assistant", "assistant_begin", "assistant_end", @@ -52,27 +63,14 @@ "user_begin", "user_end", "video", + "RuntimeEndpoint", "greedy_token_selection", "token_length_normalized", "unconditional_likelihood_normalized", + "Anthropic", + "LiteLLM", + "OpenAI", + "VertexAI", + "global_config", + "__version__", ] - -# Global Configurations -from sglang.global_config import global_config - -__all__ += ["global_config"] - -from sglang.version import __version__ - -__all__ += ["__version__"] - -# SGLang Backends -from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.utils import LazyImport - -Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") -LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") -OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") -VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") - -__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"] diff --git a/python/sglang/api.py b/python/sglang/api.py index 9a30ad492da..7ef306380a9 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -1,6 +1,5 @@ """Public APIs of the language.""" -import os import re from typing import Callable, List, Optional, Union @@ -33,19 +32,15 @@ def decorator(func): def Runtime(*args, **kwargs): - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - # Avoid importing unnecessary dependency - from sglang.srt.server import Runtime + from sglang.lang.backend.runtime_endpoint import Runtime return Runtime(*args, **kwargs) def Engine(*args, **kwargs): - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - # Avoid importing unnecessary dependency - from sglang.srt.server import Engine + from sglang.srt.entrypoints.engine import Engine return Engine(*args, **kwargs) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index 54b042c115d..9d56ff07c8b 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -27,7 +27,8 @@ sample_random_requests, set_ulimit, ) -from sglang.srt.server import Engine, Runtime +from sglang.lang.backend.runtime_endpoint import Runtime +from sglang.srt.entrypoints.engine import Engine from sglang.srt.server_args import ServerArgs @@ -48,12 +49,13 @@ class BenchArgs: gsp_system_prompt_len: int = 2048 gsp_question_len: int = 128 gsp_output_len: int = 256 + seed: int = 1 disable_ignore_eos: bool = False extra_request_body: Optional[str] = None - seed: int = 1 + apply_chat_template: bool = False + profile: bool = False skip_warmup: bool = False do_not_exit: bool = False - profile: bool = False @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -140,20 +142,31 @@ def add_cli_args(parser: argparse.ArgumentParser): default=BenchArgs.gsp_output_len, help="Target length in tokens for outputs in generated-shared-prefix dataset", ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--disable-ignore-eos", - type=bool, - default=BenchArgs.disable_ignore_eos, + action="store_true", help="Disable ignore EOS token", ) parser.add_argument( "--extra-request-body", metavar='{"key1": "value1", "key2": "value2"}', type=str, + default=BenchArgs.extra_request_body, help="Append given JSON object to the request payload. You can use this to specify" "additional generate params like sampling params.", ) - parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) parser.add_argument( "--skip-warmup", action="store_true", @@ -164,12 +177,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Do not exit the program. This is useful for nsys profile with --duration and --delay.", ) - parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "SGLANG_TORCH_PROFILER_DIR to enable profiler.", - ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 99fba8be913..de846066e63 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -57,15 +57,21 @@ import torch.distributed as dist from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.server import _set_envs_and_config from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers +from sglang.srt.utils import ( + configure_logger, + get_bool_env_var, + kill_process_tree, + set_gpu_proc_affinity, + suppress_other_loggers, +) @dataclasses.dataclass @@ -99,10 +105,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + "--profile", action="store_true", help="Use Torch Profiler." ) parser.add_argument( "--profile-filename-prefix", @@ -232,6 +235,7 @@ def extend(reqs, model_runner): model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, + enable_custom_logit_processor=False, ) batch.prepare_for_extend() model_worker_batch = batch.get_model_worker_batch() @@ -380,6 +384,7 @@ def latency_test_run_once( parent_dir = os.path.dirname(os.path.abspath(profile_filename)) os.makedirs(parent_dir, exist_ok=True) profiler.export_chrome_trace(profile_filename) + rank_print(f"torch profiler chrome trace saved to {profile_filename}") # Record decode timing from 2nd output if output_len > 1: @@ -406,6 +411,10 @@ def latency_test( bench_args, tp_rank, ): + # Set CPU affinity + if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): + set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank) + # Configure the logger configure_logger(server_args, prefix=f" TP{tp_rank}") rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None @@ -450,7 +459,7 @@ def latency_test( il, ol, server_args.device, - bench_args.profile, + bench_args.profile if tp_rank == 0 else None, bench_args.profile_filename_prefix, ) if ret is not None: diff --git a/python/sglang/bench_one_batch_server.py b/python/sglang/bench_one_batch_server.py index 01cc561e1ce..5f0759a7ce1 100644 --- a/python/sglang/bench_one_batch_server.py +++ b/python/sglang/bench_one_batch_server.py @@ -22,7 +22,7 @@ import numpy as np import requests -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 991b4ddcf1a..10ce965be74 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -453,6 +453,7 @@ def get_dataset(args, tokenizer): tokenizer=tokenizer, fixed_output_len=args.sharegpt_output_len, context_len=args.sharegpt_context_len, + apply_chat_template=args.apply_chat_template, ) elif args.dataset_name == "random": input_requests = sample_random_requests( @@ -517,6 +518,7 @@ class BenchmarkMetrics: median_e2e_latency_ms: float std_e2e_latency_ms: float p99_e2e_latency_ms: float + concurrency: float SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" @@ -562,6 +564,7 @@ def sample_sharegpt_requests( tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, context_len: Optional[int] = None, + apply_chat_template=False, ) -> List[Tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -592,6 +595,15 @@ def sample_sharegpt_requests( # Tokenize the prompts and completions. prompt = dataset[i][0] + + if apply_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + prompt = prompt.replace(tokenizer.bos_token, "") + prompt_token_ids = tokenizer.encode(prompt) completion = dataset[i][1] completion_token_ids = tokenizer.encode(completion) @@ -600,7 +612,7 @@ def sample_sharegpt_requests( len(completion_token_ids) if fixed_output_len is None else fixed_output_len ) - if prompt_len < 1 or output_len < 1: + if prompt_len < 2 or output_len < 2: # Prune too short sequences. continue @@ -880,6 +892,7 @@ def calculate_metrics( median_e2e_latency_ms=np.median(e2e_latencies) * 1000, std_e2e_latency_ms=np.std(e2e_latencies) * 1000, p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000, + concurrency=np.sum(e2e_latencies) / dur_s, ) return metrics, output_lens @@ -1031,6 +1044,7 @@ async def limited_request_func(request_func_input, pbar): "Total token throughput (tok/s):", metrics.total_throughput ) ) + print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency)) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print( "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) @@ -1062,13 +1076,24 @@ async def limited_request_func(request_func_input, pbar): and metrics.output_throughput is not None ): result = { + # Arguments "backend": args.backend, "dataset_name": args.dataset_name, "request_rate": request_rate, "max_concurrency": max_concurrency, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + # Results + "duration": benchmark_duration, + "completed": metrics.completed, "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, "median_e2e_latency_ms": metrics.median_e2e_latency_ms, "std_e2e_latency_ms": metrics.std_e2e_latency_ms, @@ -1085,14 +1110,7 @@ async def limited_request_func(request_func_input, pbar): "median_itl_ms": metrics.median_itl_ms, "std_itl_ms": metrics.std_itl_ms, "p99_itl_ms": metrics.p99_itl_ms, - "input_throughput": metrics.input_throughput, - "output_throughput": metrics.output_throughput, - "sharegpt_output_len": args.sharegpt_output_len, - "random_input_len": args.random_input_len, - "random_output_len": args.random_output_len, - "random_range_ratio": args.random_range_ratio, - "duration": benchmark_duration, - "completed": metrics.completed, + "concurrency": metrics.concurrency, } else: print(f"Error running benchmark for request rate: {request_rate}") @@ -1112,36 +1130,16 @@ async def limited_request_func(request_func_input, pbar): with open(output_file_name, "a") as file: file.write(json.dumps(result) + "\n") - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "total_output_tokens_retokenized": metrics.total_output_retokenized, - "request_throughput": metrics.request_throughput, - "input_throughput": metrics.input_throughput, - "output_throughput": metrics.output_throughput, - "mean_ttft_ms": metrics.mean_ttft_ms, - "median_ttft_ms": metrics.median_ttft_ms, - "std_ttft_ms": metrics.std_ttft_ms, - "p99_ttft_ms": metrics.p99_ttft_ms, - "mean_tpot_ms": metrics.mean_tpot_ms, - "median_tpot_ms": metrics.median_tpot_ms, - "std_tpot_ms": metrics.std_tpot_ms, - "p99_tpot_ms": metrics.p99_tpot_ms, - "mean_itl_ms": metrics.mean_itl_ms, - "median_itl_ms": metrics.median_itl_ms, - "std_itl_ms": metrics.std_itl_ms, - "p99_itl_ms": metrics.p99_itl_ms, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, - "median_e2e_latency_ms": metrics.median_e2e_latency_ms, - } + result.update( + { + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + ) return result @@ -1422,7 +1420,6 @@ def set_ulimit(target_soft_limit=65535): "actual request rate may be lower than specified with --request-rate, " "if the server is not processing requests fast enough to keep up.", ) - parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--multi", action="store_true", @@ -1446,14 +1443,15 @@ def set_ulimit(target_soft_limit=65535): help="Disable streaming mode.", ) parser.add_argument( - "--disable-ignore-eos", + "--return-logprob", action="store_true", - help="Disable ignoring EOS.", + help="Return logprob.", ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( - "--return-logprob", + "--disable-ignore-eos", action="store_true", - help="Return logprob.", + help="Disable ignoring EOS.", ) parser.add_argument( "--extra-request-body", @@ -1462,6 +1460,11 @@ def set_ulimit(target_soft_limit=65535): help="Append given JSON object to the request payload. You can use this to specify" "additional generate params like sampling params.", ) + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) parser.add_argument( "--profile", action="store_true", diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index a0032591226..01f10b9f063 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -1,6 +1,11 @@ +import atexit import json +import multiprocessing import warnings -from typing import List, Optional +from typing import Dict, List, Optional, Union + +import aiohttp +import requests from sglang.global_config import global_config from sglang.lang.backend.base_backend import BaseBackend @@ -325,3 +330,171 @@ def _assert_success(self, res): def compute_normalized_prompt_logprobs(input_logprobs): values = [x[0] for x in input_logprobs if x[0]] return sum(values) / len(values) + + +class Runtime: + """ + A wrapper for the HTTP server. + This is used for launching the server in a python program without + using the commond line interface. + + It is mainly used for the frontend language. + You should use the Engine class if you want to do normal offline processing without the frontend language. + """ + + def __init__( + self, + log_level: str = "error", + *args, + **kwargs, + ): + """See the arguments in server_args.py::ServerArgs""" + # We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run + # client code without installing SRT server and its dependency if they want. + from sglang.srt.entrypoints.http_server import launch_server + from sglang.srt.server_args import ServerArgs + from sglang.srt.utils import is_port_available + + self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) + + # Pre-allocate ports + for port in range(self.server_args.port, 40000): + if is_port_available(port): + break + self.server_args.port = port + + self.url = self.server_args.url() + self.generate_url = self.url + "/generate" + + # NOTE: We store pid instead of proc to fix some issues during __delete__ + self.pid = None + pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False) + + proc = multiprocessing.Process( + target=launch_server, + args=(self.server_args, pipe_writer), + ) + proc.start() + pipe_writer.close() + self.pid = proc.pid + + # Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() + atexit.register(self.shutdown) + + # TODO: remove this pipe_writer mechanism and use `/health_generate` instead. + try: + init_state = pipe_reader.recv() + except EOFError: + init_state = "" + + if init_state != "ready": + self.shutdown() + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + + self.endpoint = RuntimeEndpoint(self.url) + + def shutdown(self): + from sglang.srt.utils import kill_process_tree + + if self.pid is not None: + kill_process_tree(self.pid) + self.pid = None + + def cache_prefix(self, prefix: str): + self.endpoint.cache_prefix(prefix) + + def get_tokenizer(self): + from sglang.srt.hf_transformers_utils import get_tokenizer + + return get_tokenizer( + self.server_args.tokenizer_path, + tokenizer_mode=self.server_args.tokenizer_mode, + trust_remote_code=self.server_args.trust_remote_code, + revision=self.server_args.revision, + ) + + async def async_generate( + self, + prompt: str, + sampling_params: Optional[Dict] = None, + ): + if self.server_args.skip_tokenizer_init: + json_data = { + "input_ids": prompt, + "sampling_params": sampling_params, + "stream": True, + } + else: + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "stream": True, + } + pos = 0 + + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.post(self.generate_url, json=json_data) as response: + async for chunk, _ in response.content.iter_chunks(): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]\n\n": + break + data = json.loads(chunk[5:].strip("\n")) + if "text" in data: + cur = data["text"][pos:] + if cur: + yield cur + pos += len(cur) + else: + yield data + + add_request = async_generate + + def generate( + self, + prompt: Union[str, List[str]], + sampling_params: Optional[Dict] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + ): + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + "lora_path": lora_path, + } + assert not isinstance(lora_path, list) or len(lora_path) == len(prompt) + response = requests.post( + self.url + "/generate", + json=json_data, + ) + return json.dumps(response.json()) + + def encode( + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + ): + json_data = {"text": prompt} + response = requests.post(self.url + "/encode", json=json_data) + return json.dumps(response.json()) + + async def get_server_info(self): + async with aiohttp.ClientSession() as session: + async with session.get(f"{self.url}/get_server_info") as response: + if response.status == 200: + return await response.json() + else: + error_data = await response.json() + raise RuntimeError( + f"Failed to get server info. {error_data['error']['message']}" + ) + + def __del__(self): + self.shutdown() diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 845e1e52dda..a2c91c561c2 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -354,6 +354,37 @@ def get_chat_template_by_model_path(model_path): ) +register_chat_template( + ChatTemplate( + name="deepseek-v3", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "", + "", + ), + "user": ( + "<|User|>", + "", + ), + "assistant": ( + "<|Assistant|>", + "<|end▁of▁sentence|>", + ), + }, + stop_str=("<|end▁of▁sentence|>",), + ) +) + + +@register_chat_template_matching_function +def match_deepseek(model_path: str): + if ( + "deepseek-v3" in model_path.lower() or "deepseek-r1" in model_path.lower() + ) and "base" not in model_path.lower(): + return get_chat_template("deepseek-v3") + + @register_chat_template_matching_function def match_dbrx(model_path: str): if "dbrx" in model_path.lower() and "instruct" in model_path.lower(): diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 6b0c25711c6..caae7b0f6cc 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -3,7 +3,7 @@ import os import sys -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import prepare_server_args from sglang.srt.utils import kill_process_tree diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py deleted file mode 100644 index 138c2127e16..00000000000 --- a/python/sglang/launch_server_llavavid.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Launch the inference server for Llava-video model.""" - -import json -import sys - -from sglang.srt.server import launch_server, prepare_server_args - -if __name__ == "__main__": - server_args = prepare_server_args(sys.argv[1:]) - - model_override_args = {} - model_override_args["mm_spatial_pool_stride"] = 2 - model_override_args["architectures"] = ["LlavaVidForCausalLM"] - model_override_args["num_frames"] = 16 - model_override_args["model_type"] = "llavavid" - if model_override_args["num_frames"] == 32: - model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"} - model_override_args["max_sequence_length"] = 4096 * 2 - model_override_args["tokenizer_model_max_length"] = 4096 * 2 - model_override_args["model_max_length"] = 4096 * 2 - if "34b" in server_args.model_path.lower(): - model_override_args["image_token_index"] = 64002 - server_args.json_model_override_args = json.dumps(model_override_args) - - launch_server(server_args) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index f59f67605b3..3cb313b9133 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -3,6 +3,7 @@ import functools import importlib import logging +import os from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch @@ -11,12 +12,19 @@ from sglang.srt.utils import is_hpu logger = logging.getLogger(__name__) +use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True) if not is_hpu(): - try: - import sgl_kernel - except ImportError as e: - logger.warning("Failed to import from custom_ar with %r", e) + if use_vllm_custom_allreduce: + try: + import vllm._C + except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) + else: + try: + import sgl_kernel + except ImportError as e: + logger.warning("Failed to import from custom_ar with %r", e) def hint_on_error(fn): @@ -48,43 +56,78 @@ def wrapper(*args, **kwargs): return wrapper -# custom ar -def init_custom_ar( - rank_id: int, - world_size: int, - rank_data_base: torch.Tensor, - buffers: List[int], - tmp_result_buffers: List[int], - barrier_in: List[int], - barrier_out: List[int], -) -> int: - return sgl_kernel.ops.init_custom_reduce( - rank_id, - world_size, - rank_data_base, - buffers, - tmp_result_buffers, - barrier_in, - barrier_out, - ) - - -def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: - sgl_kernel.ops.custom_reduce(fa, inp, out) - +if use_vllm_custom_allreduce: + # custom ar + def init_custom_ar( + ipc_tensors: List[torch.Tensor], + rank_data: torch.Tensor, + rank: int, + full_nvlink: bool, + ) -> int: + return torch.ops._C_custom_ar.init_custom_ar( + ipc_tensors, rank_data, rank, full_nvlink + ) -def dispose(fa: int) -> None: - sgl_kernel.ops.custom_dispose(fa) + def all_reduce( + fa: int, + inp: torch.Tensor, + out: torch.Tensor, + reg_buffer: int, + reg_buffer_sz_bytes: int, + ) -> None: + torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) + + def dispose(fa: int) -> None: + torch.ops._C_custom_ar.dispose(fa) + + def meta_size() -> int: + return torch.ops._C_custom_ar.meta_size() + + def register_buffer(fa: int, ipc_tensors: List[int]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) + + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) + + def register_graph_buffers( + fa: int, handles: List[List[int]], offsets: List[List[int]] + ) -> None: + torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + +else: + # custom ar + def init_custom_ar( + rank_id: int, + world_size: int, + rank_data_base: torch.Tensor, + buffers: List[int], + tmp_result_buffers: List[int], + barrier_in: List[int], + barrier_out: List[int], + ) -> int: + return sgl_kernel.ops.init_custom_reduce( + rank_id, + world_size, + rank_data_base, + buffers, + tmp_result_buffers, + barrier_in, + barrier_out, + ) + def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + sgl_kernel.ops.custom_reduce(fa, inp, out) -def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: - return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) + def dispose(fa: int) -> None: + sgl_kernel.ops.custom_dispose(fa) + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) -def register_graph_buffers( - fa: int, handles: List[List[int]], offsets: List[List[int]] -) -> None: - sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) + def register_graph_buffers( + fa: int, handles: List[List[int]], offsets: List[List[int]] + ) -> None: + sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) # temporary fix for https://github.com/vllm-project/vllm/issues/5456 diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 2b2b341faeb..6cb35ab47c6 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum): GGUF = "gguf" BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" + LAYERED = "layered" @dataclass diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py deleted file mode 100644 index 458d1925241..00000000000 --- a/python/sglang/srt/constrained/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# TODO(lmzheng): make this an optional dependency -from sglang.srt.constrained.outlines_backend import build_regex_from_object diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index 7c88229cf16..6f304ea171e 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -18,6 +18,8 @@ from threading import Event, Lock from typing import Any, Optional, Tuple +from sglang.srt.server_args import ServerArgs + @dataclass class CacheEntry: @@ -69,3 +71,22 @@ def get_future_value(self, key: Tuple[str, str]) -> Future: def reset(self): with self.cache_lock: self.cache.clear() + + +def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size): + if server_args.grammar_backend == "outlines": + from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend + + grammar_backend = OutlinesGrammarBackend( + tokenizer, + whitespace_pattern=server_args.constrained_json_whitespace_pattern, + allow_jump_forward=not server_args.disable_jump_forward, + ) + elif server_args.grammar_backend == "xgrammar": + from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend + + grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size) + else: + raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}") + + return grammar_backend diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index b0b2c31c2ac..c423a567eda 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -19,6 +19,7 @@ import torch from xgrammar import ( CompiledGrammar, + Grammar, GrammarCompiler, GrammarMatcher, TokenizerInfo, @@ -133,10 +134,13 @@ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") return None elif key_type == "regex": - logger.warning( - "regex hasn't been supported by xgrammar yet. This is skipped." - ) - return None + try: + ctx = self.grammar_compiler.compile_grammar( + Grammar.from_regex(key_string) + ) + except RuntimeError as e: + logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") + return None else: raise ValueError(f"Invalid key_type: {key_type}") diff --git a/python/sglang/srt/distributed/__init__.py b/python/sglang/srt/distributed/__init__.py index db325cfabf5..12f802055c5 100644 --- a/python/sglang/srt/distributed/__init__.py +++ b/python/sglang/srt/distributed/__init__.py @@ -1,3 +1,3 @@ -from .communication_op import * -from .parallel_state import * -from .utils import * +from sglang.srt.distributed.communication_op import * +from sglang.srt.distributed.parallel_state import * +from sglang.srt.distributed.utils import * diff --git a/python/sglang/srt/distributed/communication_op.py b/python/sglang/srt/distributed/communication_op.py index ddf3b8ef568..95600edfb41 100644 --- a/python/sglang/srt/distributed/communication_op.py +++ b/python/sglang/srt/distributed/communication_op.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py + from typing import Any, Dict, Optional, Union import torch diff --git a/python/sglang/srt/distributed/device_communicators/__init__.py b/python/sglang/srt/distributed/device_communicators/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py index ab4ee33fcfc..c902f314112 100644 --- a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py + """This file is a pure Python wrapper for the cudart library. It avoids the need to compile a separate shared library, and is convenient for use when we just need to call a few functions. diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index ba9feb59d0c..faeac0bbae9 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py + import ctypes import logging import os @@ -6,7 +7,6 @@ from functools import wraps from typing import Callable, List, Optional, TypeVar, Union -import pynvml import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -20,9 +20,19 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.utils import cuda_device_count_stateless, is_cuda -try: - import sgl_kernel +logger = logging.getLogger(__name__) +if is_cuda(): + try: + import pynvml + except ImportError as e: + logger.warning("Failed to import pynvml with %r", e) + +try: + if ops.use_vllm_custom_allreduce: + ops.meta_size() + else: + import sgl_kernel custom_ar = True except Exception: # For AMD GPUs and CPUs @@ -175,9 +185,12 @@ def __init__( # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - assert is_cuda() + if is_cuda(): + assert is_cuda() - full_nvlink = is_full_nvlink(physical_device_ids) + full_nvlink = is_full_nvlink(physical_device_ids) + else: + full_nvlink = False if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" @@ -201,33 +214,58 @@ def __init__( self.world_size = world_size self.full_nvlink = full_nvlink - # From TensorRT-LLM getMaxRequiredWorkspaceSize - self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] - - # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; - self.barrier_max_size = 8 * (36 + 2) * 8 - - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) - self.tmp_result_buffer_ptrs = self.create_shared_buffer(max_size, group=group) - self.rank_data_base = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=self.device - ) - self.barrier_in_ptrs = self.create_shared_buffer( - self.barrier_max_size, group=group - ) - self.barrier_out_ptrs = self.create_shared_buffer( - self.barrier_max_size, group=group - ) - - self._ptr = ops.init_custom_ar( - rank, - world_size, - self.rank_data_base, - self.buffer_ptrs, - self.tmp_result_buffer_ptrs, - self.barrier_in_ptrs, - self.barrier_out_ptrs, - ) + if ops.use_vllm_custom_allreduce: + # Buffers memory are owned by this Python class and passed to C++. + # Meta data composes of two parts: meta data for synchronization and a + # temporary buffer for storing intermediate allreduce results. + self.meta_ptrs = self.create_shared_buffer( + ops.meta_size() + max_size, group=group + ) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self._ptr = ops.init_custom_ar( + self.meta_ptrs, self.rank_data, rank, self.full_nvlink + ) + ops.register_buffer(self._ptr, self.buffer_ptrs) + else: + # From TensorRT-LLM getMaxRequiredWorkspaceSize + self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] + + # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; + self.barrier_max_size = 8 * (36 + 2) * 8 + + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + self.tmp_result_buffer_ptrs = self.create_shared_buffer( + max_size, group=group + ) + self.rank_data_base = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self.barrier_in_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) + self.barrier_out_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) + + self._ptr = ops.init_custom_ar( + rank, + world_size, + self.rank_data_base, + self.buffer_ptrs, + self.tmp_result_buffer_ptrs, + self.barrier_in_ptrs, + self.barrier_out_ptrs, + ) self.disabled = False @staticmethod @@ -307,6 +345,11 @@ def should_custom_ar(self, inp: torch.Tensor): return False # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. + if ops.use_vllm_custom_allreduce: + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False + if self.world_size == 2: return ( inp_size < self.max_size @@ -326,6 +369,7 @@ def all_reduce( inp: torch.Tensor, *, out: torch.Tensor = None, + registered: bool = False, ): """Performs an out-of-place all reduce. @@ -335,7 +379,15 @@ def all_reduce( """ if out is None: out = torch.empty_like(inp) - ops.all_reduce(self._ptr, inp, out) + if ops.use_vllm_custom_allreduce: + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + ops.all_reduce( + self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size + ) + else: + ops.all_reduce(self._ptr, inp, out) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -345,21 +397,25 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - return self.all_reduce(input) + return self.all_reduce(input, registered=True) else: # If warm up, mimic the allocation pattern since custom # allreduce is out-of-place. return torch.empty_like(input) else: - return self.all_reduce(input) + return self.all_reduce(input, registered=False) def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr) - self.free_shared_buffer(self.buffer_ptrs) - self.free_shared_buffer(self.tmp_result_buffer_ptrs) - self.free_shared_buffer(self.barrier_in_ptrs) - self.free_shared_buffer(self.barrier_out_ptrs) + if ops.use_vllm_custom_allreduce: + self.free_shared_buffer(self.meta_ptrs) + self.free_shared_buffer(self.buffer_ptrs) + else: + self.free_shared_buffer(self.buffer_ptrs) + self.free_shared_buffer(self.tmp_result_buffer_ptrs) + self.free_shared_buffer(self.barrier_in_ptrs) + self.free_shared_buffer(self.barrier_out_ptrs) self._ptr = 0 def __del__(self): diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index d807dfd5ce5..4073491aa62 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py + import ctypes import json import logging @@ -7,7 +8,6 @@ import subprocess import sys import tempfile -from functools import lru_cache from itertools import product from typing import Dict, List, Optional, Sequence diff --git a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py index 72ef3889e01..722e494cf77 100644 --- a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py + import torch import torch.distributed as dist from torch.distributed import ProcessGroup diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index baee270da90..9f65939f6d9 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -1,8 +1,10 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py + import logging from contextlib import contextmanager from typing import Optional, Union +# ===================== import region ===================== import torch import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp @@ -143,6 +145,57 @@ def all_reduce( cudaStream_t(stream.cuda_stream), ) + def all_gather( + self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def reduce_scatter( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return @@ -179,6 +232,32 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): cudaStream_t(stream.cuda_stream), ) + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = self.stream + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast( + sendbuff, + recvbuff, + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + @contextmanager def change_state( self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py index e72284f5117..afb47733476 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py # This file is a pure Python wrapper for the NCCL library. # The main purpose is to use NCCL combined with CUDA graph. @@ -57,7 +57,7 @@ def find_nccl_library() -> str: so_file = "librccl.so.1" else: raise ValueError("NCCL only supports CUDA and ROCm backends.") - logger.info("Found nccl from library %s", so_file) + logger.debug("Found nccl from library %s", so_file) return so_file @@ -187,6 +187,43 @@ class NCCLLibrary: cudaStream_t, ], ), + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllGather", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclReduceScatter", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); @@ -217,6 +254,23 @@ class NCCLLibrary: cudaStream_t, ], ), + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function( + "ncclBroadcast", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -321,6 +375,46 @@ def ncclAllReduce( ) ) + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclReduceScatter"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclAllGather"]( + sendbuff, recvbuff, count, datatype, comm, stream + ) + ) + def ncclSend( self, sendbuff: buffer_type, @@ -347,6 +441,22 @@ def ncclRecv( self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) ) + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclBroadcast"]( + sendbuff, recvbuff, count, datatype, root, comm, stream + ) + ) + def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py index 1afe6fca526..7a3b22e27a8 100644 --- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -1,11 +1,9 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py -import ipaddress +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/shm_broadcast.py + import logging import os import pickle -import socket import time -import warnings from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory @@ -18,6 +16,8 @@ from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore +from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address + # SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60 SGLANG_RINGBUFFER_WARNING_INTERVAL = int( os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60") @@ -26,73 +26,6 @@ logger = logging.getLogger(__name__) -def get_ip() -> str: - # SGLANG_HOST_IP env can be ignore - host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "") - if host_ip: - return host_ip - - # IP is not set, try to get it from the network interface - - # try ipv4 - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable - return s.getsockname()[0] - except Exception: - pass - - # try ipv6 - try: - s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) - # Google's public DNS server, see - # https://developers.google.com/speed/public-dns/docs/using#addresses - s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable - return s.getsockname()[0] - except Exception: - pass - - warnings.warn( - "Failed to get the IP address, using 0.0.0.0 by default." - "The value can be set by the environment variable" - " SGLANG_HOST_IP or HOST_IP.", - stacklevel=2, - ) - return "0.0.0.0" - - -def get_open_port() -> int: - - port = os.getenv("SGLANG_PORT") - if port is not None: - while True: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", port)) - return port - except OSError: - port += 1 # Increment port number if already in use - logger.info("Port %d is already in use, trying port %d", port - 1, port) - # try ipv4 - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - except OSError: - # try ipv6 - with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -def is_valid_ipv6_address(address: str) -> bool: - try: - ipaddress.IPv6Address(address) - return True - except ValueError: - return False - - class ShmRingBuffer: def __init__( @@ -313,7 +246,7 @@ def __init__( remote_subscribe_port=remote_subscribe_port, ) - logger.info("vLLM message queue communication handle: %s", self.handle) + logger.debug("Message queue communication handle: %s", self.handle) def export_handle(self) -> Handle: return self.handle diff --git a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py index ff0981b80bc..532279f70c3 100644 --- a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/xpu_communicator.py + import torch import torch.distributed as dist from torch.distributed import ProcessGroup diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 26d04b04ce9..c6d1a830781 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/parallel_state.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/parallel_state.py # Copyright 2023 The vLLM team. # Adapted from diff --git a/python/sglang/srt/distributed/utils.py b/python/sglang/srt/distributed/utils.py index a225fbb9182..e117aa30d07 100644 --- a/python/sglang/srt/distributed/utils.py +++ b/python/sglang/srt/distributed/utils.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/utils.py + # Copyright 2023 The vLLM team. # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py new file mode 100644 index 00000000000..098a3d1e325 --- /dev/null +++ b/python/sglang/srt/entrypoints/engine.py @@ -0,0 +1,452 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +The entry point of inference server. (SRT = SGLang Runtime) + +This file implements python APIs for the inference engine. +""" + +import asyncio +import atexit +import dataclasses +import logging +import multiprocessing as mp +import os +import signal +import threading +from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union + +# Fix a bug of Python threading +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +import torch +import uvloop + +from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, +) +from sglang.srt.managers.detokenizer_manager import run_detokenizer_process +from sglang.srt.managers.io_struct import ( + EmbeddingReqInput, + GenerateReqInput, + GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, +) +from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils import ( + MultiprocessingSerializer, + assert_pkg_version, + configure_logger, + kill_process_tree, + launch_dummy_health_check_server, + maybe_set_triton_cache_manager, + prepare_model_and_tokenizer, + set_prometheus_multiproc_dir, + set_ulimit, +) +from sglang.version import __version__ + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +class Engine: + """ + The entry point to the inference engine. + + - The engine consists of three components: + 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. + 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + + Note: + 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. + """ + + def __init__(self, **kwargs): + """ + The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`. + Please refer to `ServerArgs` for the documentation. + """ + if "server_args" in kwargs: + # Directly load server_args + server_args = kwargs["server_args"] + else: + # Construct server_args from kwargs + if "log_level" not in kwargs: + # Do not print logs by default + kwargs["log_level"] = "error" + server_args = ServerArgs(**kwargs) + + # Shutdown the subprocesses automatically when the program exists + atexit.register(self.shutdown) + + # Launch subprocesses + tokenizer_manager, scheduler_info = _launch_subprocesses( + server_args=server_args + ) + self.tokenizer_manager = tokenizer_manager + self.scheduler_info = scheduler_info + + def generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + stream: bool = False, + ) -> Union[Dict, Iterator[Dict]]: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. + Please refer to `GenerateReqInput` for the documentation. + """ + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + custom_logit_processor=custom_logit_processor, + stream=stream, + ) + loop = asyncio.get_event_loop() + generator = self.tokenizer_manager.generate_request(obj, None) + + if stream: + + def generator_wrapper(): + while True: + try: + chunk = loop.run_until_complete(generator.__anext__()) + yield chunk + except StopAsyncIteration: + break + + return generator_wrapper() + else: + ret = loop.run_until_complete(generator.__anext__()) + return ret + + async def async_generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + stream: bool = False, + ) -> Union[Dict, AsyncIterator[Dict]]: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. + Please refer to `GenerateReqInput` for the documentation. + """ + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + stream=stream, + custom_logit_processor=custom_logit_processor, + ) + generator = self.tokenizer_manager.generate_request(obj, None) + + if stream is True: + return generator + else: + return await generator.__anext__() + + def encode( + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + ) -> Dict: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. + Please refer to `EmbeddingReqInput` for the documentation. + """ + + obj = EmbeddingReqInput(text=prompt) + loop = asyncio.get_event_loop() + generator = self.tokenizer_manager.generate_request(obj, None) + ret = loop.run_until_complete(generator.__anext__()) + return ret + + def shutdown(self): + """Shutdown the engine""" + kill_process_tree(os.getpid(), include_parent=False) + + def start_profile(self): + self.tokenizer_manager.start_profile() + + def stop_profile(self): + self.tokenizer_manager.stop_profile() + + def get_server_info(self): + return { + **dataclasses.asdict(self.tokenizer_manager.server_args), # server args + **self.scheduler_info, + "version": __version__, + } + + def init_weights_update_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + ): + """Initialize parameter update group.""" + obj = InitWeightsUpdateGroupReqInput( + master_address=master_address, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, + group_name=group_name, + backend=backend, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.init_weights_update_group(obj, None) + ) + + def update_weights_from_distributed(self, name: str, dtype, shape): + """Update weights from distributed source.""" + obj = UpdateWeightsFromDistributedReqInput( + name=name, + dtype=dtype, + shape=shape, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_weights_from_distributed(obj, None) + ) + + def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): + """Update weights from distributed source.""" + obj = UpdateWeightsFromTensorReqInput( + serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors) + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_weights_from_tensor(obj, None) + ) + + def get_weights_by_name(self, name: str, truncate_size: int = 100): + """Get weights by parameter name.""" + obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.get_weights_by_name(obj, None) + ) + + def release_memory_occupation(self): + """Release GPU occupation temporarily.""" + obj = ReleaseMemoryOccupationReqInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.release_memory_occupation(obj, None) + ) + + def resume_memory_occupation(self): + """Resume GPU occupation.""" + obj = ResumeMemoryOccupationReqInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.resume_memory_occupation(obj, None) + ) + + +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + + # Set ulimit + set_ulimit() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Check flashinfer version + if server_args.attention_backend == "flashinfer": + assert_pkg_version( + "flashinfer", + "0.1.6", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) + + # Register the signal handler. + # The child processes will send SIGQUIT to this process when any error happens + # This process then clean up the whole process tree + def sigquit_handler(signum, frame): + logger.error( + "Received sigquit from a child proces. It usually means the child failed." + ) + kill_process_tree(os.getpid()) + + signal.signal(signal.SIGQUIT, sigquit_handler) + + # Set mp start method + mp.set_start_method("spawn", force=True) + + +def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]: + """ + Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. + """ + # Configure global environment + configure_logger(server_args) + server_args.check_server_args() + _set_envs_and_config(server_args) + + # Allocate ports for inter-process communications + port_args = PortArgs.init_new(server_args) + logger.info(f"{server_args=}") + + # If using model from www.modelscope.cn, first download the model. + server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( + server_args.model_path, server_args.tokenizer_path + ) + + scheduler_procs = [] + if server_args.dp_size == 1: + # Launch tensor parallel scheduler processes + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + + scheduler_pipe_readers = [] + tp_size_per_node = server_args.tp_size // server_args.nnodes + tp_rank_range = range( + tp_size_per_node * server_args.node_rank, + tp_size_per_node * (server_args.node_rank + 1), + ) + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node + proc = mp.Process( + target=run_scheduler_process, + args=(server_args, port_args, gpu_id, tp_rank, None, writer), + ) + with memory_saver_adapter.configure_subprocess(): + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + else: + # Launch the data parallel controller + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_readers = [reader] + proc = mp.Process( + target=run_data_parallel_controller_process, + args=(server_args, port_args, writer), + ) + proc.start() + scheduler_procs.append(proc) + + if server_args.node_rank >= 1: + # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, + # so they can just wait here. + + for reader in scheduler_pipe_readers: + data = reader.recv() + assert data["status"] == "ready" + + if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": + # When using `Engine` as a Python API, we don't want to block here. + return None, None + + launch_dummy_health_check_server(server_args.host, server_args.port) + + for proc in scheduler_procs: + proc.join() + logger.error( + f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + ) + return None, None + + # Launch detokenizer process + detoken_proc = mp.Process( + target=run_detokenizer_process, + args=( + server_args, + port_args, + ), + ) + detoken_proc.start() + + # Launch tokenizer process + tokenizer_manager = TokenizerManager(server_args, port_args) + if server_args.chat_template: + load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) + + # Wait for the model to finish loading + scheduler_infos = [] + for i in range(len(scheduler_pipe_readers)): + try: + data = scheduler_pipe_readers[i].recv() + except EOFError: + logger.error( + f"Rank {i} scheduler is dead. Please check if there are relevant logs." + ) + scheduler_procs[i].join() + logger.error(f"Exit code: {scheduler_procs[i].exitcode}") + raise + + if data["status"] != "ready": + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + scheduler_infos.append(data) + + # Assume all schedulers have the same scheduler_info + scheduler_info = scheduler_infos[0] + tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + return tokenizer_manager, scheduler_info diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py new file mode 100644 index 00000000000..36f8b6e1971 --- /dev/null +++ b/python/sglang/srt/entrypoints/http_server.py @@ -0,0 +1,624 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +The entry point of inference server. (SRT = SGLang Runtime) + +This file implements HTTP APIs for the inferenc engine via fastapi. +""" + +import asyncio +import dataclasses +import logging +import multiprocessing as multiprocessing +import os +import threading +import time +from http import HTTPStatus +from typing import AsyncIterator, Dict, Optional + +# Fix a bug of Python threading +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +import orjson +import requests +import uvicorn +import uvloop +from fastapi import FastAPI, File, Form, Request, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import ORJSONResponse, Response, StreamingResponse + +from sglang.srt.entrypoints.engine import _launch_subprocesses +from sglang.srt.function_call_parser import FunctionCallParser +from sglang.srt.managers.io_struct import ( + CloseSessionReqInput, + ConfigureLoggingReq, + EmbeddingReqInput, + FunctionCallReqInput, + GenerateReqInput, + GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, + OpenSessionReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + UpdateWeightFromDiskReqInput, + UpdateWeightsFromDistributedReqInput, +) +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.metrics.func_timer import enable_func_timer +from sglang.srt.openai_api.adapter import ( + v1_batches, + v1_cancel_batch, + v1_chat_completions, + v1_completions, + v1_delete_file, + v1_embeddings, + v1_files_create, + v1_retrieve_batch, + v1_retrieve_file, + v1_retrieve_file_content, +) +from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import ( + add_api_key_middleware, + add_prometheus_middleware, + delete_directory, + kill_process_tree, + set_uvicorn_logging_configs, +) +from sglang.utils import get_exception_traceback +from sglang.version import __version__ + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +# Fast API +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# Store global states +@dataclasses.dataclass +class _GlobalState: + tokenizer_manager: TokenizerManager + scheduler_info: Dict + + +_global_state: Optional[_GlobalState] = None + + +def set_global_state(global_state: _GlobalState): + global _global_state + _global_state = global_state + + +##### Native API endpoints ##### + + +@app.get("/health") +async def health() -> Response: + """Check the health of the http server.""" + return Response(status_code=200) + + +@app.get("/health_generate") +async def health_generate(request: Request) -> Response: + """Check the health of the inference server by generating one token.""" + + sampling_params = {"max_new_tokens": 1, "temperature": 0.7} + + if _global_state.tokenizer_manager.is_generation: + gri = GenerateReqInput( + input_ids=[0], sampling_params=sampling_params, log_metrics=False + ) + else: + gri = EmbeddingReqInput( + input_ids=[0], sampling_params=sampling_params, log_metrics=False + ) + + try: + async for _ in _global_state.tokenizer_manager.generate_request(gri, request): + break + return Response(status_code=200) + except Exception as e: + logger.exception(e) + return Response(status_code=503) + + +@app.get("/get_model_info") +async def get_model_info(): + """Get the model information.""" + result = { + "model_path": _global_state.tokenizer_manager.model_path, + "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path, + "is_generation": _global_state.tokenizer_manager.is_generation, + } + return result + + +@app.get("/get_server_info") +async def get_server_info(): + return { + **dataclasses.asdict(_global_state.tokenizer_manager.server_args), + **_global_state.scheduler_info, + "version": __version__, + } + + +# fastapi implicitly converts json in the request to obj (dataclass) +@app.api_route("/generate", methods=["POST", "PUT"]) +async def generate_request(obj: GenerateReqInput, request: Request): + """Handle a generate request.""" + if obj.stream: + + async def stream_results() -> AsyncIterator[bytes]: + try: + async for out in _global_state.tokenizer_manager.generate_request( + obj, request + ): + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + except ValueError as e: + out = {"error": {"message": str(e)}} + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + yield b"data: [DONE]\n\n" + + return StreamingResponse( + stream_results(), + media_type="text/event-stream", + background=_global_state.tokenizer_manager.create_abort_task(obj), + ) + else: + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + logger.error(f"Error: {e}") + return _create_error_response(e) + + +@app.api_route("/encode", methods=["POST", "PUT"]) +async def encode_request(obj: EmbeddingReqInput, request: Request): + """Handle an embedding request.""" + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + return _create_error_response(e) + + +@app.api_route("/classify", methods=["POST", "PUT"]) +async def classify_request(obj: EmbeddingReqInput, request: Request): + """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + return _create_error_response(e) + + +@app.post("/flush_cache") +async def flush_cache(): + """Flush the radix cache.""" + _global_state.tokenizer_manager.flush_cache() + return Response( + content="Cache flushed.\nPlease check backend logs for more details. " + "(When there are running or waiting requests, the operation will not be performed.)\n", + status_code=200, + ) + + +@app.api_route("/start_profile", methods=["GET", "POST"]) +async def start_profile_async(): + """Start profiling.""" + _global_state.tokenizer_manager.start_profile() + return Response( + content="Start profiling.\n", + status_code=200, + ) + + +@app.api_route("/stop_profile", methods=["GET", "POST"]) +async def stop_profile_async(): + """Stop profiling.""" + _global_state.tokenizer_manager.stop_profile() + return Response( + content="Stop profiling. This will take some time.\n", + status_code=200, + ) + + +@app.post("/update_weights_from_disk") +async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): + """Update the weights from disk in-place without re-launching the server.""" + success, message = await _global_state.tokenizer_manager.update_weights_from_disk( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse( + content, + status_code=HTTPStatus.OK, + ) + else: + return ORJSONResponse( + content, + status_code=HTTPStatus.BAD_REQUEST, + ) + + +@app.post("/init_weights_update_group") +async def init_weights_update_group( + obj: InitWeightsUpdateGroupReqInput, request: Request +): + """Initialize the parameter update group.""" + success, message = await _global_state.tokenizer_manager.init_weights_update_group( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.post("/update_weights_from_distributed") +async def update_weights_from_distributed( + obj: UpdateWeightsFromDistributedReqInput, request: Request +): + """Update model parameter from distributed online.""" + success, message = ( + await _global_state.tokenizer_manager.update_weights_from_distributed( + obj, request + ) + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.api_route("/get_weights_by_name", methods=["GET", "POST"]) +async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): + """Get model parameter by name.""" + try: + ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request) + if ret is None: + return _create_error_response("Get parameter by name failed") + else: + return ORJSONResponse(ret, status_code=200) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/release_memory_occupation", methods=["GET", "POST"]) +async def release_memory_occupation( + obj: ReleaseMemoryOccupationReqInput, request: Request +): + """Release GPU occupation temporarily""" + try: + await _global_state.tokenizer_manager.release_memory_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/resume_memory_occupation", methods=["GET", "POST"]) +async def resume_memory_occupation( + obj: ResumeMemoryOccupationReqInput, request: Request +): + """Resume GPU occupation""" + try: + await _global_state.tokenizer_manager.resume_memory_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/open_session", methods=["GET", "POST"]) +async def open_session(obj: OpenSessionReqInput, request: Request): + """Open a session, and return its unique session id.""" + try: + session_id = await _global_state.tokenizer_manager.open_session(obj, request) + if session_id is None: + raise Exception( + "Failed to open the session. Check if a session with the same id is still open." + ) + return session_id + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/close_session", methods=["GET", "POST"]) +async def close_session(obj: CloseSessionReqInput, request: Request): + """Close the session""" + try: + await _global_state.tokenizer_manager.close_session(obj, request) + return Response(status_code=200) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/configure_logging", methods=["GET", "POST"]) +async def configure_logging(obj: ConfigureLoggingReq, request: Request): + """Close the session""" + _global_state.tokenizer_manager.configure_logging(obj) + return Response(status_code=200) + + +@app.post("/function_call") +async def function_call_request(obj: FunctionCallReqInput, request: Request): + """ + A native API endpoint to parse function calls from a text. + """ + # 1) Initialize the parser based on the request body + parser = FunctionCallParser(tools=obj.tools, tool_call_parser=obj.tool_call_parser) + + # 2) Call the non-stream parsing method (non-stream) + normal_text, calls = parser.parse_non_stream(obj.text) + + # 3) Organize the response content + response_data = { + "normal_text": normal_text, + "calls": [ + call.model_dump() for call in calls + ], # Convert pydantic objects to dictionaries + } + + return ORJSONResponse(content=response_data, status_code=200) + + +##### OpenAI-compatible API endpoints ##### + + +@app.post("/v1/completions") +async def openai_v1_completions(raw_request: Request): + return await v1_completions(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/chat/completions") +async def openai_v1_chat_completions(raw_request: Request): + return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/embeddings", response_class=ORJSONResponse) +async def openai_v1_embeddings(raw_request: Request): + response = await v1_embeddings(_global_state.tokenizer_manager, raw_request) + return response + + +@app.get("/v1/models", response_class=ORJSONResponse) +def available_models(): + """Show available models.""" + served_model_names = [_global_state.tokenizer_manager.served_model_name] + model_cards = [] + for served_model_name in served_model_names: + model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) + return ModelList(data=model_cards) + + +@app.post("/v1/files") +async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): + return await v1_files_create( + file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth + ) + + +@app.delete("/v1/files/{file_id}") +async def delete_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/delete + return await v1_delete_file(file_id) + + +@app.post("/v1/batches") +async def openai_v1_batches(raw_request: Request): + return await v1_batches(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/batches/{batch_id}/cancel") +async def cancel_batches(batch_id: str): + # https://platform.openai.com/docs/api-reference/batch/cancel + return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id) + + +@app.get("/v1/batches/{batch_id}") +async def retrieve_batch(batch_id: str): + return await v1_retrieve_batch(batch_id) + + +@app.get("/v1/files/{file_id}") +async def retrieve_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve + return await v1_retrieve_file(file_id) + + +@app.get("/v1/files/{file_id}/content") +async def retrieve_file_content(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve-contents + return await v1_retrieve_file_content(file_id) + + +def _create_error_response(e): + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +def launch_server( + server_args: ServerArgs, + pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None, +): + """ + Launch SRT (SGLang Runtime) Server. + + The SRT server consists of an HTTP server and an SRT engine. + + - HTTP server: A FastAPI server that routes requests to the engine. + - The engine consists of three components: + 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. + 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + + Note: + 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. + """ + tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) + set_global_state( + _GlobalState( + tokenizer_manager=tokenizer_manager, + scheduler_info=scheduler_info, + ) + ) + + # Add api key authorization + if server_args.api_key: + add_api_key_middleware(app, server_args.api_key) + + # Add prometheus middleware + if server_args.enable_metrics: + add_prometheus_middleware(app) + enable_func_timer() + + # Send a warmup request + t = threading.Thread( + target=_wait_and_warmup, + args=( + server_args, + pipe_finish_writer, + _global_state.tokenizer_manager.image_token_id, + ), + ) + t.start() + + try: + # Update logging configs + set_uvicorn_logging_configs() + + # Listen for HTTP requests + uvicorn.run( + app, + host=server_args.host, + port=server_args.port, + log_level=server_args.log_level_http or server_args.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) + finally: + t.join() + + +def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): + headers = {} + url = server_args.url() + if server_args.api_key: + headers["Authorization"] = f"Bearer {server_args.api_key}" + + # Wait until the server is launched + success = False + for _ in range(120): + time.sleep(1) + try: + res = requests.get(url + "/get_model_info", timeout=5, headers=headers) + assert res.status_code == 200, f"{res=}, {res.text=}" + success = True + break + except (AssertionError, requests.exceptions.RequestException): + last_traceback = get_exception_traceback() + pass + + if not success: + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_process_tree(os.getpid()) + return + + model_info = res.json() + + # Send a warmup request + request_name = "/generate" if model_info["is_generation"] else "/encode" + max_new_tokens = 8 if model_info["is_generation"] else 1 + json_data = { + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + } + if server_args.skip_tokenizer_init: + json_data["input_ids"] = [10, 11, 12] + else: + # json_data["text"] = "The capital city of France is" + target_length = int(os.getenv("SRT_WARMUP_PASSKEY_LENGTH", "35000")) + json_data["text"] = ( + "You need to find the passkey. Read carefully following text, and remember the passkey\n\n" + ) + filler = "Sky is blue, grass is green, sun is red. And here we go again" + json_data["text"] += filler * (target_length // 35) + json_data[ + "text" + ] += "\n\nThe passkey is $000310$. Remember, the passkey is $000310$.\n\n" + json_data[ + "text" + ] += "\n\nThe passkey is $000310$. Remember, the passkey is $000310$.\n\n" + json_data[ + "text" + ] += "\n\nThe passkey is $000310$. Remember, the passkey is $000310$.\n\n" + json_data["text"] += filler * (target_length // 35) + json_data["text"] += "What was the passkey? The passkey is" + + try: + for _ in range(server_args.dp_size): + res = requests.post( + url + request_name, + json=json_data, + headers=headers, + timeout=3600, + ) + assert res.status_code == 200, f"{res}" + logger.info(f"Warmup response: {res.json()}") + if os.getenv("SRT_EXIT_AFTER_WARMUP", "0") == "1": + logger.error(f"Initialization canceled. SRT_EXIT_AFTER_WARMUP") + kill_process_tree(os.getpid()) + except Exception: + last_traceback = get_exception_traceback() + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_process_tree(os.getpid()) + return + + # Debug print + # logger.info(f"{res.json()=}") + + logger.info("The server is fired up and ready to roll!") + if pipe_finish_writer is not None: + pipe_finish_writer.send("ready") + + if server_args.delete_ckpt_after_loading: + delete_directory(server_args.model_path) diff --git a/python/sglang/srt/function_call_parser.py b/python/sglang/srt/function_call_parser.py new file mode 100644 index 00000000000..3def4e1eb27 --- /dev/null +++ b/python/sglang/srt/function_call_parser.py @@ -0,0 +1,494 @@ +import json +import re +from abc import ABC, abstractmethod +from json import JSONDecodeError, JSONDecoder +from typing import Any, Dict, List, Optional, Tuple + +import partial_json_parser +from partial_json_parser.core.options import Allow +from pydantic import BaseModel, Field + +TOOLS_TAG_LIST = [ + "<|plugin|>", + "", + "<|python_tag|>", + "[TOOL_CALLS]", +] + + +class Function(BaseModel): + """Function Tool Template.""" + + description: Optional[str] = Field(default=None, examples=[None]) + name: Optional[str] = None + parameters: Optional[object] = None + + +class ToolCallItem(BaseModel): + """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts.""" + + tool_index: int + name: Optional[str] = None + parameters: str # JSON string + + +def _find_common_prefix(s1: str, s2: str) -> str: + prefix = "" + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: + try: + return (partial_json_parser.loads(input_str, flags), len(input_str)) + except JSONDecodeError as e: + if "Extra data" in e.msg: + dec = JSONDecoder() + return dec.raw_decode(input_str) + raise + + +def _is_complete_json(input_str: str) -> bool: + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False + + +class StreamingParseResult: + """Result of streaming incremental parsing.""" + + def __init__( + self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None + ): + self.normal_text = normal_text + self.calls = calls or [] + + +class BaseFormatDetector: + """Base class providing two sets of interfaces: one-time and streaming incremental.""" + + def __init__(self): + # initialize properties used for state when parsing tool calls in + self._buffer = "" + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = ( + [] + ) # map what has been streamed for each tool so far to a list + self.bot_token = "" + self.eot_token = "" + + def parse_base_json(self, action: Dict, tools: List[Function]): + name, parameters = action["name"], json.dumps( + action.get("parameters", action.get("arguments", {})), + ensure_ascii=False, + ) + tool_index = [tool.function.name for tool in tools].index(name) + tool_call_item = ToolCallItem( + tool_index=tool_index, name=name, parameters=parameters + ) + calls = [tool_call_item] + return calls + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + Parses the text in one go. Returns success=True if the format matches, otherwise False. + Note that leftover_text here represents "content that this parser will not consume further". + """ + action = json.loads(text) + return self.parse_base_json(action, tools) + + def parse_streaming_increment( + self, new_text: str, tools: List[Function] + ) -> StreamingParseResult: + """ + Streaming incremental parsing, referencing the logic of Llama32Detector. + We partially parse JSON within ..., and handle + incremental argument output. + """ + # Append new text to buffer + self._buffer += new_text + current_text = self._buffer + if not (self.bot_token in current_text or current_text.startswith("{")): + self._buffer = "" + if self.eot_token in new_text: + new_text = new_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=new_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR + try: + tool_call_arr = [] + is_complete = [] + try: + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + start_idx = ( + len(self.bot_token) + if current_text.startswith(self.bot_token) + else 0 + ) + while start_idx < len(current_text): + (obj, end_idx) = _partial_json_loads( + current_text[start_idx:], flags + ) + is_complete.append( + _is_complete_json(current_text[start_idx : start_idx + end_idx]) + ) + start_idx += end_idx + len("; ") + # depending on the prompt Llama can use + # either arguments or parameters + if "parameters" in obj: + assert ( + "arguments" not in obj + ), "model generated both parameters and arguments" + obj["arguments"] = obj["parameters"] + tool_call_arr.append(obj) + + except partial_json_parser.core.exceptions.MalformedJSON: + # not enough tokens to parse into JSON yet + return StreamingParseResult() + + # select as the current tool call the one we're on the state at + current_tool_call: Dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return StreamingParseResult() + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif ( + len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 + ): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + res = StreamingParseResult( + normal_text=None, + calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + name="", + parameters=argument_diff, + ) + ], + ) + self.streamed_args_for_tool[ + self.current_tool_id + ] += argument_diff + else: + res = StreamingParseResult() + else: + res = StreamingParseResult() + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + print("starting on new tool %d", self.current_tool_id) + return res + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + res = StreamingParseResult( + normal_text=None, + calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + name=function_name, + parameters="", + ) + ], + ) + self.current_tool_name_sent = True + else: + res = StreamingParseResult() + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + res = StreamingParseResult() + + if cur_arguments: + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + self._buffer = "" + self.prev_tool_call_arr[self.current_tool_id].clear() + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool[self.current_tool_id] = "" + + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + + prefix = _find_common_prefix(prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + res = StreamingParseResult( + calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + name="", + parameters=argument_diff, + ) + ], + ) + if not is_complete[self.current_tool_id]: + self.streamed_args_for_tool[ + self.current_tool_id + ] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return res + + except Exception as e: + print(e) + # Skipping chunk as a result of tool streaming extraction error + return StreamingParseResult() + + +class Qwen25Detector(BaseFormatDetector): + """ + Detector for Qwen 2.5 models. + Assumes function call format: + {"name":"xxx", "arguments":{...}} + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "" + self.eot_token = "" + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + if "" not in text: + return [] + pattern = r"(.*?)" + match_result_list = re.findall(pattern, text, re.DOTALL) + calls = [] + for match_result in match_result_list: + match_result = json.loads(match_result) + calls.extend(self.parse_base_json(match_result, tools)) + return calls + + +class MistralDetector(BaseFormatDetector): + """ + Detector for Mistral models. + Assumes function call format: + <|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|> + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "[TOOL_CALLS] [" + self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) + + def _clean_text(self, text: str) -> str: + """ + clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]' + for example, + text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.' + return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]' + The key pattern is [TOOL_CALLS] [...] + """ + find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL) + if len(find_results) > 0: + return find_results[0] + else: + return "" + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + text = self._clean_text(text) + tool_content = text.replace("[TOOL_CALLS]", "").strip() + raw_tool_calls = self.tool_call_regex.findall(tool_content) + calls = [] + if len(raw_tool_calls) > 0: + raw_tool_call = raw_tool_calls[0] + function_call_arr = json.loads(raw_tool_call) + for match_result in function_call_arr: + calls.extend(self.parse_base_json(match_result, tools)) + return calls + + +class Llama32Detector(BaseFormatDetector): + """ + Detector for Llama 3.2 models. + Assumes function call format: + <|python_tag|>{"name":"xxx", "arguments":{...}} + Does not require a closing tag "", + relies on json.loads(...) success to determine if JSON is complete. + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "<|python_tag|>" + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + + if "<|python_tag|>" not in text: + return [] + _, action = text.split("<|python_tag|>") + action = json.loads(action) + return self.parse_base_json(action, tools) + + +class MultiFormatParser: + def __init__(self, detectors: List[BaseFormatDetector]): + """ + :param detectors: A series of available Detector instances passed in + """ + self.detectors = detectors + + def parse_once(self, text: str, tools: List[Function]): + """ + One-time parsing: Loop through detectors until there are no new matches or text is exhausted + Return: (final_text, all_calls) + - final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text) + - all_calls: All calls parsed by the Detectors + """ + final_calls = [] + final_normal_text = text + for detector in self.detectors: + tool_call_list = detector.detect_and_parse(text, tools) + if len(tool_call_list) > 0: # parsed successfully + final_calls = tool_call_list + break + + # leftover_text is the normal text not consumed by any Detector + return final_normal_text, final_calls + + def parse_streaming_increment(self, new_text: str, tools: List[Function]): + """ + Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment + and merge their produced normal_text/calls to return. + (The logic here can be "priority-based" or "parallel parsing" based on your needs) + """ + final_normal_text = "" + final_calls = [] + + for detector in self.detectors: + sp_result = detector.parse_streaming_increment(new_text, tools) + # Merge normal_text and calls + # If one sp_result contains result call, this should be a successful parse + # If one sp_result only contains normal_text, this can either be a successful + # parse or it is not using the desired parsing tool. + if sp_result.normal_text: + final_normal_text = sp_result.normal_text + if sp_result.calls: + final_calls.extend(sp_result.calls) + final_normal_text = sp_result.normal_text + break + + return final_normal_text, final_calls + + +class FunctionCallParser: + """ + In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment + and returns the resulting normal_text and calls to the upper layer (or SSE). + """ + + ToolCallParserEnum: Dict[str, BaseFormatDetector] = { + "llama3": Llama32Detector, + "qwen25": Qwen25Detector, + "mistral": MistralDetector, + } + + def __init__(self, tools: List[Function], tool_call_parser: str = None): + detectors = [] + if tool_call_parser: + detector_class = self.ToolCallParserEnum.get(tool_call_parser) + if detector_class: + detectors.append(detector_class()) + else: + raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}") + else: + raise ValueError("Tool Call Parser Not Given!") + + self.multi_format_parser = MultiFormatParser(detectors) + self.tools = tools + + def parse_non_stream(self, full_text: str): + """ + Non-streaming call: one-time parsing + """ + full_normal_text, calls = self.multi_format_parser.parse_once( + full_text, self.tools + ) + return full_normal_text, calls + + def parse_stream_chunk(self, chunk_text: str): + """ + Streaming call: incremental parsing + """ + normal_text, calls = self.multi_format_parser.parse_streaming_increment( + chunk_text, self.tools + ) + return normal_text, calls diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index ebb0652c5d2..d69d854ab2e 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -20,10 +20,10 @@ import torch.nn as nn import torch.nn.functional as F -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +if is_cuda_available(): + from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from vllm.model_executor.custom_op import CustomOp @@ -149,8 +149,8 @@ def get_act_fn( return act_fn -if not is_flashinfer_available(): +if not is_cuda_available(): logger.info( - "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries." + "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." ) from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul diff --git a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py index 9163eba68de..d022b972147 100644 --- a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py @@ -166,6 +166,12 @@ def _fwd_kernel( def context_attention_fwd( q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True ): + """ + q, k, v: [b * s, head, head_dim] + b_start_loc: [b] + b_seq_len: [b] + out: [b * s, head, head_dim] + """ if is_cuda_available and CUDA_CAPABILITY[0] > 8: BLOCK = 128 else: diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index f66456b0437..03c4cfb46a8 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -4,10 +4,11 @@ import torch import torch.nn as nn +import torch.nn.functional as F from einops import rearrange, repeat -from vllm.distributed import parallel_state -from vllm.distributed import utils as dist_utils +from sglang.srt.distributed import parallel_state +from sglang.srt.distributed import utils as dist_utils from sglang.srt.layers.attention.triton_ops.prefill_attention import ( context_attention_fwd, ) @@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T class VisionAttention(nn.Module): - """Multi-headed attention without any cache, mostly used for ViT.""" + r""" + Multi-headed attention without any cache, mostly used for ViT. + + + Args: + use_qkv_parallel (bool, optional): If True, use QKV-parallel attention. + use_context_forward (bool, default to True): + if ``True``, a flash_attn style attention will be applied + Otherwise, a full-sequence attention will be applied. + use_full_precision_softmax (bool, default to False): + if ``True``, the softmax will be performed in full-precision + Otherwise, it will be performed in half-precision + + """ def __init__( self, @@ -72,25 +86,39 @@ def __init__( projection_size: int, use_qkv_parallel: bool, quant_config: Optional[QuantizationConfig] = None, + dropout: float = 0.0, + use_context_forward: bool = True, + use_full_precision_softmax: bool = False, + flatten_batch: bool = False, prefix: str = "", ): super().__init__() + self.use_context_forward = use_context_forward world_size = parallel_state.get_tensor_model_parallel_world_size() - + self.dropout = dropout + self.head_size = embed_dim // num_heads self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads ) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, world_size ) - # self.tp_size = get_tensor_model_parallel_world_size() - # num_heads = self.num_heads_per_partition + + if self.use_context_forward: + self.qkv_backend = VisionTritonAttention() + else: + self.qkv_backend = VisionSdpaAttention( + head_size=self.head_size, + dropout=dropout, + flatten_batch=flatten_batch, + use_full_precision_softmax=use_full_precision_softmax, + ) + self.use_qkv_parallel = use_qkv_parallel if use_qkv_parallel: - self.head_dim = embed_dim // num_heads self.qkv_proj = QKVParallelLinear( hidden_size=embed_dim, - head_size=self.head_dim, + head_size=self.head_size, total_num_heads=num_heads, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", @@ -114,12 +142,15 @@ def forward( x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, rotary_pos_emb: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: + r""" + Args: + x: [b, s, embed_dim] + cu_seqlens: [b] + Returns: + [s, b, num_heads * head] """ - Input shape: [b, s, embed_dim] - Output shape: [s, b, num_heads * head_size] - """ - bsz, s, _ = x.shape if self.use_qkv_parallel: # [b, s, embed_dim] --> [b, s, embed_dim] @@ -136,19 +167,19 @@ def forward( else: # [b, s, embed_dim] --> [s, b, embed_dim] x = rearrange(x, "b s ... -> s b ...") - # [s, b, embed_dim] --> [s, b, head * 3 * head_dim] + # [s, b, embed_dim] --> [s, b, head * 3 * head_size] qkv, _ = self.qkv_proj(x) - # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] + # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size] new_x_shape = qkv.size()[:-1] + ( self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head, ) qkv = qkv.view(*new_x_shape) - # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] + # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size] q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3) - # [s, b, head, head_dim] --> [b, s, head, head_dim] + # [s, b, head, head_size] --> [b, s, head, head_size] q, k, v = [ rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) ] @@ -160,45 +191,217 @@ def forward( if self.use_qkv_parallel: pass else: - # [b, s, head, head_dim] --> [b * s, head, head_dim] + # [b, s, head, head_size] --> [b * s, head, head_size] q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] - # [b * s, num_heads, head_size] - output = torch.empty_like(q) - - seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda() - max_seqlen = seq_lens.max().item() - - context_attention_fwd( - q, - k, - v, - output, - cu_seqlens.cuda(), - seq_lens, - max_seqlen, - is_causal=False, - ) + output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask) if self.use_qkv_parallel: - - # [b * s, head, head_dim] --> [b, s, head * head_dim] + # [b * s, h, head_size] --> [b, s, h * head_size] output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz) - # [b, s, head, head_dim] --> [b, s, head, head_dim] + # [b, s, h * head_size] --> [b, s, h * head_size] output, _ = self.proj(output) else: - # [b * s, head, head_dim] --> [b, s, head, head_dim] - context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz) - - # [s, b, num_heads * head_size] + # [b * s, h, head_size] --> [s, b, h * head_size] context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" + output, "(b s) h d -> s b (h d)", b=bsz, s=s ).contiguous() - # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size] + # [s, b, h * head_size] --> [s, b, h * head_size] output, _ = self.proj(context_layer) + # [s, b, h * head_size] --> [b, s, h * head_size] output = output.view(bsz, s, -1) return output + + +class VisionSdpaAttention(nn.Module): + r""" + Scaled Dot Product Attention inner product + + """ + + # TODO: Should it be released after used? + _mask_cache = {} + + def __init__( + self, + head_size: int, + dropout: float = 0.0, + flatten_batch: bool = False, + use_full_precision_softmax: bool = False, + ): + super().__init__() + self.head_size = head_size + self.flatten_batch = flatten_batch + self.use_full_precision_softmax = use_full_precision_softmax + self.dropout = dropout + + def generate_patch_attention_mask( + self, + s: int, + bsz: int, + device, + cu_seqlens: Optional[torch.Tensor], + flatten_batch: bool = False, + dtype=torch.bfloat16, + ) -> torch.Tensor: + r""" + Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`. + + When `flatten_batch` is True: + - All sequences in the batch are flattened into a single dimension + - `s` represents the total number of tokens across all sequences in the batch + - Returns a unified mask of shape `(1, 1, s, s)` + + When `flatten_batch` is False: + - Each sequence has its own attention mask + - `s` represents the maximum sequence length in the batch + - Returns separate masks of shape `(b, 1, s, s)` + + Args: + flatten_batch: (bool): + If True, treats all sequences in the batch as a single flattened sequence + If False, generates separate masks for each sequence + + Returns: + Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`. + """ + + cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist())) + + if cache_key in VisionSdpaAttention._mask_cache: + cached_mask = VisionSdpaAttention._mask_cache[cache_key] + # print(f"cache hit for key: {cache_key}") + return cached_mask.to(device=device, dtype=dtype) + + if cu_seqlens is None: + raise ValueError("Internal Error: cu_seqlens cannot be None") + + if flatten_batch: + mask = torch.zeros([1, s, s], device=device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + start = cu_seqlens[i - 1] + end = cu_seqlens[i] + mask[ + ..., + start:end, + start:end, + ] = True + else: + # [1, 1, 1, s] + row_indices = torch.arange(s, device=device).view(1, 1, 1, s) + # [1, 1, s, 1] + col_indices = torch.arange(s, device=device).view(1, 1, s, 1) + # [b, 1, 1, 1] + seq_lens = ( + (cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1) + ) + + mask = (row_indices < seq_lens) & (col_indices < seq_lens) + + # Convert to attention mask format (False -> 0, True -> -inf) + mask = (~mask).to(dtype) * torch.finfo(dtype).min + + VisionSdpaAttention._mask_cache[cache_key] = mask + + return mask + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + cu_seqlens: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + cu_seqlens: [b] + Returns: + [b * s, h, head_size] + """ + + s = q.shape[0] // bsz + + # [b, 1, s, s] + if attention_mask is None: + attention_mask = self.generate_patch_attention_mask( + s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype + ) + q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]] + # [b, 1, s] + if self.use_full_precision_softmax: + scale = self.head_size**-0.5 + k_transposed = rearrange(k, "b h s d -> b h d s") + attn_weights = torch.matmul(q, k_transposed) * scale + del k, k_transposed + attn_weights = attn_weights + attention_mask + del attention_mask + # full-precision + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=False + ) + output = torch.matmul(attn_weights, v) + del attn_weights, v + else: + # SDPA + # [b, h, s, head_size] + output = F.scaled_dot_product_attention( + q, k, v, attention_mask, dropout_p=self.dropout + ) + + # [b, h, s, head_size] --> [b * s, h, head_size] + output = rearrange(output, "b h s d -> (b s) h d") + + return output + + +class VisionTritonAttention(nn.Module): + """ + Triton-implemented attention without a causal mask + """ + + def __init__( + self, + ): + super().__init__() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + _bsz: int, + cu_seqlens: Optional[torch.Tensor], + **kwargs, + ) -> torch.Tensor: + r""" + Args: + cu_seqlens: [b] + Returns: + [b * s, h, head_size] + """ + + # [b * s, head, head_size] + output = torch.empty_like(q) + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + context_attention_fwd( + q, + k, + v, + output, + cu_seqlens.cuda(), + seq_lens.cuda(), + max_seqlen, + is_causal=False, + ) + + return output diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 65efa0feb84..36b87ca0ba0 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE + from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP + _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info( enable_dp_attention, tp_rank, tp_size, dp_size ) @@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): ], tp_rank, torch.distributed.get_backend(tp_group.device_group), - False, + SYNC_TOKEN_IDS_ACROSS_TP, False, False, False, diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index bd95b9bccce..207ba8d1b7a 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -19,10 +19,10 @@ import torch import torch.nn as nn -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer.norm import ( +if is_cuda_available(): + from sgl_kernel import ( fused_add_rmsnorm, gemma_fused_add_rmsnorm, gemma_rmsnorm, @@ -121,8 +121,8 @@ def forward_cuda( return out -if not is_flashinfer_available(): +if not is_cuda_available(): logger.info( - "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries." + "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." ) from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index bfa5d2b6654..64daf79c50f 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -329,12 +329,14 @@ def __init__( prefix: str = "", tp_rank: Optional[int] = None, tp_size: Optional[int] = None, + use_presharded_weights: bool = False, ): super().__init__( input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix ) self.gather_output = gather_output + self.use_presharded_weights = use_presharded_weights # Divide the weight matrix along the last dimension. if tp_rank is None: @@ -402,7 +404,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if output_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[output_dim] start_idx = self.tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -418,7 +421,11 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): if len(loaded_weight.shape) == 0: assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank) + param.load_column_parallel_weight( + loaded_weight, + tp_rank=self.tp_rank, + use_presharded_weights=self.use_presharded_weights, + ) def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -499,7 +506,9 @@ def __init__( prefix=prefix, tp_rank=tp_rank, tp_size=tp_size, + use_presharded_weights=use_presharded_weights, ) + self.prefix = prefix def weight_loader( self, @@ -743,6 +752,7 @@ def __init__( prefix: str = "", tp_rank: Optional[int] = None, tp_size: Optional[int] = None, + load_presharded_attn: bool = False, ): self.hidden_size = hidden_size self.head_size = head_size @@ -772,6 +782,7 @@ def __init__( self.num_kv_heads * self.head_size * tp_size, # k_proj self.num_kv_heads * self.head_size * tp_size, # v_proj ] + self.use_presharded_weights = load_presharded_attn super().__init__( input_size=input_size, @@ -784,6 +795,7 @@ def __init__( prefix=prefix, tp_rank=tp_rank, tp_size=tp_size, + use_presharded_weights=self.use_presharded_weights, ) def _get_shard_offset_mapping(self, loaded_shard_id: str): @@ -842,9 +854,10 @@ def _load_fused_module_from_checkpoint( shard_size=shard_size, shard_offset=shard_offset ) - loaded_weight_shard = loaded_weight.narrow( - param.output_dim, shard_offset, shard_size - ) + if not self.use_presharded_weights: + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) self.weight_loader_v2(param, loaded_weight_shard, shard_id) def weight_loader_v2( @@ -882,6 +895,7 @@ def weight_loader_v2( shard_offset=shard_offset, shard_size=shard_size, tp_rank=self.tp_rank, + use_presharded_weights=self.use_presharded_weights, ) def weight_loader( @@ -987,9 +1001,10 @@ def weight_loader( param, orig_qkv_offsets, shard_id ) - loaded_weight_shard = loaded_weight.narrow( - output_dim, shard_offset, shard_size - ) + if not self.use_presharded_weights: + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) self.weight_loader(param, loaded_weight_shard, shard_id) return @@ -1049,7 +1064,7 @@ def weight_loader( # bitsandbytes loads the weights of the specific portion # no need to narrow here - if not use_bitsandbytes_4bit: + if not use_bitsandbytes_4bit and not self.use_presharded_weights: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for for AQLM codebooks. diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 10f26467787..08ee5a3509b 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -14,6 +14,7 @@ """Logits processing.""" import dataclasses +import logging from typing import List, Optional, Union import torch @@ -32,6 +33,8 @@ ForwardMode, ) +logger = logging.getLogger(__name__) + @dataclasses.dataclass class LogitsProcessorOutput: @@ -136,50 +139,61 @@ def forward( logits_metadata.forward_mode.is_decode_or_idle() or logits_metadata.forward_mode.is_target_verify() ): - last_index = None - last_hidden = hidden_states - else: + pruned_states = hidden_states + sample_indices = None + elif ( + logits_metadata.forward_mode.is_extend() + and not logits_metadata.extend_return_logprob + ): + # Prefill without input logprobs. last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 - last_hidden = hidden_states[last_index] + pruned_states = hidden_states[last_index] + sample_indices = None + else: + # Slice the requested tokens to compute logprob + sample_index_pt = -1 + sample_indices = [] + pt, pruned_states, pruned_input_ids = 0, [], [] + for start_len, extend_len in zip( + logits_metadata.extend_logprob_start_lens_cpu, + logits_metadata.extend_seq_lens_cpu, + ): + pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) + sample_index_pt += extend_len - start_len + sample_indices.append(sample_index_pt) + pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) + pt += extend_len + + pruned_states = torch.cat(pruned_states) + + # Compute logits for both input and sampled tokens. + logits = self._get_logits(pruned_states, lm_head, logits_metadata) + sampled_logits = ( + logits[sample_indices] if sample_indices is not None else logits + ) - # Compute logits - last_logits = self._get_logits(last_hidden, lm_head) if ( not logits_metadata.extend_return_logprob or logits_metadata.capture_hidden_mode.need_capture() ): # Decode mode or extend mode without return_logprob. return LogitsProcessorOutput( - next_token_logits=last_logits, + next_token_logits=sampled_logits, hidden_states=( hidden_states if logits_metadata.capture_hidden_mode.is_full() else ( - last_hidden + pruned_states if logits_metadata.capture_hidden_mode.is_last() else None ) ), ) else: - # Slice the requested tokens to compute logprob - pt, pruned_states, pruned_input_ids = 0, [], [] - for start_len, extend_len in zip( - logits_metadata.extend_logprob_start_lens_cpu, - logits_metadata.extend_seq_lens_cpu, - ): - pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) - pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) - pt += extend_len - - # Compute the logits of all required tokens - pruned_states = torch.cat(pruned_states) - del hidden_states - input_token_logits = self._get_logits(pruned_states, lm_head) - del pruned_states + input_logprobs = logits + del hidden_states, logits # Normalize the logprob w/o temperature, top-p - input_logprobs = input_token_logits input_logprobs = self.compute_temp_top_p_normalized_logprobs( input_logprobs, logits_metadata ) @@ -194,17 +208,17 @@ def forward( input_top_logprobs_val = input_top_logprobs_idx = None input_token_logprobs = input_logprobs[ - torch.arange(input_logprobs.shape[0], device="cuda"), + torch.arange(input_logprobs.shape[0], device=input_logprobs.device), torch.cat( [ torch.cat(pruned_input_ids)[1:], - torch.tensor([0], device="cuda"), + torch.tensor([0], device=input_logprobs.device), ] ), ] return LogitsProcessorOutput( - next_token_logits=last_logits, + next_token_logits=sampled_logits, input_token_logprobs=input_token_logprobs, input_top_logprobs_val=input_top_logprobs_val, input_top_logprobs_idx=input_top_logprobs_idx, @@ -214,8 +228,11 @@ def _get_logits( self, hidden_states: torch.Tensor, lm_head: VocabParallelEmbedding, + logits_metadata: LogitsMetadata, embedding_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """Get logits from hidden_states.""" + if hasattr(lm_head, "weight"): logits = torch.matmul(hidden_states, lm_head.weight.T) else: @@ -279,7 +296,7 @@ def fused_softcap_kernel( n_elements, BLOCK_SIZE: tl.constexpr, ): - pid = tl.program_id(0) + pid = tl.program_id(0).to(tl.int64) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 8f5a71dff8c..bc927621a84 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -114,6 +114,8 @@ def __init__( tp_size: Optional[int] = None, prefix: str = "", correction_bias: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + activation: str = "silu", ): super().__init__() @@ -140,6 +142,8 @@ def __init__( self.num_expert_group = num_expert_group self.topk_group = topk_group self.correction_bias = correction_bias + self.custom_routing_function = custom_routing_function + self.activation = activation if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() @@ -166,6 +170,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + assert self.activation == "silu" if self.grouped_gemm_runner is None: self.grouped_gemm_runner = GroupedGemmRunner( @@ -181,6 +186,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): topk_group=self.topk_group, num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, + custom_routing_function=self.custom_routing_function, ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( @@ -254,16 +260,20 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): dtype=torch.float32, device=hidden_states.device, ) - silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( - gateup_output, - down_input, - gateup_output.shape[1], - reorder_topk_ids, - self.w2_input_scale, - self.start_expert_id, - self.end_expert_id, - BLOCK_SIZE=512, - ) + + if self.activation == "silu": + silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( + gateup_output, + down_input, + gateup_output.shape[1], + reorder_topk_ids, + self.w2_input_scale, + self.start_expert_id, + self.end_expert_id, + BLOCK_SIZE=512, + ) + else: + raise ValueError(f"Unsupported activation: {self.activation=}") # GroupGemm-1 down_output = torch.empty( @@ -309,7 +319,6 @@ def make_expert_params_mapping( ckpt_up_proj_name: str, num_experts: int, ) -> List[Tuple[str, str, int, str]]: - return [ # (param_name, weight_name, expert_id, shard_id) ( @@ -354,7 +363,6 @@ def weight_loader( ) return - expert_data = param.data[expert_id] if shard_id == "w2": param.data[expert_id] = loaded_weight elif shard_id == "w1": diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 0703e840ca6..042c0a52c56 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -8,7 +8,7 @@ import torch from torch.nn import functional as F -from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.activation import GeluAndMul, SiluAndMul from sglang.srt.layers.moe.topk import select_experts @@ -23,6 +23,7 @@ def fused_moe_forward_native( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: topk_weights, topk_ids = select_experts( hidden_states=x, @@ -41,7 +42,12 @@ def fused_moe_forward_native( w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w2_weights = layer.w2_weight[topk_ids] x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) - x1 = F.silu(x1) + if activation == "silu": + x1 = F.silu(x1) + elif activation == "gelu": + x1 = F.gelu(x1) + else: + raise ValueError(f"Unsupported activation: {activation=}") x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) @@ -58,6 +64,7 @@ def moe_forward_native( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: topk_weights, topk_ids = select_experts( @@ -84,6 +91,13 @@ def moe_forward_native( sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() + if activation == "silu": + act = SiluAndMul() + elif activation == "gelu": + act = GeluAndMul() + else: + raise ValueError(f"Unsupported activation: {activation=}") + outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): @@ -96,7 +110,7 @@ def moe_forward_native( layer_w2_weight = layer.w2_weight[i] gate_up = F.linear(tokens_for_this_expert, layer_w13_weight) - gate_up = SiluAndMul()(gate_up) + gate_up = act(gate_up) expert_out = F.linear(gate_up, layer_w2_weight) outputs.append(expert_out) start_idx = end_idx diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..a7be90051f8 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 01ecce1a6ed..32c8fcbb625 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -15,18 +15,18 @@ from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip +from sglang.srt.utils import ( + direct_register_custom_op, + get_device_name, + is_cuda_available, + is_hip, +) -is_hip_flag = False -if not is_hip(): - if torch.cuda.is_available(): - from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size - else: - sgl_moe_align_block_size = None +is_cuda = is_cuda_available() +is_hip_flag = is_hip() +if is_cuda: + from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size - is_hip_flag = False -else: - is_hip_flag = True logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @@ -711,6 +711,7 @@ def inplace_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -726,6 +727,7 @@ def inplace_fused_experts( topk_weights, topk_ids, True, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, @@ -742,6 +744,7 @@ def inplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -767,6 +770,7 @@ def outplace_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -782,6 +786,7 @@ def outplace_fused_experts( topk_weights, topk_ids, False, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, @@ -798,6 +803,7 @@ def outplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -824,6 +830,7 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -839,6 +846,7 @@ def fused_experts( w2, topk_weights, topk_ids, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, @@ -855,6 +863,7 @@ def fused_experts( w2, topk_weights, topk_ids, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, @@ -872,6 +881,7 @@ def fused_experts_impl( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -986,7 +996,12 @@ def fused_experts_impl( block_shape=block_shape, ) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + if activation == "silu": + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + elif activation == "gelu": + ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + else: + raise ValueError(f"Unsupported activation: {activation=}") invoke_fused_moe_kernel( intermediate_cache2, @@ -1042,6 +1057,7 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, + activation: str = "silu", use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, @@ -1111,6 +1127,7 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, + activation=activation, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, w1_scale=w1_scale, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 75d4c5ead65..b71a878a0ba 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -126,6 +126,7 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: return self.forward( x=x, @@ -138,6 +139,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, + activation=activation, ) def forward_cuda( @@ -152,6 +154,7 @@ def forward_cuda( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: topk_weights, topk_ids = select_experts( hidden_states=x, @@ -169,6 +172,8 @@ def forward_cuda( import ater from ater.fused_moe import fused_experts_ck + assert activation == "silu", f"{activation=} is not supported." + return fused_experts_ck( hidden_states=x, w1=layer.w13_weight, @@ -184,6 +189,7 @@ def forward_cuda( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, ) def forward_cpu( @@ -256,6 +262,7 @@ def __init__( prefix: str = "", custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", use_presharded_weights: bool = False, ): super().__init__() @@ -279,6 +286,7 @@ def __init__( self.topk_group = topk_group self.custom_routing_function = custom_routing_function self.correction_bias = correction_bias + self.activation = activation if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -589,6 +597,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, correction_bias=self.correction_bias, + activation=self.activation, ) if self.reduce_results and self.tp_size > 1: diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index d99b2efe85f..78be6798254 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -124,7 +124,13 @@ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs): + def load_qkv_weight( + self, + loaded_weight: torch.Tensor, + tp_rank: int, + use_presharded_weights: bool = False, + **kwargs, + ): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") @@ -142,11 +148,14 @@ def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs): param_data = self.data shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) - loaded_weight = loaded_weight.narrow( - self.output_dim, shard_id * shard_size, shard_size - ) + if not use_presharded_weights: + loaded_weight = loaded_weight.narrow( + self.output_dim, shard_id * shard_size, shard_size + ) - assert param_data.shape == loaded_weight.shape + assert ( + param_data.shape == loaded_weight.shape + ), f"{param_data.shape=}, {loaded_weight.shape=}" param_data.copy_(loaded_weight) @@ -292,7 +301,7 @@ def __init__( packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, - **kwargs + **kwargs, ): self._packed_factor = packed_factor self._packed_dim = packed_dim @@ -336,7 +345,7 @@ def __init__( packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, - **kwargs + **kwargs, ): self._packed_factor = packed_factor self._packed_dim = packed_dim diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..c098ef2dbb9 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..6f5adbb9361 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..4225c78eb72 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..5e6789d00e0 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..49ac14d2a57 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..dcbb0efc53e --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..dfe5c1e43d6 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..a87f5de1b18 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..468f9e78da0 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index bd59352a796..b0b5b8952a1 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -763,8 +763,8 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts @@ -785,6 +785,8 @@ def apply( import ater from ater.fused_moe import fused_experts_ck + assert activation == "silu", f"{activation=} is not supported." + return fused_experts_ck( x, layer.w13_weight, @@ -815,6 +817,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, use_fp8_w8a8=True, w1_scale=( layer.w13_weight_scale_inv diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 822b344feab..94094420596 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -54,8 +54,8 @@ def __init__( self.logit_cap = logit_cap self.sliding_window_size = sliding_window_size or -1 self.is_cross_attention = is_cross_attention - self.k_scale = 1.0 - self.v_scale = 1.0 + self.k_scale = None + self.v_scale = None self.orig_context_len = orig_context_len diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 7c18c683e96..7093bb90d81 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -6,8 +6,16 @@ import torch import torch.nn as nn +from vllm import _custom_ops as ops from vllm.model_executor.custom_op import CustomOp +from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.utils import is_cuda_available + +_is_cuda_available = is_cuda_available() +if _is_cuda_available: + from sgl_kernel import apply_rope_with_cos_sin_cache_inplace + def _rotate_neox(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] @@ -51,7 +59,7 @@ def _apply_rotary_emb( return torch.stack((o1, o2), dim=-1).flatten(-2) -@CustomOp.register("rotary_embedding") +@register_custom_op("sglang_rotary_embedding") class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" @@ -73,7 +81,9 @@ def __init__( self.dtype = dtype cache = self._compute_cos_sin_cache() - cache = cache.to(dtype) + # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability + if not _is_cuda_available: + cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -139,23 +149,17 @@ def forward_cuda( key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - from vllm import _custom_ops as ops - - self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - self.rotary_dim, - offsets, + if _is_cuda_available: + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=self.head_size, + cos_sin_cache=self.cos_sin_cache, + is_neox=self.is_neox_style, ) else: + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) ops.rotary_embedding( positions, query, @@ -176,28 +180,14 @@ def forward_xpu( from vllm._ipex_ops import ipex_ops as ops self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) - # ops.rotary_embedding()/batched_rotary_embedding() - # are in-place operations that update the query and key tensors. - if offsets is not None: - ops.batched_rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - self.rotary_dim, - offsets, - ) - else: - ops.rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - ) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) return query, key def forward_hpu( @@ -664,6 +654,7 @@ def __init__( beta_slow: int = 1, mscale: float = 1, mscale_all_dim: float = 0, + device: Optional[str] = "cuda", ) -> None: self.scaling_factor = scaling_factor self.extrapolation_factor = extrapolation_factor @@ -676,13 +667,14 @@ def __init__( / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor ) + self.device = device super().__init__( head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype ) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: pos_freqs = self.base ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda") + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) / self.rotary_dim ) inv_freq_extrapolation = 1.0 / pos_freqs @@ -710,7 +702,7 @@ def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange( self.max_position_embeddings * self.scaling_factor, - device="cuda", + device=self.device, dtype=torch.float32, ) freqs = torch.einsum("i,j -> ij", t, inv_freq) @@ -1042,7 +1034,12 @@ def get_rope( head_size, rotary_dim, max_position, base, is_neox_style, dtype ) else: - scaling_type = rope_scaling["rope_type"] + if "rope_type" in rope_scaling: + scaling_type = rope_scaling["rope_type"] + elif "type" in rope_scaling: + scaling_type = rope_scaling["type"] + else: + raise ValueError("Unknown RoPE scaling type") if scaling_type == "llama3": scaling_factor = rope_scaling["factor"] @@ -1174,3 +1171,111 @@ def get_rope( raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb return rotary_emb + + +def get_rope_cpu( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, + device: Optional[str] = None, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = ( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling_args, + dtype, + ) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + assert rope_scaling is not None + scaling_type = rope_scaling["rope_type"] + assert ( + scaling_type == "deepseek_yarn" + ), "Only deepseek_yarn is supported for CPU for now" + + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) + } + extra_kwargs["device"] = device + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + + _ROPE_DICT[key] = rotary_emb + return rotary_emb + + +def get_rope_wrapper( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, + device: Optional[str] = None, +): + if device != "cpu": + return get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling, + dtype, + partial_rotary_factor, + ) + + return get_rope_cpu( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling, + dtype, + partial_rotary_factor, + device, + ) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 23037650a31..b24bfc8dacf 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -2,15 +2,18 @@ from typing import List import torch +import torch.distributed as dist from torch import nn +from sglang.srt.distributed import get_tensor_model_parallel_group +from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.utils import crash_on_warnings, is_flashinfer_available +from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available -if is_flashinfer_available(): - from flashinfer.sampling import ( +if is_cuda_available(): + from sgl_kernel import ( min_p_sampling_from_probs, top_k_renorm_prob, top_k_top_p_sampling_from_probs, @@ -20,11 +23,17 @@ logger = logging.getLogger(__name__) +SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP") + class Sampler(nn.Module): def __init__(self): super().__init__() self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"] + self.tp_sync_group = get_tensor_model_parallel_group().device_group + + if global_server_args_dict["enable_dp_attention"]: + self.tp_sync_group = get_attention_tp_group().device_group def forward( self, @@ -35,6 +44,10 @@ def forward( ): logits = logits_output.next_token_logits + # Apply the custom logit processors if registered in the sampling info. + if sampling_info.has_custom_logit_processor: + self._apply_custom_logit_processor(logits, sampling_info) + if self.use_nan_detectioin and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") logits = torch.where( @@ -104,8 +117,6 @@ def forward( f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" ) - batch_next_token_ids = batch_next_token_ids.to(torch.int32) - # Attach logprobs to logits_output (in-place modification) if return_logprob: if any(x > 0 for x in top_logprobs_nums): @@ -119,7 +130,54 @@ def forward( batch_next_token_ids, ] - return batch_next_token_ids + if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars: + # For performance reasons, SGLang does not sync the final token IDs across TP ranks by default. + # This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators: + # the last all-reduce, the last lm_head matmul, and all sampling kernels. + # These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic. + # In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized. + # When using xgrammar, this becomes more likely so we also do the sync when grammar is used. + + torch.distributed.all_reduce( + batch_next_token_ids, + op=dist.ReduceOp.MIN, + group=self.tp_sync_group, + ) + + return batch_next_token_ids.to(torch.int32) + + def _apply_custom_logit_processor( + self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo + ): + """Apply custom logit processors to the logits. + This function will modify the logits in-place.""" + + assert logits.shape[0] == len(sampling_batch_info), ( + f"The batch size of logits ({logits.shape[0]}) does not match the batch size of " + f"sampling_batch_info ({len(sampling_batch_info)})" + ) + + for _, ( + processor, + batch_mask, + ) in sampling_batch_info.custom_logit_processor.items(): + # Get the batch indices that need to be processed + batch_indices = batch_mask.nonzero(as_tuple=True)[0] + + assert batch_mask.shape[0] == len(sampling_batch_info), ( + f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of " + f"sampling_batch_info ({len(sampling_batch_info)})" + ) + + # Apply the processor to the logits + logits[batch_mask] = processor( + logits[batch_mask], + [sampling_batch_info.custom_params[i] for i in batch_indices], + ) + + logger.debug( + f"Custom logit processor {processor.__class__.__name__} is applied." + ) def top_k_top_p_min_p_sampling_from_probs_torch( diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index c5bca25df37..e08abd5ae1d 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -5,6 +5,7 @@ import logging import os import pwd +from typing import Callable, Optional import torch @@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool: return True +def proj_filter( + module: torch.nn.Module, + fqn: str, +): + """Filter function for quantizing projection layers.""" + return "proj" in fqn + + def apply_torchao_config_to_model( - model: torch.nn.Module, torchao_config: str, filter_fn=None + model: torch.nn.Module, + torchao_config: str, + filter_fn: Optional[Callable] = proj_filter, ): """Quantize a modelwith torchao quantization specified by torchao_config @@ -49,11 +60,6 @@ def apply_torchao_config_to_model( ) from torchao.quantization.observer import PerRow, PerTensor - if filter_fn is None: - - def filter_fn(module, fqn): - return "proj" in fqn - if torchao_config == "" or torchao_config is None: return model elif "int8wo" in torchao_config: diff --git a/python/sglang/srt/managers/configure_logging.py b/python/sglang/srt/managers/configure_logging.py index 3351cdc400c..187af4d9c08 100644 --- a/python/sglang/srt/managers/configure_logging.py +++ b/python/sglang/srt/managers/configure_logging.py @@ -27,6 +27,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--url", type=str, default="http://localhost:30000") + parser.add_argument("--log-requests", action="store_true") parser.add_argument( "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump" ) @@ -36,6 +37,8 @@ response = requests.post( args.url + "/configure_logging", json={ + "log_requests": args.log_requests, + "log_requests_level": 1, # Log full requests "dump_requests_folder": args.dump_requests_folder, "dump_requests_threshold": args.dump_requests_threshold, }, diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index f0605ee1fea..a8ded73bccc 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -15,6 +15,7 @@ import dataclasses import logging +import os import signal from collections import OrderedDict from typing import Dict, List, Union @@ -35,6 +36,12 @@ logger = logging.getLogger(__name__) +# Maximum number of request states that detokenizer can hold. When exceeded, +# oldest request states will be evicted. Default: 65536 (1<<16). +# For more details, see: https://github.com/sgl-project/sglang/issues/2812 +# Use power of 2 values for better memory allocation. +DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 << 16)) + @dataclasses.dataclass class DecodeStatus: @@ -71,9 +78,10 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) - self.decode_status = LimitedCapacityDict() + self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) def trim_matched_stop( self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool @@ -155,7 +163,17 @@ def event_loop(self): # Incremental decoding output_strs = [] for i in range(bs): - s = self.decode_status[recv_obj.rids[i]] + try: + s = self.decode_status[recv_obj.rids[i]] + except KeyError: + raise RuntimeError( + f"Decode status not found for request {recv_obj.rids[i]}. " + "It may be due to the request being evicted from the decode status due to memory pressure. " + "Please increase the maximum number of requests by setting " + "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. " + f"The current value is {DETOKENIZER_MAX_STATES}. " + "For more details, see: https://github.com/sgl-project/sglang/issues/2812" + ) new_text = read_texts[i][len(surr_texts[i]) :] if recv_obj.finished_reasons[i] is None: # Streaming chunk: update the decode status @@ -183,6 +201,7 @@ def event_loop(self): prompt_tokens=recv_obj.prompt_tokens, completion_tokens=recv_obj.completion_tokens, cached_tokens=recv_obj.cached_tokens, + spec_verify_ct=recv_obj.spec_verify_ct, input_token_logprobs_val=recv_obj.input_token_logprobs_val, input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, output_token_logprobs_val=recv_obj.output_token_logprobs_val, @@ -196,7 +215,7 @@ def event_loop(self): class LimitedCapacityDict(OrderedDict): - def __init__(self, capacity=1 << 15, *args, **kwargs): + def __init__(self, capacity: int, *args, **kwargs): super().__init__(*args, **kwargs) self.capacity = capacity diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index c8ebbed783a..f43ecb18c16 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -240,6 +240,7 @@ async def process_images_async( class MiniCPMVImageProcessor(BaseImageProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) + self.IMAGE_TOKEN = "(./)" @staticmethod def _process_images_task(images, input_text): @@ -271,7 +272,7 @@ async def _process_images(self, images, input_text): async def process_images_async( self, image_data: List[Union[str, bytes]], - input_text, + input_ids, request_obj, max_req_input_len, ): @@ -282,28 +283,49 @@ async def process_images_async( image_data = [image_data] image_hashes, image_sizes = [], [] - raw_images = [] - IMAGE_TOKEN = "(./)" + all_frames = [] - # roughly calculate the max number of frames - # TODO: the process should be applied to all the visual inputs + # roughly calculate the max number of frames under the max_req_input_len limit def calculate_max_num_frames() -> int: # Model-specific NUM_TOKEN_PER_FRAME = 330 - ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME + ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME return min(ret, 100) - # if cuda OOM set a smaller number MAX_NUM_FRAMES = calculate_max_num_frames() - print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}") - def encode_video(video_path): + # print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}") + + def get_estimated_frames_list(): + """ + estimate the total frame count from all visual input + """ + # Before processing inputs + estimated_frames_list = [] + for image in image_data: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + # Estimate frames for the video + vr = VideoReader(path, ctx=cpu(0)) + num_frames = len(vr) + else: + # For images, each contributes one frame + num_frames = 1 + estimated_frames_list.append(num_frames) + + return estimated_frames_list + + estimated_frames_list = get_estimated_frames_list() + total_frame_count = sum(estimated_frames_list) + scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count) + + def encode_video(video_path, frame_count_limit=None): if not os.path.exists(video_path): logger.error(f"Video {video_path} does not exist") return [] - if MAX_NUM_FRAMES == 0: + if frame_count_limit == 0: return [] def uniform_sample(l, n): @@ -314,45 +336,63 @@ def uniform_sample(l, n): vr = VideoReader(video_path, ctx=cpu(0)) sample_fps = round(vr.get_avg_fps() / 1) # FPS frame_idx = [i for i in range(0, len(vr), sample_fps)] - if len(frame_idx) > MAX_NUM_FRAMES: - frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) + if frame_count_limit is not None and len(frame_idx) > frame_count_limit: + frame_idx = uniform_sample(frame_idx, frame_count_limit) frames = vr.get_batch(frame_idx).asnumpy() frames = [Image.fromarray(v.astype("uint8")) for v in frames] return frames - if isinstance(input_text, list): - assert len(input_text) and isinstance(input_text[0], int) - input_text = self._processor.tokenizer.decode(input_text) - + if isinstance(input_ids, list): + assert len(input_ids) and isinstance(input_ids[0], int) + input_text = self._processor.tokenizer.decode(input_ids) + else: + input_text = input_ids # MiniCPMV requires each frame of video as a single image token - text_parts = input_text.split(IMAGE_TOKEN) + text_parts = input_text.split(self.IMAGE_TOKEN) new_text_parts = [] - for image_index, image in enumerate(image_data): - try: - if isinstance(image, str) and image.startswith("video:"): - path = image[len("video:") :] - frames = encode_video(path) - else: - raw_image, size = load_image(image) - frames = [raw_image] - if len(frames) == 0: - continue - except FileNotFoundError as e: - print(e) - return None - - image_sizes += frames[0].size * len(frames) - image_hashes += [hash(image)] * len(frames) - raw_images += frames + # Process each input with allocated frames + for image_index, (image, estimated_frames) in enumerate( + zip(image_data, estimated_frames_list) + ): + if len(all_frames) >= MAX_NUM_FRAMES: + frames_to_process = 0 + else: + frames_to_process = max(1, int(estimated_frames * scaling_factor)) + + if frames_to_process == 0: + frames = [] + else: + try: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + frames = encode_video(path, frame_count_limit=frames_to_process) + else: + raw_image, _size = load_image(image) + frames = [raw_image] + if len(frames) == 0: + continue + except FileNotFoundError as e: + print(e) + return None + image_sizes += frames[0].size * len(frames) + image_hashes += [hash(image)] * len(frames) + all_frames += frames + + assert frames_to_process == len(frames) + new_text_parts.append(text_parts[image_index]) - new_text_parts.append(IMAGE_TOKEN * len(frames)) + + if frames_to_process != 0: + new_text_parts.append(self.IMAGE_TOKEN * len(frames)) new_text_parts.append(text_parts[-1]) + input_text = "".join(new_text_parts) - if len(raw_images) == 0: + + if len(all_frames) == 0: return None - res = await self._process_images(images=raw_images, input_text=input_text) + res = await self._process_images(images=all_frames, input_text=input_text) pixel_values = res["pixel_values"] tgt_sizes = res["tgt_sizes"] input_ids = res["input_ids"] @@ -364,7 +404,6 @@ def uniform_sample(l, n): if tokenizer.slice_start_id: slice_start_id = [tokenizer.slice_start_id] slice_end_id = [tokenizer.slice_end_id] - return { "input_ids": input_ids.flatten().tolist(), "pixel_values": pixel_values, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 7f07055132f..f7419d04f33 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -17,7 +17,7 @@ """ import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import Dict, List, Optional, Union @@ -69,6 +69,10 @@ class GenerateReqInput: # Session info for continual prompting session_params: Optional[Union[List[Dict], Dict]] = None + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. + custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None def normalize_batch_and_arguments(self): if ( @@ -183,6 +187,13 @@ def normalize_batch_and_arguments(self): else: assert self.parallel_sample_num == 1 + if self.custom_logit_processor is None: + self.custom_logit_processor = [None] * num + elif not isinstance(self.custom_logit_processor, list): + self.custom_logit_processor = [self.custom_logit_processor] * num + else: + assert self.parallel_sample_num == 1 + def regenerate_rid(self): self.rid = uuid.uuid4().hex return self.rid @@ -202,6 +213,11 @@ def __getitem__(self, i): log_metrics=self.log_metrics, modalities=self.modalities[i] if self.modalities else None, lora_path=self.lora_path[i] if self.lora_path is not None else None, + custom_logit_processor=( + self.custom_logit_processor[i] + if self.custom_logit_processor is not None + else None + ), ) @@ -234,6 +250,11 @@ class TokenizedGenerateReqInput: # Session info for continual prompting session_params: Optional[SessionParams] = None + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. + custom_logit_processor: Optional[str] = None + @dataclass class EmbeddingReqInput: @@ -333,10 +354,13 @@ class BatchTokenIDOut: skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] no_stop_trim: List[bool] + # Token counts prompt_tokens: List[int] completion_tokens: List[int] cached_tokens: List[int] + spec_verify_ct: List[int] + # Logprobs input_token_logprobs_val: List[float] input_token_logprobs_idx: List[int] @@ -361,6 +385,7 @@ class BatchStrOut: prompt_tokens: List[int] completion_tokens: List[int] cached_tokens: List[int] + spec_verify_ct: List[int] # Logprobs input_token_logprobs_val: List[float] @@ -495,6 +520,7 @@ class ProfileReq(Enum): @dataclass class ConfigureLoggingReq: log_requests: Optional[bool] = None + log_requests_level: Optional[int] = None dump_requests_folder: Optional[str] = None dump_requests_threshold: Optional[int] = None @@ -514,3 +540,27 @@ class CloseSessionReqInput: class OpenSessionReqOutput: session_id: Optional[str] success: bool + + +@dataclass +class Function: + description: Optional[str] = None + name: Optional[str] = None + parameters: Optional[object] = None + + +@dataclass +class Tool: + function: Function + type: Optional[str] = "function" + + +@dataclass +class FunctionCallReqInput: + text: str # The text to parse. + tools: List[Tool] = field( + default_factory=list + ) # A list of available function tools (name, parameters, etc.). + tool_call_parser: Optional[str] = ( + None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all. + ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 4784a98d968..d51a8e45a08 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -158,7 +158,6 @@ class ImageInputs: im_end_id: Optional[torch.Tensor] = None slice_start_id: Optional[torch.Tensor] = None slice_end_id: Optional[torch.Tensor] = None - tgt_sizes: Optional[list] = None @staticmethod @@ -233,6 +232,7 @@ def __init__( lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, session_id: Optional[str] = None, + custom_logit_processor: Optional[str] = None, eos_token_ids: Optional[Set[int]] = None, ): # Input and output info @@ -247,12 +247,13 @@ def __init__( # Each decode stage's output ids self.output_ids = [] # fill_ids = origin_input_ids + output_ids. Updated if chunked. + self.fill_ids = None self.session_id = session_id self.input_embeds = input_embeds # Sampling info self.sampling_params = sampling_params - self.lora_path = lora_path + self.custom_logit_processor = custom_logit_processor # Memory pool info self.req_pool_idx = None @@ -299,7 +300,7 @@ def __init__( self.logprob_start_len = 0 self.top_logprobs_num = top_logprobs_num - # Logprobs (return value) + # Logprobs (return values) self.input_token_logprobs_val: Optional[List[float]] = None self.input_token_logprobs_idx: Optional[List[int]] = None self.input_top_logprobs_val: Optional[List[float]] = None @@ -328,8 +329,14 @@ def __init__( # Constrained decoding self.grammar: Optional[BaseGrammarObject] = None - # The number of cached tokens, that were already cached in the KV cache + # The number of cached tokens that were already cached in the KV cache self.cached_tokens = 0 + self.already_computed = 0 + + # The number of verification forward passes in the speculative decoding. + # This is used to compute the average acceptance length per request. + self.spec_verify_ct = 0 + self.lora_path = lora_path def extend_image_inputs(self, image_inputs): if self.image_inputs is None: @@ -549,13 +556,13 @@ class ScheduleBatch: next_batch_sampling_info: SamplingBatchInfo = None # Batched arguments to model runner - input_ids: torch.Tensor = None - input_embeds: torch.Tensor = None - req_pool_indices: torch.Tensor = None - seq_lens: torch.Tensor = None + input_ids: torch.Tensor = None # shape: [b], int32 + input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32 + req_pool_indices: torch.Tensor = None # shape: [b], int32 + seq_lens: torch.Tensor = None # shape: [b], int64 # The output locations of the KV cache - out_cache_loc: torch.Tensor = None - output_ids: torch.Tensor = None + out_cache_loc: torch.Tensor = None # shape: [b], int32 + output_ids: torch.Tensor = None # shape: [b], int32 # The sum of all sequence lengths seq_lens_sum: int = None @@ -594,6 +601,9 @@ class ScheduleBatch: spec_algorithm: SpeculativeAlgorithm = None spec_info: Optional[SpecInfo] = None + # Enable custom logit processor + enable_custom_logit_processor: bool = False + @classmethod def init_new( cls, @@ -604,6 +614,7 @@ def init_new( model_config: ModelConfig, enable_overlap: bool, spec_algorithm: SpeculativeAlgorithm, + enable_custom_logit_processor: bool, ): return cls( reqs=reqs, @@ -617,6 +628,7 @@ def init_new( has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, spec_algorithm=spec_algorithm, + enable_custom_logit_processor=enable_custom_logit_processor, ) def batch_size(self): @@ -744,13 +756,6 @@ def prepare_for_extend(self): pt = 0 for i, req in enumerate(reqs): - already_computed = ( - req.extend_logprob_start_len + 1 + req.cached_tokens - if req.extend_logprob_start_len > 0 - else 0 - ) - req.cached_tokens += len(req.prefix_indices) - already_computed - req.req_pool_idx = req_pool_indices[i] pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) seq_lens.append(seq_len) @@ -766,15 +771,20 @@ def prepare_for_extend(self): # If req.input_embeds is already a list, append its content directly input_embeds.extend(req.input_embeds) # Use extend to avoid nesting - # Compute the relative logprob_start_len in an extend batch - if req.logprob_start_len >= pre_len: - extend_logprob_start_len = min( - req.logprob_start_len - pre_len, req.extend_input_len - 1 - ) - else: - extend_logprob_start_len = req.extend_input_len - 1 + if req.return_logprob: + # Compute the relative logprob_start_len in an extend batch + if req.logprob_start_len >= pre_len: + extend_logprob_start_len = min( + req.logprob_start_len - pre_len, req.extend_input_len - 1 + ) + else: + raise RuntimeError( + f"This should never happen. {req.logprob_start_len=}, {pre_len=}" + ) + req.extend_logprob_start_len = extend_logprob_start_len - req.extend_logprob_start_len = extend_logprob_start_len + req.cached_tokens += pre_len - req.already_computed + req.already_computed = seq_len req.is_retracted = False pre_lens.append(pre_len) @@ -1020,7 +1030,7 @@ def prepare_for_idle(self): self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device) - self.req_pool_indices = torch.empty(0, dtype=torch.int64, device=self.device) + self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens_sum = 0 self.extend_num_tokens = 0 self.sampling_info = SamplingBatchInfo.from_schedule_batch( @@ -1106,6 +1116,8 @@ def filter_batch( self.has_grammar = any(req.grammar for req in self.reqs) self.sampling_info.filter_batch(keep_indices, new_indices) + if self.spec_info: + self.spec_info.filter_batch(new_indices) def merge_batch(self, other: "ScheduleBatch"): # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because @@ -1200,6 +1212,7 @@ def copy(self): return_logprob=self.return_logprob, decoding_reqs=self.decoding_reqs, spec_algorithm=self.spec_algorithm, + enable_custom_logit_processor=self.enable_custom_logit_processor, ) def __str__(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index bdcc56186c8..79d4db114e8 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -34,6 +34,7 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.constrained.base_grammar_backend import create_grammar_backend from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -97,7 +98,7 @@ set_random_seed, suppress_other_loggers, ) -from sglang.utils import get_exception_traceback +from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) @@ -148,10 +149,9 @@ def __init__( if not self.spec_algorithm.is_none() else 1 ) + self.enable_hierarchical_cache = server_args.enable_hierarchical_cache - # Init inter-process communication - context = zmq.Context(2) - + # Distributed rank info self.dp_size = server_args.dp_size self.attn_tp_rank, self.attn_tp_size, self.dp_rank = ( compute_dp_attention_world_info( @@ -162,6 +162,8 @@ def __init__( ) ) + # Init inter-process communication + context = zmq.Context(2) if self.attn_tp_rank == 0: self.recv_from_tokenizer = get_zmq_socket( context, zmq.PULL, port_args.scheduler_input_ipc_name, False @@ -206,6 +208,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.tokenizer = self.processor.tokenizer else: @@ -213,6 +216,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) # Check whether overlap can be enabled @@ -241,7 +245,7 @@ def __init__( nccl_port=port_args.nccl_port, ) - # Launch worker for speculative decoding if need + # Launch a worker for speculative decoding if needed if self.spec_algorithm.is_eagle(): from sglang.srt.speculative.eagle_worker import EAGLEWorker @@ -278,6 +282,7 @@ def __init__( # Print debug info logger.info( f"max_total_num_tokens={self.max_total_num_tokens}, " + f"chunked_prefill_size={server_args.chunked_prefill_size}, " f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_running_requests={self.max_running_requests}, " f"context_len={self.model_config.context_len}" @@ -314,6 +319,8 @@ def __init__( self.forward_ct = 0 self.forward_ct_decode = 0 self.num_generated_tokens = 0 + self.spec_num_total_accepted_tokens = 0 + self.spec_num_total_forward_ct = 0 self.last_decode_stats_tic = time.time() self.stream_interval = server_args.stream_interval self.current_stream = torch.get_device_module(self.device).current_stream() @@ -335,28 +342,9 @@ def __init__( # Init the grammar backend for constrained generation self.grammar_queue: List[Req] = [] if not server_args.skip_tokenizer_init: - if server_args.grammar_backend == "outlines": - from sglang.srt.constrained.outlines_backend import ( - OutlinesGrammarBackend, - ) - - self.grammar_backend = OutlinesGrammarBackend( - self.tokenizer, - whitespace_pattern=server_args.constrained_json_whitespace_pattern, - allow_jump_forward=not server_args.disable_jump_forward, - ) - elif server_args.grammar_backend == "xgrammar": - from sglang.srt.constrained.xgrammar_backend import ( - XGrammarGrammarBackend, - ) - - self.grammar_backend = XGrammarGrammarBackend( - self.tokenizer, vocab_size=self.model_config.vocab_size - ) - else: - raise ValueError( - f"Invalid grammar backend: {server_args.grammar_backend}" - ) + self.grammar_backend = create_grammar_backend( + server_args, self.tokenizer, self.model_config.vocab_size + ) else: self.grammar_backend = None @@ -422,6 +410,40 @@ def __init__( }, ) + # The largest prefill length of a single request + self._largest_prefill_len: int = 0 + # The largest context length (prefill + generation) of a single request + self._largest_prefill_decode_len: int = 0 + + # Init request dispatcher + self._request_dispatcher = TypeBasedDispatcher( + [ + (TokenizedGenerateReqInput, self.handle_generate_request), + (TokenizedEmbeddingReqInput, self.handle_embedding_request), + (FlushCacheReq, self.flush_cache_wrapped), + (AbortReq, self.abort_request), + (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), + (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), + ( + UpdateWeightsFromDistributedReqInput, + self.update_weights_from_distributed, + ), + (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), + (GetWeightsByNameReqInput, self.get_weights_by_name), + (ProfileReq, self.profile), + (OpenSessionReqInput, self.open_session), + (CloseSessionReqInput, self.close_session), + ( + ReleaseMemoryOccupationReqInput, + lambda _: self.release_memory_occupation(), + ), + ( + ResumeMemoryOccupationReqInput, + lambda _: self.resume_memory_occupation(), + ), + ] + ) + def watchdog_thread(self): """A watch dog thread that will try to kill the server itself if one batch takes too long.""" self.watchdog_last_forward_ct = 0 @@ -450,10 +472,6 @@ def event_loop_normal(self): self.process_input_requests(recv_reqs) batch = self.get_next_batch_to_run() - - if self.server_args.enable_dp_attention: # TODO: simplify this - batch = self.prepare_dp_attn_batch(batch) - self.cur_batch = batch if batch: @@ -469,25 +487,21 @@ def event_loop_normal(self): @torch.no_grad() def event_loop_overlap(self): """A scheduler loop that overlaps the CPU processing and GPU computation.""" - result_queue = deque() + self.result_queue = deque() while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) batch = self.get_next_batch_to_run() - - if self.server_args.enable_dp_attention: # TODO: simplify this - batch = self.prepare_dp_attn_batch(batch) - self.cur_batch = batch if batch: result = self.run_batch(batch) - result_queue.append((batch.copy(), result)) + self.result_queue.append((batch.copy(), result)) if self.last_batch is None: - # Create a dummy first batch to start the pipeline for overlap scheduler. + # Create a dummy first batch to start the pipeline for overlap schedule. # It is now used for triggering the sampling_info_done event. tmp_batch = ScheduleBatch( reqs=None, @@ -498,7 +512,7 @@ def event_loop_overlap(self): if self.last_batch: # Process the results of the last batch - tmp_batch, tmp_result = result_queue.popleft() + tmp_batch, tmp_result = self.result_queue.popleft() tmp_batch.next_batch_sampling_info = ( self.tp_worker.cur_sampling_info if batch else None ) @@ -563,57 +577,9 @@ def recv_requests(self) -> List[Req]: def process_input_requests(self, recv_reqs: List): for recv_req in recv_reqs: - if isinstance(recv_req, TokenizedGenerateReqInput): - self.handle_generate_request(recv_req) - elif isinstance(recv_req, TokenizedEmbeddingReqInput): - self.handle_embedding_request(recv_req) - elif isinstance(recv_req, FlushCacheReq): - self.flush_cache() - elif isinstance(recv_req, AbortReq): - self.abort_request(recv_req) - elif isinstance(recv_req, UpdateWeightFromDiskReqInput): - success, message = self.update_weights_from_disk(recv_req) - self.send_to_tokenizer.send_pyobj( - UpdateWeightFromDiskReqOutput(success, message) - ) - elif isinstance(recv_req, InitWeightsUpdateGroupReqInput): - success, message = self.init_weights_update_group(recv_req) - self.send_to_tokenizer.send_pyobj( - InitWeightsUpdateGroupReqOutput(success, message) - ) - elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput): - success, message = self.update_weights_from_distributed(recv_req) - self.send_to_tokenizer.send_pyobj( - UpdateWeightsFromDistributedReqOutput(success, message) - ) - elif isinstance(recv_req, UpdateWeightsFromTensorReqInput): - success, message = self.update_weights_from_tensor(recv_req) - self.send_to_tokenizer.send_pyobj( - UpdateWeightsFromTensorReqOutput(success, message) - ) - elif isinstance(recv_req, GetWeightsByNameReqInput): - parameter = self.get_weights_by_name(recv_req) - self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) - elif isinstance(recv_req, ReleaseMemoryOccupationReqInput): - self.release_memory_occupation() - self.send_to_tokenizer.send_pyobj(ReleaseMemoryOccupationReqOutput()) - elif isinstance(recv_req, ResumeMemoryOccupationReqInput): - self.resume_memory_occupation() - self.send_to_tokenizer.send_pyobj(ResumeMemoryOccupationReqOutput()) - elif isinstance(recv_req, ProfileReq): - if recv_req == ProfileReq.START_PROFILE: - self.start_profile() - else: - self.stop_profile() - elif isinstance(recv_req, OpenSessionReqInput): - session_id, success = self.open_session(recv_req) - self.send_to_tokenizer.send_pyobj( - OpenSessionReqOutput(session_id=session_id, success=success) - ) - elif isinstance(recv_req, CloseSessionReqInput): - self.close_session(recv_req) - else: - raise ValueError(f"Invalid request: {recv_req}") + output = self._request_dispatcher(recv_req) + if output is not None: + self.send_to_tokenizer.send_pyobj(output) def handle_generate_request( self, @@ -632,6 +598,19 @@ def handle_generate_request( fake_input_ids = [1] * seq_length recv_req.input_ids = fake_input_ids + # Handle custom logit processor passed to the request + custom_logit_processor = recv_req.custom_logit_processor + if ( + not self.server_args.enable_custom_logit_processor + and custom_logit_processor is not None + ): + logger.warning( + "The SGLang server is not configured to enable custom logit processor." + "The custom logit processor passed in will be ignored." + "Please set --enable-custom-logits-processor to enable this feature." + ) + custom_logit_processor = None + req = Req( recv_req.rid, recv_req.input_text, @@ -642,6 +621,7 @@ def handle_generate_request( stream=recv_req.stream, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, + custom_logit_processor=custom_logit_processor, eos_token_ids=self.model_config.hf_eos_token_id, ) req.tokenizer = self.tokenizer @@ -663,7 +643,7 @@ def handle_generate_request( self.waiting_queue.append(req) return - # Handle image inputs + # Handle multimodal inputs if recv_req.image_inputs is not None: image_inputs = ImageInputs.from_dict(recv_req.image_inputs) # Expand a single image token into multiple dummy tokens for receiving image embeddings @@ -687,24 +667,23 @@ def handle_generate_request( self.waiting_queue.append(req) return - # Copy more attributes - req.logprob_start_len = recv_req.logprob_start_len - - if req.logprob_start_len == -1: - # By default, only return the logprobs for output tokens - req.logprob_start_len = len(req.origin_input_ids) - 1 - # Validate prompts length error_msg = validate_input_length( req, self.max_req_input_len, self.server_args.allow_auto_truncate, ) - if error_msg: self.waiting_queue.append(req) return + # Copy more attributes + if recv_req.logprob_start_len == -1: + # By default, only return the logprobs for output tokens + req.logprob_start_len = len(req.origin_input_ids) - 1 + else: + req.logprob_start_len = recv_req.logprob_start_len + req.sampling_params.max_new_tokens = min( ( req.sampling_params.max_new_tokens @@ -752,15 +731,26 @@ def handle_embedding_request( req.tokenizer = self.tokenizer # Validate prompts length - validate_input_length( + error_msg = validate_input_length( req, self.max_req_input_len, self.server_args.allow_auto_truncate, ) + if error_msg: + self.waiting_queue.append(req) + return + # Copy more attributes + req.logprob_start_len = len(req.origin_input_ids) - 1 self.waiting_queue.append(req) - def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked): + def log_prefill_stats( + self, + adder: PrefillAdder, + can_run_list: List[Req], + running_bs: ScheduleBatch, + has_being_chunked: bool, + ): self.tree_cache_metrics["total"] += ( adder.log_input_tokens + adder.log_hit_tokens ) / 10**9 @@ -802,31 +792,56 @@ def log_decode_stats(self): self.num_generated_tokens = 0 self.last_decode_stats_tic = time.time() num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 - logger.info( - f"Decode batch. " - f"#running-req: {num_running_reqs}, " - f"#token: {num_used}, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"gen throughput (token/s): {gen_throughput:.2f}, " - f"#queue-req: {len(self.waiting_queue)}" - ) + if self.spec_algorithm.is_none(): + msg = ( + f"Decode batch. " + f"#running-req: {num_running_reqs}, " + f"#token: {num_used}, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + f"gen throughput (token/s): {gen_throughput:.2f}, " + f"#queue-req: {len(self.waiting_queue)}" + ) + spec_accept_length = 0 + else: + spec_accept_length = ( + self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct + ) + self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 + msg = ( + f"Decode batch. " + f"#running-req: {num_running_reqs}, " + f"#token: {num_used}, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + f"accept len: {spec_accept_length:.2f}, " + f"gen throughput (token/s): {gen_throughput:.2f}, " + f"#queue-req: {len(self.waiting_queue)}" + ) + + logger.info(msg) if self.enable_metrics: self.stats.num_running_reqs = num_running_reqs self.stats.num_used_tokens = num_used self.stats.token_usage = num_used / self.max_total_num_tokens self.stats.gen_throughput = gen_throughput self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.spec_accept_length = spec_accept_length self.metrics_collector.log_stats(self.stats) def check_memory(self): available_size = ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) - if available_size != self.max_total_num_tokens: + protected_size = self.tree_cache.protected_size() + memory_leak = available_size != ( + self.max_total_num_tokens + if not self.enable_hierarchical_cache + else self.max_total_num_tokens - protected_size + ) + if memory_leak: msg = ( "KV cache pool leak detected!" - f"{available_size=}, {self.max_total_num_tokens=}\n" + f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n" ) warnings.warn(msg) if crash_on_warnings(): @@ -859,16 +874,23 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: else: self.running_batch.merge_batch(self.last_batch) - # Run prefill first if possible new_batch = self.get_new_batch_prefill() if new_batch is not None: - return new_batch + # Run prefill first if possible + ret = new_batch + else: + # Run decode + if self.running_batch is None: + ret = None + else: + self.running_batch = self.update_running_batch(self.running_batch) + ret = self.running_batch - # Run decode - if self.running_batch is None: - return None - self.running_batch = self.update_running_batch(self.running_batch) - return self.running_batch + # Handle DP attention + if self.server_args.enable_dp_attention: + ret = self.prepare_dp_attn_batch(ret) + + return ret def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Check if the grammar is ready in the grammar queue @@ -934,7 +956,14 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: res = adder.add_one_req(req) if res != AddReqResult.CONTINUE: if res == AddReqResult.NO_TOKEN: - self.batch_is_full = True + if self.enable_hierarchical_cache: + # Set batch_is_full after making sure there are requests that can be served + self.batch_is_full = len(adder.can_run_list) > 0 or ( + self.running_batch is not None + and not self.running_batch.is_empty() + ) + else: + self.batch_is_full = True break if self.server_args.prefill_only_one_req: break @@ -967,6 +996,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.model_config, self.enable_overlap, self.spec_algorithm, + self.server_args.enable_custom_logit_processor, ) new_batch.prepare_for_extend() @@ -1023,7 +1053,7 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: ) # Check for jump-forward - if not self.disable_jump_forward: + if not self.disable_jump_forward and batch.has_grammar: jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) self.waiting_queue.extend(jump_forward_reqs) if batch.is_empty(): @@ -1044,37 +1074,23 @@ def run_batch( self.forward_ct += 1 if self.is_generation: - if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0: - if self.spec_algorithm.is_none(): - model_worker_batch = batch.get_model_worker_batch() - logits_output, next_token_ids = ( - self.tp_worker.forward_batch_generation(model_worker_batch) - # model_worker_batch = batch.get_model_worker_batch() - # if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: - # # FIXME(geon): handle hip refresh_interval here - # logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - # model_worker_batch - # ) - # elif batch.forward_mode.is_idle(): - # model_worker_batch = batch.get_model_worker_batch() - # self.tp_worker.forward_batch_idle(model_worker_batch) - # return - # else: - # logits_output = None - # if self.skip_tokenizer_init: - # next_token_ids = torch.full( - # (batch.batch_size(),), self.tokenizer.eos_token_id - ) - else: - ( - logits_output, - next_token_ids, - model_worker_batch, - num_accepted_tokens, - ) = self.draft_worker.forward_batch_speculative_generation(batch) - self.num_generated_tokens += num_accepted_tokens + if self.spec_algorithm.is_none(): + model_worker_batch = batch.get_model_worker_batch() + logits_output, next_token_ids = self.tp_worker.forward_batch_generation( + model_worker_batch + ) else: - assert False, "batch.extend_num_tokens == 0, this is unexpected!" + ( + logits_output, + next_token_ids, + model_worker_batch, + num_accepted_tokens, + ) = self.draft_worker.forward_batch_speculative_generation(batch) + self.spec_num_total_accepted_tokens += ( + num_accepted_tokens + batch.batch_size() + ) + self.spec_num_total_forward_ct += batch.batch_size() + self.num_generated_tokens += num_accepted_tokens batch.output_ids = next_token_ids ret = GenerationBatchResult( @@ -1083,7 +1099,6 @@ def run_batch( bid=model_worker_batch.bid, ) else: # embedding or reward model - assert batch.extend_num_tokens != 0 model_worker_batch = batch.get_model_worker_batch() embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) ret = EmbeddingBatchResult( @@ -1382,6 +1397,7 @@ def stream_output( prompt_tokens = [] completion_tokens = [] cached_tokens = [] + spec_verify_ct = [] if return_logprob: input_token_logprobs_val = [] @@ -1435,6 +1451,9 @@ def stream_output( completion_tokens.append(len(req.output_ids)) cached_tokens.append(req.cached_tokens) + if not self.spec_algorithm.is_none(): + spec_verify_ct.append(req.spec_verify_ct) + if return_logprob: input_token_logprobs_val.append(req.input_token_logprobs_val) input_token_logprobs_idx.append(req.input_token_logprobs_idx) @@ -1462,6 +1481,7 @@ def stream_output( prompt_tokens, completion_tokens, cached_tokens, + spec_verify_ct, input_token_logprobs_val, input_token_logprobs_idx, output_token_logprobs_val, @@ -1532,6 +1552,7 @@ def get_idle_batch(self): self.model_config, self.enable_overlap, self.spec_algorithm, + self.server_args.enable_custom_logit_processor, ) idle_batch.prepare_for_idle() return idle_batch @@ -1560,6 +1581,9 @@ def move_ready_grammar_requests(self): self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs]) self.grammar_queue = self.grammar_queue[num_ready_reqs:] + def flush_cache_wrapped(self, recv_req: FlushCacheReq): + self.flush_cache() + def flush_cache(self): """Flush the memory pool and cache.""" if len(self.waiting_queue) == 0 and ( @@ -1571,6 +1595,15 @@ def flush_cache(self): self.grammar_backend.reset() self.req_to_token_pool.clear() self.token_to_kv_pool.clear() + + if not self.spec_algorithm.is_none(): + self.draft_worker.model_runner.req_to_token_pool.clear() + self.draft_worker.model_runner.token_to_kv_pool.clear() + + self.num_generated_tokens = 0 + self.forward_ct_decode = 0 + self.spec_num_total_accepted_tokens = 0 + self.spec_num_total_forward_ct = 0 torch.cuda.empty_cache() logger.info("Cache flushed successfully!") if_success = True @@ -1612,12 +1645,12 @@ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return success, message + return UpdateWeightFromDiskReqOutput(success, message) def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): """Initialize the online model parameter update group.""" success, message = self.tp_worker.init_weights_update_group(recv_req) - return success, message + return InitWeightsUpdateGroupReqOutput(success, message) def update_weights_from_distributed( self, @@ -1630,7 +1663,7 @@ def update_weights_from_distributed( assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return success, message + return UpdateWeightsFromDistributedReqOutput(success, message) def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): """Update the online model parameter from tensors.""" @@ -1641,11 +1674,11 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return success, message + return UpdateWeightsFromTensorReqOutput(success, message) def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) - return parameter + return GetWeightsByNameReqOutput(parameter) def release_memory_occupation(self): self.stashed_model_static_state = _export_static_state( @@ -1653,6 +1686,7 @@ def release_memory_occupation(self): ) self.memory_saver_adapter.pause() self.flush_cache() + return ReleaseMemoryOccupationReqOutput() def resume_memory_occupation(self): self.memory_saver_adapter.resume() @@ -1660,6 +1694,13 @@ def resume_memory_occupation(self): self.tp_worker.worker.model_runner.model, self.stashed_model_static_state ) del self.stashed_model_static_state + return ResumeMemoryOccupationReqOutput() + + def profile(self, recv_req: ProfileReq): + if recv_req == ProfileReq.START_PROFILE: + self.start_profile() + else: + self.stop_profile() def start_profile(self) -> None: if self.profiler is None: @@ -1675,20 +1716,20 @@ def stop_profile(self) -> None: ) logger.info("Profiler is done") - def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]: + def open_session(self, recv_req: OpenSessionReqInput): # handle error session_id = recv_req.session_id if session_id in self.sessions: logger.warning(f"session id {session_id} already exist, cannot open.") - return session_id, False + return OpenSessionReqOutput(session_id, False) elif session_id is None: logger.warning(f"session id is None, cannot open.") - return session_id, False + return OpenSessionReqOutput(session_id, False) else: self.sessions[session_id] = Session( recv_req.capacity_of_str_len, session_id ) - return session_id, True + return OpenSessionReqOutput(session_id, True) def close_session(self, recv_req: CloseSessionReqInput): # handle error diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index e9c0c909d52..4f4af636757 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -131,6 +131,7 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer): sampling_params=req.sampling_params, lora_path=req.lora_path, session_id=self.session_id, + custom_logit_processor=req.custom_logit_processor, ) if last_req is not None: new_req.image_inputs = last_req.image_inputs diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 8e7f21d9565..da665ea24bc 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -80,7 +80,7 @@ get_zmq_socket, kill_process_tree, ) -from sglang.utils import get_exception_traceback +from sglang.utils import TypeBasedDispatcher, get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -117,6 +117,7 @@ def __init__( self.server_args = server_args self.enable_metrics = server_args.enable_metrics self.log_requests = server_args.log_requests + self.log_requests_level = 0 # Init inter-process communication context = zmq.asyncio.Context(2) @@ -157,6 +158,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.tokenizer = self.processor.tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -170,10 +172,11 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) # Store states - self.to_create_loop = True + self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} self.dump_requests_folder = "" # By default do not dump self.dump_requests_threshold = 1000 @@ -221,6 +224,44 @@ def __init__( }, ) + self._result_dispatcher = TypeBasedDispatcher( + [ + ( + (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut), + self._handle_batch_output, + ), + (OpenSessionReqOutput, self._handle_open_session_req_output), + ( + UpdateWeightFromDiskReqOutput, + self._handle_update_weights_from_disk_req_output, + ), + ( + InitWeightsUpdateGroupReqOutput, + self.init_weights_update_group_communicator.handle_recv, + ), + ( + UpdateWeightsFromDistributedReqOutput, + self.update_weights_from_distributed_communicator.handle_recv, + ), + ( + UpdateWeightsFromTensorReqOutput, + self.update_weights_from_tensor_communicator.handle_recv, + ), + ( + GetWeightsByNameReqOutput, + self.get_weights_by_name_communicator.handle_recv, + ), + ( + ReleaseMemoryOccupationReqOutput, + self.release_memory_occupation_communicator.handle_recv, + ), + ( + ResumeMemoryOccupationReqOutput, + self.resume_memory_occupation_communicator.handle_recv, + ), + ] + ) + async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -239,7 +280,10 @@ async def generate_request( obj.normalize_batch_and_arguments() if self.log_requests: - logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}") + max_length = 2048 if self.log_requests_level == 0 else 1 << 30 + logger.info( + f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}" + ) async with self.model_update_lock.reader_lock: is_single = obj.is_single @@ -348,6 +392,7 @@ async def _tokenize_one_request( lora_path=obj.lora_path, input_embeds=input_embeds, session_params=session_params, + custom_logit_processor=obj.custom_logit_processor, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( @@ -392,7 +437,8 @@ async def _wait_one_response( state.out_list = [] if state.finished: if self.log_requests: - msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}" + max_length = 2048 if self.log_requests_level == 0 else 1 << 30 + msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}" logger.info(msg) del self.rid_to_state[obj.rid] @@ -649,12 +695,13 @@ async def open_session( async def close_session( self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None ): - assert not self.to_create_loop, "close session should not be the first request" await self.send_to_scheduler.send_pyobj(obj) def configure_logging(self, obj: ConfigureLoggingReq): if obj.log_requests is not None: self.log_requests = obj.log_requests + if obj.log_requests_level is not None: + self.log_requests_level = obj.log_requests_level if obj.dump_requests_folder is not None: self.dump_requests_folder = obj.dump_requests_folder if obj.dump_requests_threshold is not None: @@ -676,10 +723,10 @@ async def abort_request(): return background_tasks def auto_create_handle_loop(self): - if not self.to_create_loop: + if self.no_create_loop: return - self.to_create_loop = False + self.no_create_loop = True loop = asyncio.get_event_loop() self.asyncio_tasks.add( loop.create_task(print_exception_wrapper(self.handle_loop)) @@ -722,110 +769,68 @@ async def handle_loop(self): """The event loop that handles requests""" while True: - recv_obj: Union[ - BatchStrOut, - BatchEmbeddingOut, - BatchTokenIDOut, - UpdateWeightFromDiskReqOutput, - UpdateWeightsFromDistributedReqOutput, - GetWeightsByNameReqOutput, - InitWeightsUpdateGroupReqOutput, - ReleaseMemoryOccupationReqOutput, - ResumeMemoryOccupationReqOutput, - ] = await self.recv_from_detokenizer.recv_pyobj() - - if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)): - for i, rid in enumerate(recv_obj.rids): - state = self.rid_to_state.get(rid, None) - if state is None: - continue - - meta_info = { - "id": rid, - "finish_reason": recv_obj.finished_reasons[i], - "prompt_tokens": recv_obj.prompt_tokens[i], - } + recv_obj = await self.recv_from_detokenizer.recv_pyobj() + self._result_dispatcher(recv_obj) - if getattr(state.obj, "return_logprob", False): - self.convert_logprob_style( - meta_info, - state.obj.top_logprobs_num, - state.obj.return_text_in_logprobs, - recv_obj, - i, - ) - - if not isinstance(recv_obj, BatchEmbeddingOut): - meta_info.update( - { - "completion_tokens": recv_obj.completion_tokens[i], - "cached_tokens": recv_obj.cached_tokens[i], - } - ) - - if isinstance(recv_obj, BatchStrOut): - out_dict = { - "text": recv_obj.output_strs[i], - "meta_info": meta_info, - } - elif isinstance(recv_obj, BatchTokenIDOut): - out_dict = { - "token_ids": recv_obj.output_ids[i], - "meta_info": meta_info, - } - else: - assert isinstance(recv_obj, BatchEmbeddingOut) - out_dict = { - "embedding": recv_obj.embeddings[i], - "meta_info": meta_info, - } - state.out_list.append(out_dict) - state.finished = recv_obj.finished_reasons[i] is not None - state.event.set() - - if self.enable_metrics and state.obj.log_metrics: - self.collect_metrics(state, recv_obj, i) - if ( - self.dump_requests_folder - and state.finished - and state.obj.log_metrics - ): - self.dump_requests(state, out_dict) - elif isinstance(recv_obj, OpenSessionReqOutput): - self.session_futures[recv_obj.session_id].set_result( - recv_obj.session_id if recv_obj.success else None + def _handle_batch_output( + self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] + ): + for i, rid in enumerate(recv_obj.rids): + state = self.rid_to_state.get(rid, None) + if state is None: + continue + + meta_info = { + "id": rid, + "finish_reason": recv_obj.finished_reasons[i], + "prompt_tokens": recv_obj.prompt_tokens[i], + } + + if getattr(state.obj, "return_logprob", False): + self.convert_logprob_style( + meta_info, + state.obj.top_logprobs_num, + state.obj.return_text_in_logprobs, + recv_obj, + i, ) - elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput): - if self.server_args.dp_size == 1: - self.model_update_result.set_result(recv_obj) - else: # self.server_args.dp_size > 1 - self.model_update_tmp.append(recv_obj) - # set future if the all results are recevied - if len(self.model_update_tmp) == self.server_args.dp_size: - self.model_update_result.set_result(self.model_update_tmp) - elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" - self.init_weights_update_group_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for update weights from distributed" - self.update_weights_from_distributed_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for update weights from distributed" - self.update_weights_from_tensor_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, GetWeightsByNameReqOutput): - self.get_weights_by_name_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput): - self.release_memory_occupation_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput): - self.resume_memory_occupation_communicator.handle_recv(recv_obj) + + if self.server_args.speculative_algorithm: + meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] + + if not isinstance(recv_obj, BatchEmbeddingOut): + meta_info.update( + { + "completion_tokens": recv_obj.completion_tokens[i], + "cached_tokens": recv_obj.cached_tokens[i], + } + ) + + if isinstance(recv_obj, BatchStrOut): + out_dict = { + "text": recv_obj.output_strs[i], + "meta_info": meta_info, + } + elif isinstance(recv_obj, BatchTokenIDOut): + out_dict = { + "token_ids": recv_obj.output_ids[i], + "meta_info": meta_info, + } else: - raise ValueError(f"Invalid object: {recv_obj=}") + assert isinstance(recv_obj, BatchEmbeddingOut) + out_dict = { + "embedding": recv_obj.embeddings[i], + "meta_info": meta_info, + } + + state.out_list.append(out_dict) + state.finished = recv_obj.finished_reasons[i] is not None + state.event.set() + + if self.enable_metrics and state.obj.log_metrics: + self.collect_metrics(state, recv_obj, i) + if self.dump_requests_folder and state.finished and state.obj.log_metrics: + self.dump_requests(state, out_dict) def convert_logprob_style( self, @@ -953,6 +958,20 @@ def background_task(): # Schedule the task to run in the background without awaiting it asyncio.create_task(asyncio.to_thread(background_task)) + def _handle_open_session_req_output(self, recv_obj): + self.session_futures[recv_obj.session_id].set_result( + recv_obj.session_id if recv_obj.success else None + ) + + def _handle_update_weights_from_disk_req_output(self, recv_obj): + if self.server_args.dp_size == 1: + self.model_update_result.set_result(recv_obj) + else: # self.server_args.dp_size > 1 + self.model_update_tmp.append(recv_obj) + # set future if the all results are recevied + if len(self.model_update_tmp) == self.server_args.dp_size: + self.model_update_result.set_result(self.model_update_tmp) + async def print_exception_wrapper(func): """ diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index ff63a934fad..891cacd34b7 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -87,6 +87,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.tokenizer = self.processor.tokenizer else: @@ -94,6 +95,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.device = self.model_runner.device diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index acdd2898ffa..9386595a8bd 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -41,6 +41,10 @@ def dec_lock_ref(self, node): def evictable_size(self): pass + @abstractmethod + def protected_size(self): + raise NotImplementedError() + def total_size(self): raise NotImplementedError() diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index ab8965a0189..b50199ca28a 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -85,3 +85,6 @@ def dec_lock_ref(self, node): def evictable_size(self): return 0 + + def protected_size(self): + return 0 diff --git a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py index 2950f4fd84e..7adbf0c9ee3 100644 --- a/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py +++ b/python/sglang/srt/mem_cache/hip_offload_kv_pool_mha.py @@ -115,6 +115,7 @@ def __init__( head_dim=head_dim, layer_num=layer_num, device=self.device, + enable_memory_saver=False, ) else: self.validation_cache = None diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index f38258caa92..1e0fb755055 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -27,7 +27,7 @@ import threading from enum import IntEnum from functools import wraps -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import psutil @@ -50,7 +50,6 @@ def __init__( size: int, max_context_len: int, device: str, - use_records: bool, enable_memory_saver: bool, ): memory_saver_adapter = TorchMemorySaverAdapter.create( @@ -65,17 +64,9 @@ def __init__( (size, max_context_len), dtype=torch.int32, device=device ) self.free_slots = list(range(size)) - self.write_records = [] - self.use_records = use_records - - if self.use_records: - self.write = self.write_with_records - else: - self.write = self.write_without_records def write(self, indices, values): - # Keep the signature for type checking. It will be assigned during runtime. - raise NotImplementedError() + self.req_to_token[indices] = values def available_size(self): return len(self.free_slots) @@ -97,23 +88,6 @@ def free(self, free_index: Union[int, List[int]]): def clear(self): self.free_slots = list(range(self.size)) - self.write_records = [] - - def write_without_records(self, indices, values): - self.req_to_token[indices] = values - - def write_with_records(self, indices, values): - self.req_to_token[indices] = values - self.write_records.append((indices, values)) - - def get_write_records(self): - ret = self.write_records - self.write_records = [] - return ret - - def apply_write_records(self, write_records: List[Tuple]): - for indices, values in write_records: - self.req_to_token[indices] = values class BaseTokenToKVPool: @@ -309,13 +283,17 @@ def set_kv_buffer( loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, ): layer_id = layer.layer_id if cache_k.dtype != self.dtype: - cache_k = (cache_k / k_scale).to(self.dtype) - cache_v = (cache_v / v_scale).to(self.dtype) + if k_scale is not None: + cache_k.div_(k_scale) + if v_scale is not None: + cache_v.div_(v_scale) + cache_k = cache_k.to(self.dtype) + cache_v = cache_v.to(self.dtype) if self.store_dtype != self.dtype: self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 1673d4f0c3d..3bf87b54299 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -34,7 +34,10 @@ class TreeNode: - def __init__(self): + + counter = 0 + + def __init__(self, id: Optional[int] = None): self.children = defaultdict(TreeNode) self.parent = None self.key = None @@ -42,6 +45,23 @@ def __init__(self): self.lock_ref = 0 self.last_access_time = time.time() + self.hit_count = 0 + # indicating the node is loading KV cache from host + self.loading = False + # store the host indices of KV cache + self.host_value = None + + self.id = TreeNode.counter if id is None else id + TreeNode.counter += 1 + + @property + def evicted(self): + return self.value is None + + @property + def backuped(self): + return self.host_value is not None + def __lt__(self, other: "TreeNode"): return self.last_access_time < other.last_access_time @@ -75,6 +95,7 @@ def reset(self): self.root_node.value = [] self.root_node.lock_ref = 1 self.evictable_size_ = 0 + self.protected_size_ = 0 def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: """Find the matching prefix from the radix tree. @@ -203,6 +224,7 @@ def inc_lock_ref(self, node: TreeNode): while node != self.root_node: if node.lock_ref == 0: self.evictable_size_ -= len(node.value) + self.protected_size_ += len(node.value) delta -= len(node.value) node.lock_ref += 1 node = node.parent @@ -216,6 +238,7 @@ def dec_lock_ref(self, node: TreeNode): while node != self.root_node: if node.lock_ref == 1: self.evictable_size_ += len(node.value) + self.protected_size_ -= len(node.value) delta += len(node.value) node.lock_ref -= 1 node = node.parent @@ -224,6 +247,10 @@ def dec_lock_ref(self, node: TreeNode): def evictable_size(self): return self.evictable_size_ + def protected_size(self): + # protected size refers to the size of the cache that is locked + return self.protected_size_ + ##### Internal Helper Functions ##### def _match_prefix_helper( @@ -303,6 +330,8 @@ def _delete_leaf(self, node): self.evictable_size_ -= len(node.key) def _total_size_helper(self, node: TreeNode): + if node.evicted: + return 0 x = len(node.value) for child in node.children.values(): x += self._total_size_helper(child) diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index 070b405be42..26eb2fc27d2 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -25,6 +25,7 @@ class SchedulerStats: gen_throughput: float = 0.0 num_queue_reqs: int = 0 cache_hit_rate: float = 0.0 + spec_accept_length: float = 0.0 class SchedulerMetricsCollector: @@ -37,42 +38,49 @@ def __init__(self, labels: Dict[str, str]) -> None: self.num_running_reqs = Gauge( name="sglang:num_running_reqs", - documentation="The number of running requests", + documentation="The number of running requests.", labelnames=labels.keys(), multiprocess_mode="sum", ) self.num_used_tokens = Gauge( name="sglang:num_used_tokens", - documentation="The number of used tokens", + documentation="The number of used tokens.", labelnames=labels.keys(), multiprocess_mode="sum", ) self.token_usage = Gauge( name="sglang:token_usage", - documentation="The token usage", + documentation="The token usage.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) self.gen_throughput = Gauge( name="sglang:gen_throughput", - documentation="The generate throughput (token/s)", + documentation="The generation throughput (token/s).", labelnames=labels.keys(), multiprocess_mode="sum", ) self.num_queue_reqs = Gauge( name="sglang:num_queue_reqs", - documentation="The number of requests in the waiting queue", + documentation="The number of requests in the waiting queue.", labelnames=labels.keys(), multiprocess_mode="sum", ) self.cache_hit_rate = Gauge( name="sglang:cache_hit_rate", - documentation="The cache hit rate", + documentation="The prefix cache hit rate.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + self.spec_accept_length = Gauge( + name="sglang:spec_accept_length", + documentation="The average acceptance length of speculative decoding.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) @@ -88,6 +96,7 @@ def log_stats(self, stats: SchedulerStats) -> None: self._log_gauge(self.gen_throughput, stats.gen_throughput) self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) + self._log_gauge(self.spec_accept_length, stats.spec_accept_length) class TokenizerMetricsCollector: diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c90719cc168..23b45f9ee19 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -25,7 +25,7 @@ from vllm.model_executor.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.distributed.parallel_state import graph_capture +from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.torchao_utils import save_gemlite_cache @@ -34,13 +34,12 @@ ForwardBatch, ForwardMode, ) -from sglang.srt.utils import monkey_patch_vllm_all_gather if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner -def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): +def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): for sub in model._modules.values(): if isinstance(sub, CustomOp): if reverse: @@ -49,7 +48,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): else: # NOTE: Temporarily workaround MoE if "FusedMoE" in sub.__class__.__name__: - if batch_size == 1: + if num_tokens == 1: # The performance of torch.compile on this layer is not always good when bs > 1, # so we decide to only use torch.compile when bs =1 sub._forward_method = fused_moe_forward_native @@ -57,23 +56,22 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): sub._forward_method = sub.forward_native setattr(sub, "is_torch_compile", True) if isinstance(sub, torch.nn.Module): - _to_torch(sub, reverse, batch_size) + _to_torch(sub, reverse, num_tokens) @contextmanager def patch_model( model: torch.nn.Module, enable_compile: bool, - batch_size: int, - tp_group: "GroupCoordinator", + num_tokens: int, + tp_group: GroupCoordinator, ): """Patch the model to make it compatible with with torch.compile""" backup_ca_comm = None try: if enable_compile: - _to_torch(model, reverse=False, batch_size=batch_size) - monkey_patch_vllm_all_gather() + _to_torch(model, reverse=False, num_tokens=num_tokens) backup_ca_comm = tp_group.ca_comm # Use custom-allreduce here. # We found the custom allreduce is much faster than the built-in allreduce in torch, @@ -88,8 +86,7 @@ def patch_model( yield model.forward finally: if enable_compile: - _to_torch(model, reverse=True, batch_size=batch_size) - monkey_patch_vllm_all_gather(reverse=True) + _to_torch(model, reverse=True, num_tokens=num_tokens) tp_group.ca_comm = backup_ca_comm @@ -153,9 +150,18 @@ def __init__(self, model_runner: "ModelRunner"): and bs <= model_runner.server_args.cuda_graph_max_bs ] + self.compile_bs = ( + [ + bs + for bs in self.capture_bs + if bs <= self.model_runner.server_args.torch_compile_max_bs + ] + if self.use_torch_compile + else [] + ) + self.capture_forward_mode = ForwardMode.DECODE self.num_tokens_per_bs = 1 - if model_runner.spec_algorithm.is_eagle(): if self.model_runner.is_draft_worker: self.num_tokens_per_bs = ( @@ -167,16 +173,6 @@ def __init__(self, model_runner: "ModelRunner"): self.model_runner.server_args.speculative_num_draft_tokens ) - self.compile_bs = ( - [ - bs - for bs in self.capture_bs - if bs <= self.model_runner.server_args.torch_compile_max_bs - ] - if self.use_torch_compile - else [] - ) - # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs @@ -184,7 +180,6 @@ def __init__(self, model_runner: "ModelRunner"): self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) - # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 @@ -193,14 +188,14 @@ def __init__(self, model_runner: "ModelRunner"): # Common inputs with torch.device("cuda"): - self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32) + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32) + self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) # Speculative_inference if model_runner.spec_algorithm.is_eagle(): @@ -289,8 +284,8 @@ def capture(self): with patch_model( self.model_runner.model, bs in self.compile_bs, - bs, - self.model_runner.tp_group, + num_tokens=bs * self.num_tokens_per_bs, + tp_group=self.model_runner.tp_group, ) as forward: ( graph, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 98f74a6def9..aff37773bb5 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -38,7 +38,7 @@ import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding -from sglang.srt.utils import maybe_torch_compile +from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: from sglang.srt.layers.attention import AttentionBackend @@ -288,6 +288,9 @@ def init_new( can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, lora_paths=batch.lora_paths, sampling_info=batch.sampling_info, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool=model_runner.token_to_kv_pool, + attn_backend=model_runner.attn_backend, spec_algorithm=batch.spec_algorithm, spec_info=batch.spec_info, capture_hidden_mode=batch.capture_hidden_mode, @@ -342,11 +345,6 @@ def init_new( if model_runner.model_is_mrope: ret.compute_mrope_positions(model_runner, batch) - # Init attention information - ret.req_to_token_pool = model_runner.req_to_token_pool - ret.token_to_kv_pool = model_runner.token_to_kv_pool - ret.attn_backend = model_runner.attn_backend - # Init HiP attention information if hasattr(model_runner, "hip_metadata_cache_pool"): ret.hip_metadata_cache_pool = model_runner.hip_metadata_cache_pool @@ -441,6 +439,6 @@ def compute_position_torch( return positions.to(torch.int64), extend_start_loc -@maybe_torch_compile(dynamic=True) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def clamp_position(seq_lens): return torch.clamp((seq_lens - 1), min=0).to(torch.int64) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9c5c0536c9f..dacf51eab55 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -66,8 +66,8 @@ init_custom_process_group, is_cuda, is_hip, + monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, - monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, ) @@ -188,9 +188,12 @@ def __init__( self.load_model() # Apply torchao quantization - apply_torchao_config_to_model( - self.model, global_server_args_dict["torchao_config"] - ) + torchao_applied = getattr(self.model, "torchao_applied", False) + # In layered loading, torchao may have been applied + if not torchao_applied: + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] + ) # Apply torch TP if the model supports it supports_torch_tp = getattr(self.model, "supports_torch_tp", False) @@ -218,7 +221,7 @@ def __init__( def init_torch_distributed(self): logger.info("Init torch distributed begin.") - # Init torch distributed + torch.get_device_module(self.device).set_device(self.gpu_id) if self.device == "cuda": backend = "nccl" @@ -232,7 +235,8 @@ def init_torch_distributed(self): backend = "gloo" if not self.server_args.enable_p2p_check: - monkey_patch_vllm_p2p_access_check(self.gpu_id) + monkey_patch_p2p_access_check() + if self.server_args.dist_init_addr: dist_init_method = f"tcp://{self.server_args.dist_init_addr}" else: @@ -634,7 +638,6 @@ def init_memory_pool( size=max_num_reqs + 1, max_context_len=self.model_config.context_len + 4, device=self.device, - use_records=False, enable_memory_saver=self.server_args.enable_memory_saver, ) if ( diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 677d716d43b..9e6b09488e6 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -374,6 +374,78 @@ def load_model( return model.eval() +class LayeredModelLoader(DefaultModelLoader): + """Model loader that loads weights layer by layer so that one can quantize a + layer before loading another to make the peak memory envelope smaller.""" + + def __init__(self, load_config: LoadConfig): + # Back to the default load format + load_config.load_format = LoadFormat.AUTO + super().__init__(load_config) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model + from sglang.srt.managers.schedule_batch import global_server_args_dict + + torchao_config = global_server_args_dict.get("torchao_config") + target_device = torch.device(device_config.device) + + with set_default_torch_dtype(model_config.dtype): + # Create model on meta device + with torch.device("meta"): + model = _initialize_model( + model_config, + self.load_config, + ) + + # Check model's layered load support + if not hasattr(model, "load_weights_to_module"): + raise ValueError( + "LayeredModelLoader requires the model to have a " + "`load_weights_to_module` method. " + f"{model_config.model_path} does not support it." + ) + + # Get all weights from disk + weights = self._get_all_weights(model_config, model) + + # Helper function to recursively fill the weights of a module + def fill_module(module, fqn: List[str], weights): + """ + fqn: list of strings representing the fully qualified name of `module`. + """ + # Layer by layer + for name, submod in module.named_children(): + fill_module(submod, fqn + [name], weights) + + # First materialize on target device + module.to_empty(device=target_device, recurse=False) + fqn_path = ".".join(fqn) + # Fill weights + model.load_weights_to_module( + fqn_path, + weights, + ) + # Quantize weights if applicable + if torchao_config and "proj" in fqn_path: + # Note: `None` here is needed to indicate no filter, see + # `apply_torchao_config_to_model` for details. + apply_torchao_config_to_model(module, torchao_config, None) + + # Start calling on root module + fill_module(model, [], weights) + + if torchao_config: + model.torchao_applied = True + + return model.eval() + + class DummyModelLoader(BaseModelLoader): """Model loader that will set model weights to random values.""" @@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.GGUF: return GGUFModelLoader(load_config) + if load_config.load_format == LoadFormat.LAYERED: + return LayeredModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 77c3fcbee74..c07a346f471 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -27,6 +27,7 @@ import numpy as np import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm @@ -403,8 +404,13 @@ def np_cache_weights_iterator( def safetensors_weights_iterator( hf_weights_files: List[str], + is_all_weights_sharded: bool = False, ) -> Generator[Tuple[str, torch.Tensor], None, None]: - """Iterate over the weights in the model safetensor files.""" + """Iterate over the weights in the model safetensor files. + + If is_all_weights_sharded is True, it uses more optimize read by reading an + entire file instead of reading each tensor one by one. + """ enable_tqdm = ( not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) @@ -414,9 +420,14 @@ def safetensors_weights_iterator( disable=not enable_tqdm, bar_format=_BAR_FORMAT, ): - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) + if not is_all_weights_sharded: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + else: + result = load_file(st_file, device="cpu") + for name, param in result.items(): yield name, param @@ -650,6 +661,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: return name +# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py +class KVCacheQuantSchema(BaseModel): + dtype: str + # Each key is a TP rank. Each value is a dictionary mapping a TP rank's + # layer indices to their per-tensor KV cache scaling factor. + # TODO: Consider pulling this and its validation methods out into its + # own schema class (tricky as its members are variable) + scaling_factor: Dict[int, Dict[int, float]] + + @model_validator(mode="after") + def check_is_fp8(self) -> "KVCacheQuantSchema": + assert self.dtype == "float8_e4m3fn", ( + "Loaded scaling factors intended for KV cache dtype = " + f"{self.dtype} rather than float8_e4m3fn!" + ) + return self + + @model_validator(mode="after") + def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_size = context["tp_size"] + num_hidden_layers = context["num_hidden_layers"] + assert len(self.scaling_factor) == tp_size, ( + f"Loaded dictionary has TP size {len(self.scaling_factor)} " + f"but LLM engine is currently running with TP size {tp_size}." + ) + for tp_rank, layer_maps in self.scaling_factor.items(): + assert len(layer_maps) == num_hidden_layers, ( + f"KV cache scales map for TP rank {tp_rank} is malformed. " + f"Expected {num_hidden_layers} layers, got " + f"{len(layer_maps)}." + ) + for i in range(tp_size): + assert ( + i in self.scaling_factor + ), f"KV cache scales map for TP rank {i} not found." + return self + + @model_validator(mode="after") + def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_rank = context["tp_rank"] + num_hidden_layers = context["num_hidden_layers"] + layer_scales_map = self.scaling_factor[tp_rank] + for i in range(num_hidden_layers): + assert i in layer_scales_map, ( + f"Could not find KV cache scales for layer {i} in " + f"TP rank {tp_rank}." + ) + return self + + +class QuantParamSchema(BaseModel): + # TODO: Generalize and extend with more fields + # (e.g. weights/activations params) once functionality is enabled + model_config = ConfigDict(protected_namespaces=()) + model_type: Optional[str] + kv_cache: KVCacheQuantSchema + + @model_validator(mode="after") + def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": + context = info.context + if context: + model_type = context.get("model_type", None) + if model_type is not None: + assert model_type == self.model_type, ( + f"Model type is {model_type} but loaded " + f"scaling factors belonging to different " + f"model type {self.model_type}!" + ) + return self + + def kv_cache_scales_loader( filename: str, tp_rank: int, @@ -681,7 +767,7 @@ def kv_cache_scales_loader( except json.JSONDecodeError: logger.error("Error decoding JSON in file '%s'.", filename) except Exception: - logger.exception("An error occurred while reading '%s'.", filename) + logger.error("An error occurred while reading '%s'.", filename) # This section is reached if and only if any of the excepts are hit # Return an empty iterable (list) => no KV cache scales are loaded # which ultimately defaults to 1.0 scales diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 151087732f0..e4b291b66cb 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -61,7 +61,10 @@ from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import get_compiler_backend, set_weight_attrs @@ -372,10 +375,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) -EntryClass = CohereForCausalLM +class Cohere2ForCausalLM(CohereForCausalLM): + pass + + +EntryClass = [CohereForCausalLM, Cohere2ForCausalLM] diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index cedc9639220..92fc679391f 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -42,7 +42,10 @@ VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import set_weight_attrs @@ -411,6 +414,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, weight_name) break else: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0d327c0ca97..4384410476c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -48,7 +48,7 @@ normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -56,12 +56,12 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_flashinfer_available, is_hip +from sglang.srt.utils import is_cuda_available, is_hip is_hip_ = is_hip() -if is_flashinfer_available(): - from flashinfer import bmm_fp8 +if is_cuda_available(): + from sgl_kernel import bmm_fp8 class DeepseekV2MLP(nn.Module): @@ -271,7 +271,7 @@ def __init__( quant_config=quant_config, ) rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope( + self.rotary_emb = get_rope_wrapper( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 4d21901de7c..06a7b030260 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -35,7 +35,10 @@ from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import make_layers @@ -424,6 +427,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index c13d3e25368..0471e37d982 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -133,6 +133,7 @@ def __init__( renormalize=False, quant_config=quant_config, tp_size=tp_size, + activation="gelu", use_presharded_weights=use_presharded_weights, ) diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 118be8ff6c8..31ea7cd9f25 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -40,10 +40,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer import bmm_fp8 +if is_cuda_available(): + from sgl_kernel import bmm_fp8 class MiniCPM3MLP(nn.Module): diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index 5ff941b6c27..7b02b4cedbb 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -1,6 +1,6 @@ # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. +# Copyright 2023 The SGLang team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" -from functools import cached_property, partial +from functools import partial from typing import ( Any, Callable, @@ -33,18 +33,15 @@ Union, ) +import numpy as np import torch import torch.types from PIL import Image from torch import nn from torch.nn.init import trunc_normal_ from transformers import PretrainedConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size -from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata +from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ( @@ -63,6 +60,88 @@ RawImageType = Union[Image.Image, torch.Tensor] +# sin/cos positional embedding helpers are adapted from: +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: np.ndarray, version: Tuple[int, int] = (2, 0) +) -> torch.Tensor: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) / (H, W) + out: (M, D) / (H, W, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + if version == (2, 0): + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + else: + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product + emb_sin = np.sin(out) # (H, W, D/2) + emb_cos = np.cos(out) # (H, W, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: np.ndarray, version: Tuple[int, int] = (2, 0) +) -> torch.Tensor: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0], version + ) # (H*W, D/2) or (H, W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1], version + ) # (H*W, D/2) or (H, W, D/2) + + if version == (2, 0): + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + else: + emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +) -> torch.Tensor: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size) + + if version == (2, 0): + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + else: + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + return pos_embed + + class Idefics2VisionMLP(nn.Module): def __init__( @@ -116,6 +195,10 @@ def __init__( projection_size=config.intermediate_size, use_qkv_parallel=True, quant_config=quant_config, + dropout=config.attention_dropout, + use_context_forward=False, + use_full_precision_softmax=True, + flatten_batch=False, prefix=f"{prefix}.self_attn", ) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -126,7 +209,6 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - forward_batch: ForwardBatch, ) -> torch.Tensor: """ Args: @@ -136,11 +218,8 @@ def forward( """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn( - hidden_states, - cu_seqlens=cu_seqlens, - # , forward_batch=forward_batch - ) + hidden_states = self.self_attn(hidden_states, cu_seqlens=cu_seqlens) + hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) @@ -181,7 +260,6 @@ def forward( self, inputs_embeds: torch.Tensor, cu_seqlens: torch.Tensor, - forward_batch: ForwardBatch, ) -> torch.Tensor: r""" Args: @@ -195,7 +273,8 @@ def forward( hidden_states = inputs_embeds for encoder_layer in self.layers: layer_outputs = encoder_layer( - hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch + hidden_states, + cu_seqlens=cu_seqlens, ) hidden_states = layer_outputs return hidden_states @@ -232,19 +311,14 @@ def __init__(self, config: PretrainedConfig): self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - def forward( + def get_position_ids( self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, tgt_sizes: Optional[torch.IntTensor] = None, - ) -> torch.Tensor: + ): batch_size, _, max_im_h, max_im_w = pixel_values.shape - target_dtype = self.patch_embedding.weight.dtype - pixel_values = pixel_values.to( - device=self.patch_embedding.weight.device, dtype=target_dtype - ) - patch_embeds = self.patch_embedding(pixel_values) - embeddings = patch_embeds.flatten(2).transpose(1, 2) + max_nb_patches_h, max_nb_patches_w = ( max_im_h // self.patch_size, max_im_w // self.patch_size, @@ -277,6 +351,24 @@ def forward( ).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) + return position_ids + + def forward( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + pixel_values = pixel_values.to( + device=self.patch_embedding.weight.device, dtype=target_dtype + ) + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + position_ids = self.get_position_ids( + pixel_values, patch_attention_mask, tgt_sizes + ) + embeddings = embeddings + self.position_embedding(position_ids) return embeddings @@ -287,7 +379,6 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", ) -> None: super().__init__() @@ -302,8 +393,6 @@ def get_input_embeddings(self): def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,) - - # 做 prefix sum 来得到 cu_seqlens,注意在最前面插一个 0 作为 offset cu_seqlens = torch.cat( [ torch.tensor([0], device=patch_len.device, dtype=torch.int32), @@ -316,19 +405,18 @@ def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: def forward( self, pixel_values, - forward_batch: ForwardBatch, patch_attention_mask: Optional[torch.BoolTensor] = None, tgt_sizes: Optional[torch.IntTensor] = None, ) -> torch.Tensor: hidden_states = self.embeddings( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, - # forward_batch=forward_batch, tgt_sizes=tgt_sizes, ) cu_seqlens = self.compute_cu_seqlens(tgt_sizes) encoder_outputs = self.encoder( - hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch + hidden_states, + cu_seqlens=cu_seqlens, ) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state @@ -573,14 +661,12 @@ def __init__( config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, ): - # multimodal_config = config.model_config.multimodal_config super().__init__() # All MiniCPM-V models disable `tie_word_embeddings` but # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot - # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model + # check `tie_word_embeddings` until SGLang integrate MiniCPM-V model # and config class self.config = config - # self.multimodal_config = multimodal_config self.version = get_version_by_config(self.config) self.llm = self.init_llm(config=config, quant_config=quant_config) @@ -598,13 +684,6 @@ def __init__( self.logits_processor = LogitsProcessor(config) - @cached_property - def sampler(self): - if hasattr(self.llm, "sampler"): - return self.llm.sampler - - return get_sampler() - def _get_image_bounds( self, input_ids: torch.Tensor, @@ -666,7 +745,6 @@ def get_embedding( self, input_ids: torch.Tensor, image_inputs: Optional[MiniCPMVImageInputs], - forward_batch: ForwardBatch, ) -> Tuple[torch.Tensor, torch.Tensor]: vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) @@ -680,10 +758,7 @@ def get_embedding( .to(vlm_embedding.device) ) else: - vision_hidden_states = self.get_vision_hidden_states( - forward_batch, image_inputs - ) - + vision_hidden_states = self.get_vision_hidden_states(image_inputs) # See NOTE in _parse_and_validate_inputs image_bounds = image_inputs["image_bounds"] if len(image_bounds) > 0: @@ -693,6 +768,7 @@ def get_embedding( for start, end in image_bounds.tolist() ] ).to(vlm_embedding.device) + vlm_embedding.scatter_( 0, image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), @@ -839,7 +915,7 @@ def forward( # There values are useless because their embeddings will be replaced by vision embeddings anyway. input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch) + vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) # always pass the input via `inputs_embeds` # to make sure the computation graph is consistent @@ -857,29 +933,6 @@ def forward( input_ids, hidden_states, self.llm.lm_head, forward_batch ) - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - return self.llm.compute_logits(hidden_states, sampling_metadata) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def get_mm_mapping(self) -> MultiModelKeys: - """ - Get the module prefix in multimodal models - """ - return MultiModelKeys.from_string_field( - language_model="llm", connector="resampler", tower_model="vpm" - ) - def init_llm( self, config: Qwen2Config, @@ -910,9 +963,7 @@ def get_vision_embedding( ) -> torch.Tensor: raise NotImplementedError - def get_vision_hidden_states( - self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs - ) -> torch.Tensor: + def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: raise NotImplementedError @@ -1019,7 +1070,6 @@ def get_vision_embedding( def get_vision_hidden_states( self, - forward_batch: ForwardBatch, data: MiniCPMVImageInputs, ) -> torch.Tensor: pixel_values = data["data"] @@ -1042,15 +1092,18 @@ def get_vision_hidden_states( patch_attn_mask = torch.zeros( (B, 1, max_patches), dtype=torch.bool, device=device ) - for i in range(B): - patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True + + tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device) + mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1] + patch_attn_mask[:, 0, :] = torch.arange( + patch_attn_mask.size(2), device=patch_attn_mask.device + ).unsqueeze(0) < mask_shapes.unsqueeze(1) + vision_embedding = self.vpm( all_pixel_values.type(dtype), - forward_batch=forward_batch, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes, ) - return self.resampler(vision_embedding, tgt_sizes) def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): @@ -1138,7 +1191,7 @@ class MiniCPMV: """ Different versions of MiniCPMV use different visual encoders and LLMs, which is not conducive to the current integration logic of LoRA and - bitsandbytes in vLLM. Therefore, it is necessary to separate them. + bitsandbytes in SGLang. Therefore, it is necessary to separate them. """ # Ensure that the LoRA support check passes when the class is not diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 43f6793e4ef..05069edb69b 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -17,6 +17,7 @@ import sglang.srt.distributed.parallel_state as ps from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -145,61 +146,6 @@ def forward( return hidden_state -class MllamaVisionSdpaAttention(nn.Module): - def __init__(self, config: config_mllama.MllamaVisionConfig): - super().__init__() - - model_parallel_size = get_tensor_model_parallel_world_size() - self.embed_dim = config.hidden_size - self.num_heads = config.attention_heads - self.head_dim = config.hidden_size // config.attention_heads - self.num_local_heads = self.num_heads // model_parallel_size - self.q_size = self.num_local_heads * self.head_dim - self.kv_size = self.num_local_heads * self.head_dim - - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - bias=False, - ) - self.o_proj = RowParallelLinear( - self.num_heads * self.head_dim, - self.embed_dim, - bias=False, - input_is_parallel=True, - ) - - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_state) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q.view( - q.shape[0], q.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - k = k.view( - k.shape[0], k.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - v = v.view( - v.shape[0], v.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - - # TODO: remove padding in image encoder - attn_output = F.scaled_dot_product_attention( - q, k, v, attn_mask=attention_mask, dropout_p=0.0 - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape( - attn_output.shape[0], attn_output.shape[1], -1 - ) - output, _ = self.o_proj(attn_output) - return output - - class MllamaVisionMLP(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -237,7 +183,17 @@ def __init__( self.is_gated = is_gated self.intermediate_size = config.intermediate_size - self.self_attn = MllamaVisionSdpaAttention(config) + self.self_attn = VisionAttention( + self.hidden_size, + self.num_attention_heads, + self.hidden_size, + use_qkv_parallel=True, + quant_config=None, + dropout=0.0, + use_context_forward=False, + use_full_precision_softmax=False, + flatten_batch=False, + ) self.mlp = MllamaVisionMLP(config) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) @@ -992,6 +948,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: + if "vision_model" in name: + # adapt to VisionAttention + name = name.replace("self_attn.o_proj", "self_attn.proj") + param = params_dict.pop(name) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py old mode 100755 new mode 100644 diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index d06b5b5be4b..2cdb3182f2f 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -278,7 +278,10 @@ def __init__( self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + if hasattr(self.config, "scale_emb"): + return self.embed_tokens(input_ids) * self.config.scale_emb + else: + return self.embed_tokens(input_ids) def forward( self, diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 0fb85679f7a..365891544e0 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -30,12 +30,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange from vllm.model_executor.layers.activation import QuickGELU from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig -from sglang.srt.distributed import parallel_state -from sglang.srt.distributed import utils as dist_utils from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear @@ -118,6 +116,7 @@ def __init__( mlp_ratio: float, act_layer: Type[nn.Module] = QuickGELU, norm_layer: Type[nn.Module] = None, + attn_implementation: Optional[str] = "sdpa", quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -126,12 +125,24 @@ def __init__( self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) + if attn_implementation == "sdpa": + use_context_forward = False + use_full_precision_softmax = False + elif attn_implementation == "flash_attention_2": + use_full_precision_softmax = False + use_context_forward = True + elif attn_implementation == "eager": + use_full_precision_softmax = True + use_context_forward = False self.attn = VisionAttention( embed_dim=dim, num_heads=num_heads, projection_size=dim, use_qkv_parallel=False, + use_context_forward=use_context_forward, + use_full_precision_softmax=use_full_precision_softmax, + flatten_batch=True, quant_config=quant_config, ) self.mlp = Qwen2VisionMLP( @@ -286,7 +297,6 @@ def __init__( norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = embed_dim // num_heads self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList( [ Qwen2VisionBlock( @@ -294,6 +304,7 @@ def __init__( num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, + attn_implementation="sdpa", quant_config=quant_config, ) for _ in range(depth) @@ -482,10 +493,6 @@ def forward( opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). (Use input_metadata.mrope_positions to replace it) - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. """ if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": positions = forward_batch.mrope_positions @@ -540,15 +547,18 @@ def forward( num_image_tokens = self.calculate_num_image_tokens( image_grid_thws[idx] ) + left_idx = start_idx + (image_offset - prefix_len) right_idx = ( start_idx + (image_offset - prefix_len) + num_image_tokens ) + inputs_embeds[left_idx:right_idx] = image_embeds[ image_embeds_offset : image_embeds_offset + num_image_tokens ] image_embeds_offset += num_image_tokens + input_ids = None hidden_states = self.model( input_ids=input_ids, positions=positions, diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 024a6f317fa..7b3e5bc5ddd 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -460,7 +460,12 @@ def get_num_params(self): params_dict = dict(self.named_parameters()) return len(params_dict) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights_to_module( + self, + fqn: str, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + """Load weights onto submodule pointed by path `fqn`.""" stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -469,7 +474,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] - params_dict = dict(self.named_parameters()) + module = self.get_submodule(fqn) + params_dict = dict(module.named_parameters(prefix=fqn, recurse=False)) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: @@ -486,7 +492,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith(".bias") or name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader @@ -494,12 +500,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith(".bias") or name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + """Load weights onto the full model.""" + self.load_weights_to_module("", weights) + class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM): pass diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 5056ba22ef9..6687a4c0133 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -20,7 +20,7 @@ import time import uuid from http import HTTPStatus -from typing import Dict, List +from typing import Dict, List, Optional from fastapi import HTTPException, Request, UploadFile from fastapi.responses import ORJSONResponse, StreamingResponse @@ -40,6 +40,7 @@ generate_chat_conv, register_conv_template, ) +from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.openai_api.protocol import ( BatchRequest, @@ -71,7 +72,6 @@ TopLogprob, UsageInfo, ) -from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ret, to_file=True, cache_report=tokenizer_manager.server_args.enable_cache_report, + tool_call_parser=tokenizer_manager.server_args.tool_call_parser, ) else: responses = v1_generate_response( @@ -877,9 +878,6 @@ def v1_chat_generate_request( tools = None if request.tools and request.tool_choice != "none": request.skip_special_tokens = False - if request.stream: - logger.warning("Streaming is not supported with tools.") - request.stream = False if not isinstance(request.tool_choice, str): tools = [ item.function.model_dump() @@ -908,12 +906,26 @@ def v1_chat_generate_request( openai_compatible_messages = openai_compatible_messages[:-1] else: assistant_prefix = None - prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( - openai_compatible_messages, - tokenize=True, - add_generation_prompt=True, - tools=tools, - ) + + try: + prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + openai_compatible_messages, + tokenize=True, + add_generation_prompt=True, + tools=tools, + ) + except: + # This except branch will be triggered when the chosen model + # has a different tools input format that is not compatiable + # with openAI's apply_chat_template tool_call format, like Mistral. + tools = [t if "function" in t else {"function": t} for t in tools] + prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + openai_compatible_messages, + tokenize=True, + add_generation_prompt=True, + tools=tools, + ) + if assistant_prefix: prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix) stop = request.stop @@ -1005,7 +1017,9 @@ def v1_chat_generate_request( return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] -def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): +def v1_chat_generate_response( + request, ret, to_file=False, cache_report=False, tool_call_parser=None +): choices = [] for idx, ret_item in enumerate(ret): @@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): if finish_reason == "stop": finish_reason = "tool_calls" try: - text, call_info_list = parse_tool_response(text, tools) # noqa + parser = FunctionCallParser(tools, tool_call_parser) + full_normal_text, call_info_list = parser.parse_non_stream(text) tool_calls = [ ToolCall( - id=str(call_info[0]), + id=str(call_info.tool_index), function=FunctionResponse( - name=call_info[1], arguments=call_info[2] + name=call_info.name, arguments=call_info.parameters ), ) for call_info in call_info_list @@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) if adapted_request.stream: + parser_dict = {} async def generate_stream_resp(): is_firsts = {} @@ -1184,6 +1200,7 @@ async def generate_stream_resp(): adapted_request, raw_request ): index = content.get("index", 0) + text = content["text"] is_first = is_firsts.get(index, True) stream_buffer = stream_buffers.get(index, "") @@ -1263,29 +1280,111 @@ async def generate_stream_resp(): text = content["text"] delta = text[len(stream_buffer) :] - stream_buffer = stream_buffer + delta - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(content=delta), - finish_reason=(finish_reason["type"] if finish_reason else ""), - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - logprobs=choice_logprobs, - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - choices=[choice_data], - model=request.model, - ) + new_stream_buffer = stream_buffer + delta - is_firsts[index] = is_first - stream_buffers[index] = stream_buffer - n_prev_tokens[index] = n_prev_token + if request.tool_choice != "none" and request.tools: + if index not in parser_dict: + parser_dict[index] = FunctionCallParser( + tools=request.tools, + tool_call_parser=tokenizer_manager.server_args.tool_call_parser, + ) + parser = parser_dict[index] + + # parse_increment => returns (normal_text, calls) + normal_text, calls = parser.parse_stream_chunk(delta) + + # 1) if there's normal_text, output it as normal content + if normal_text: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=normal_text), + finish_reason=( + finish_reason["type"] if finish_reason else "" + ), + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + # 2) if we found calls, we output them as separate chunk(s) + for call_item in calls: + # transform call_item -> FunctionResponse + ToolCall + + if ( + content["meta_info"]["finish_reason"] + and content["meta_info"]["finish_reason"]["type"] + == "stop" + ): + latest_delta_len = 0 + if isinstance(call_item.parameters, str): + latest_delta_len = len(call_item.parameters) + + expected_call = json.dumps( + parser.multi_format_parser.detectors[0] + .prev_tool_call_arr[index] + .get("arguments", {}), + ensure_ascii=False, + ) + actual_call = parser.multi_format_parser.detectors[ + 0 + ].streamed_args_for_tool[index] + if latest_delta_len > 0: + actual_call = actual_call[:-latest_delta_len] + remaining_call = expected_call.replace( + actual_call, "", 1 + ) + call_item.parameters = remaining_call + + tool_call = ToolCall( + id=str(call_item.tool_index), + function=FunctionResponse( + name=call_item.name, + arguments=call_item.parameters, + ), + ) + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage( + role="assistant", tool_calls=[tool_call] + ), + finish_reason="tool_call", + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" - yield f"data: {chunk.model_dump_json()}\n\n" + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first + + else: + # No tool calls => just treat this as normal text + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=delta), + finish_reason=( + finish_reason["type"] if finish_reason else "" + ), + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + logprobs=choice_logprobs, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first if request.stream_options and request.stream_options.include_usage: total_prompt_tokens = sum( tokens @@ -1333,7 +1432,10 @@ async def generate_stream_resp(): ret = [ret] response = v1_chat_generate_response( - request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report + request, + ret, + cache_report=tokenizer_manager.server_args.enable_cache_report, + tool_call_parser=tokenizer_manager.server_args.tool_call_parser, ) return response diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 2ed9006c0ea..95b34527edb 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -262,7 +262,7 @@ class Function(BaseModel): """Function descriptions.""" description: Optional[str] = Field(default=None, examples=[None]) - name: str + name: Optional[str] = None parameters: Optional[object] = None @@ -276,7 +276,7 @@ class Tool(BaseModel): class ToolChoiceFuncName(BaseModel): """The name of tool choice function.""" - name: str + name: Optional[str] = None class ToolChoice(BaseModel): @@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel): class FunctionResponse(BaseModel): """Function response.""" - name: str - arguments: str + name: Optional[str] = None + arguments: Optional[str] = None class ToolCall(BaseModel): @@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel): class DeltaMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) class ChatCompletionResponseStreamChoice(BaseModel): diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py new file mode 100644 index 00000000000..a64b2498f23 --- /dev/null +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -0,0 +1,38 @@ +import json +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, Dict, List, Optional + +import dill +import torch + + +@lru_cache(maxsize=None) +def _cache_from_str(json_str: str): + """Deserialize a json string to a Callable object. + This function is cached to avoid redundant deserialization. + """ + data = json.loads(json_str) + return dill.loads(bytes.fromhex(data["callable"])) + + +class CustomLogitProcessor(ABC): + """Abstract base class for callable functions.""" + + @abstractmethod + def __call__( + self, + logits: torch.Tensor, + custom_param_list: Optional[List[Dict[str, Any]]] = None, + ) -> torch.Tensor: + """Define the callable behavior.""" + raise NotImplementedError + + def to_str(self) -> str: + """Serialize the callable function to a JSON-compatible string.""" + return json.dumps({"callable": dill.dumps(self).hex()}) + + @classmethod + def from_str(cls, json_str: str): + """Deserialize a callable function from a JSON string.""" + return _cache_from_str(json_str) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py index fcd5ff71c23..fe687c569d4 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py @@ -3,11 +3,16 @@ import torch from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs -from sglang.srt.utils import is_cuda_available +from sglang.srt.utils import get_compiler_backend -is_cuda = is_cuda_available() -if is_cuda: - from sgl_kernel import sampling_scaling_penalties + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def apply_scaling_penalties(logits, scaling_penalties): + logits[:] = torch.where( + logits > 0, + logits / scaling_penalties, + logits * scaling_penalties, + ) class BatchedRepetitionPenalizer(_BatchedPenalizer): @@ -61,16 +66,8 @@ def _cumulate_output_tokens(self, output_ids: _TokenIDs): self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] def _apply(self, logits: torch.Tensor) -> torch.Tensor: - if is_cuda: - return sampling_scaling_penalties( - logits, self.cumulated_repetition_penalties - ) - else: - return torch.where( - logits > 0, - logits / self.cumulated_repetition_penalties, - logits * self.cumulated_repetition_penalties, - ) + apply_scaling_penalties(logits, self.cumulated_repetition_penalties) + return logits def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 6eda63c706a..9521a34f4f6 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -3,17 +3,15 @@ import dataclasses import logging import threading -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch -from sglang.srt.utils import is_cuda_available - -is_cuda = is_cuda_available() -if is_cuda: - from sgl_kernel import sampling_scaling_penalties - import sglang.srt.sampling.penaltylib as penaltylib +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor +from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import ( + apply_scaling_penalties, +) logger = logging.getLogger(__name__) @@ -36,6 +34,9 @@ class SamplingBatchInfo: # Dispatch in CUDA graph need_min_p_sampling: bool + # Whether any request has custom logit processor + has_custom_logit_processor: bool + # Bias Tensors vocab_size: int grammars: Optional[List] = None @@ -52,6 +53,14 @@ class SamplingBatchInfo: # Device device: str = "cuda" + # Custom Parameters + custom_params: Optional[List[Optional[Dict[str, Any]]]] = None + + # Custom Logit Processor + custom_logit_processor: Optional[ + Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]] + ] = None + @classmethod def from_schedule_batch( cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool @@ -76,6 +85,39 @@ def from_schedule_batch( [r.sampling_params.min_p for r in reqs], dtype=torch.float ).to(device, non_blocking=True) + # Check if any request has custom logit processor + has_custom_logit_processor = ( + batch.enable_custom_logit_processor # check the flag first. + and any(r.custom_logit_processor for r in reqs) # then check the requests. + ) + + if has_custom_logit_processor: + # Merge the same type of custom logit processors together + processor_dict = {} + for i, r in enumerate(reqs): + if r.custom_logit_processor is None: + continue + processor_str = r.custom_logit_processor + if processor_str not in processor_dict: + processor_dict[processor_str] = [] + processor_dict[processor_str].append(i) + + merged_custom_logit_processor = { + hash(processor_str): ( + # The deserialized custom logit processor object + CustomLogitProcessor.from_str(processor_str), + # The mask tensor for the requests that use this custom logit processor + torch.zeros(len(reqs), dtype=torch.bool) + .scatter_(0, torch.tensor(true_indices), True) + .to(device, non_blocking=True), + ) + for processor_str, true_indices in processor_dict.items() + } + custom_params = [r.sampling_params.custom_params for r in reqs] + else: + merged_custom_logit_processor = None + custom_params = None + ret = cls( temperatures=temperatures, top_ps=top_ps, @@ -83,8 +125,11 @@ def from_schedule_batch( min_ps=min_ps, need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), + has_custom_logit_processor=has_custom_logit_processor, vocab_size=vocab_size, device=device, + custom_params=custom_params, + custom_logit_processor=merged_custom_logit_processor, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -184,6 +229,8 @@ def update_regex_vocab_mask(self): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + if self.has_custom_logit_processor: + self._filter_batch_custom_logit_processor(unfinished_indices, new_indices) for item in [ "temperatures", @@ -196,6 +243,27 @@ def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor) if value is not None: # logit_bias can be None setattr(self, item, value[new_indices]) + def _filter_batch_custom_logit_processor( + self, unfinished_indices: List[int], new_indices: torch.Tensor + ): + """Filter the custom logit processor and custom params""" + + self.custom_logit_processor = { + k: (p, mask[new_indices]) + for k, (p, mask) in self.custom_logit_processor.items() + if any( + mask[new_indices] + ) # ignore the custom logit processor whose mask is all False + } + self.custom_params = [self.custom_params[i] for i in unfinished_indices] + + # If the custom logit processor is an empty dict, set the flag to False, + # and set the custom logit processor and custom params to None. + if len(self.custom_logit_processor) == 0: + self.custom_logit_processor = None + self.custom_params = None + self.has_custom_logit_processor = False + @staticmethod def merge_bias_tensor( lhs: torch.Tensor, @@ -221,9 +289,76 @@ def merge_bias_tensor( return None + @staticmethod + def merge_custom_logit_processor( + lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]], + rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]], + bs1: int, + bs2: int, + device: str, + ): + if lhs is None and rhs is None: + return None + lhs, rhs = lhs or {}, rhs or {} + + keys = set(lhs.keys()).union(set(rhs.keys())) + merged_dict = {} + + for k in keys: + # Get the logit processor object + processor = lhs[k][0] if k in lhs else rhs[k][0] + # Get and merge the mask tensors from the two dicts + left_mask = ( + lhs[k][1] + if k in lhs + else torch.zeros(bs1, dtype=torch.bool, device=device) + ) + right_mask = ( + rhs[k][1] + if k in rhs + else torch.zeros(bs2, dtype=torch.bool, device=device) + ) + merged_dict[k] = (processor, torch.cat([left_mask, right_mask])) + + assert merged_dict[k][1].shape[0] == bs1 + bs2, ( + f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match " + f"the sum of the batch sizes of the two masks ({bs1 + bs2})" + f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}" + f"\n{lhs=}\n{rhs=}" + ) + + return merged_dict + def merge_batch(self, other: "SamplingBatchInfo"): self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + # Merge the logit bias tensor + self.logit_bias = SamplingBatchInfo.merge_bias_tensor( + self.logit_bias, other.logit_bias, len(self), len(other), self.device + ) + # Merge the custom logit processors and custom params lists + if self.has_custom_logit_processor or other.has_custom_logit_processor: + # Merge the custom logit processors + self.custom_logit_processor = ( + SamplingBatchInfo.merge_custom_logit_processor( + self.custom_logit_processor, + other.custom_logit_processor, + len(self), + len(other), + self.device, + ) + ) + # Merge the custom params lists + self.custom_params = self.custom_params or [None] * len(self) + other.custom_params = other.custom_params or [None] * len(other) + self.custom_params.extend(other.custom_params) + + # Set the flag to True if any of the two has custom logit processor + self.has_custom_logit_processor = True + + # Note: becasue the __len()__ operator is defined on the temperatures tensor, + # please make sure any merge operation with len(self) or len(other) is done before + # the merge operation of the temperatures tensor below. for item in [ "temperatures", "top_ps", @@ -235,9 +370,6 @@ def merge_batch(self, other: "SamplingBatchInfo"): setattr(self, item, torch.concat([self_val, other_val])) self.is_all_greedy = self.is_all_greedy and other.is_all_greedy - self.logit_bias = SamplingBatchInfo.merge_bias_tensor( - self.logit_bias, other.logit_bias, len(self), len(other), self.device - ) self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling def apply_logits_bias(self, logits: torch.Tensor): @@ -251,14 +383,7 @@ def apply_logits_bias(self, logits: torch.Tensor): # repetition if self.scaling_penalties is not None: - if is_cuda: - logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties) - else: - logits[:] = torch.where( - logits > 0, - logits / self.scaling_penalties, - logits * self.scaling_penalties, - ) + apply_scaling_penalties(logits, self.scaling_penalties) # Apply regex vocab_mask if self.vocab_mask is not None: diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index d1d932693c6..2224fb0919a 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -13,7 +13,7 @@ # ============================================================================== """Sampling parameters for text generation.""" -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union _SAMPLING_EPS = 1e-6 @@ -48,6 +48,7 @@ def __init__( no_stop_trim: bool = False, ignore_eos: bool = False, skip_special_tokens: bool = True, + custom_params: Optional[Dict[str, Any]] = None, ) -> None: self.temperature = temperature self.top_p = top_p @@ -71,6 +72,7 @@ def __init__( self.json_schema = json_schema self.ebnf = ebnf self.no_stop_trim = no_stop_trim + self.custom_params = custom_params # Process some special cases if self.temperature < _SAMPLING_EPS: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index d381d390fa8..869a984d0cf 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -11,1125 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -""" -The entry point of inference server. -SRT = SGLang Runtime. -""" -import asyncio -import atexit -import dataclasses -import json -import logging -import multiprocessing as mp -import os -import signal -import threading -import time -from http import HTTPStatus -from typing import AsyncIterator, Dict, List, Optional, Tuple, Union - -import torch - -from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter - -# Fix a bug of Python threading -setattr(threading, "_register_atexit", lambda *args, **kwargs: None) - -import aiohttp -import orjson -import requests -import uvicorn -import uvloop -from fastapi import FastAPI, File, Form, Request, UploadFile -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import ORJSONResponse, Response, StreamingResponse - -from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.data_parallel_controller import ( - run_data_parallel_controller_process, -) -from sglang.srt.managers.detokenizer_manager import run_detokenizer_process -from sglang.srt.managers.io_struct import ( - CloseSessionReqInput, - ConfigureLoggingReq, - EmbeddingReqInput, - GenerateReqInput, - GetWeightsByNameReqInput, - InitWeightsUpdateGroupReqInput, - OpenSessionReqInput, - ReleaseMemoryOccupationReqInput, - ResumeMemoryOccupationReqInput, - UpdateWeightFromDiskReqInput, - UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromTensorReqInput, -) -from sglang.srt.managers.scheduler import run_scheduler_process -from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency -from sglang.srt.openai_api.adapter import ( - load_chat_template_for_openai_api, - v1_batches, - v1_cancel_batch, - v1_chat_completions, - v1_completions, - v1_delete_file, - v1_embeddings, - v1_files_create, - v1_retrieve_batch, - v1_retrieve_file, - v1_retrieve_file_content, -) -from sglang.srt.openai_api.protocol import ModelCard, ModelList -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import ( - MultiprocessingSerializer, - add_api_key_middleware, - add_prometheus_middleware, - assert_pkg_version, - configure_logger, - delete_directory, - is_port_available, - kill_process_tree, - maybe_set_triton_cache_manager, - prepare_model_and_tokenizer, - set_prometheus_multiproc_dir, - set_ulimit, - set_uvicorn_logging_configs, -) -from sglang.utils import get_exception_traceback -from sglang.version import __version__ - -logger = logging.getLogger(__name__) - -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - -# Fast API -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -tokenizer_manager: TokenizerManager = None -scheduler_info: Dict = None - - -##### Native API endpoints ##### - - -@app.get("/health") -async def health() -> Response: - """Check the health of the http server.""" - return Response(status_code=200) - - -@app.get("/health_generate") -async def health_generate(request: Request) -> Response: - """Check the health of the inference server by generating one token.""" - - sampling_params = {"max_new_tokens": 1, "temperature": 0.7} - - if tokenizer_manager.is_generation: - gri = GenerateReqInput( - input_ids=[0], sampling_params=sampling_params, log_metrics=False - ) - else: - gri = EmbeddingReqInput( - input_ids=[0], sampling_params=sampling_params, log_metrics=False - ) - - try: - async for _ in tokenizer_manager.generate_request(gri, request): - break - return Response(status_code=200) - except Exception as e: - logger.exception(e) - return Response(status_code=503) - - -@app.get("/get_model_info") -async def get_model_info(): - """Get the model information.""" - result = { - "model_path": tokenizer_manager.model_path, - "tokenizer_path": tokenizer_manager.server_args.tokenizer_path, - "is_generation": tokenizer_manager.is_generation, - } - return result - - -@app.get("/get_server_info") -async def get_server_info(): - return { - **dataclasses.asdict(tokenizer_manager.server_args), - **scheduler_info, - "version": __version__, - } - - -# fastapi implicitly converts json in the request to obj (dataclass) -@app.api_route("/generate", methods=["POST", "PUT"]) -@time_func_latency -async def generate_request(obj: GenerateReqInput, request: Request): - """Handle a generate request.""" - if obj.stream: - - async def stream_results() -> AsyncIterator[bytes]: - try: - async for out in tokenizer_manager.generate_request(obj, request): - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - except ValueError as e: - out = {"error": {"message": str(e)}} - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - yield b"data: [DONE]\n\n" - - return StreamingResponse( - stream_results(), - media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(obj), - ) - else: - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - logger.error(f"Error: {e}") - return _create_error_response(e) - - -@app.api_route("/encode", methods=["POST", "PUT"]) -@time_func_latency -async def encode_request(obj: EmbeddingReqInput, request: Request): - """Handle an embedding request.""" - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return _create_error_response(e) - - -@app.api_route("/classify", methods=["POST", "PUT"]) -@time_func_latency -async def classify_request(obj: EmbeddingReqInput, request: Request): - """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return _create_error_response(e) - - -@app.post("/flush_cache") -async def flush_cache(): - """Flush the radix cache.""" - tokenizer_manager.flush_cache() - return Response( - content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", - status_code=200, - ) - - -@app.api_route("/start_profile", methods=["GET", "POST"]) -async def start_profile_async(): - """Start profiling.""" - tokenizer_manager.start_profile() - return Response( - content="Start profiling.\n", - status_code=200, - ) - - -@app.api_route("/stop_profile", methods=["GET", "POST"]) -async def stop_profile_async(): - """Stop profiling.""" - tokenizer_manager.stop_profile() - return Response( - content="Stop profiling. This will take some time.\n", - status_code=200, - ) - - -@app.post("/update_weights_from_disk") -@time_func_latency -async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): - """Update the weights from disk in-place without re-launching the server.""" - success, message = await tokenizer_manager.update_weights_from_disk(obj, request) - content = {"success": success, "message": message} - if success: - return ORJSONResponse( - content, - status_code=HTTPStatus.OK, - ) - else: - return ORJSONResponse( - content, - status_code=HTTPStatus.BAD_REQUEST, - ) - - -@app.post("/init_weights_update_group") -async def init_weights_update_group( - obj: InitWeightsUpdateGroupReqInput, request: Request -): - """Initialize the parameter update group.""" - success, message = await tokenizer_manager.init_weights_update_group(obj, request) - content = {"success": success, "message": message} - if success: - return ORJSONResponse(content, status_code=200) - else: - return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) - - -@app.post("/update_weights_from_distributed") -async def update_weights_from_distributed( - obj: UpdateWeightsFromDistributedReqInput, request: Request -): - """Update model parameter from distributed online.""" - success, message = await tokenizer_manager.update_weights_from_distributed( - obj, request - ) - content = {"success": success, "message": message} - if success: - return ORJSONResponse(content, status_code=200) - else: - return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) - - -@app.api_route("/get_weights_by_name", methods=["GET", "POST"]) -async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): - """Get model parameter by name.""" - try: - ret = await tokenizer_manager.get_weights_by_name(obj, request) - if ret is None: - return _create_error_response("Get parameter by name failed") - else: - return ORJSONResponse(ret, status_code=200) - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/release_memory_occupation", methods=["GET", "POST"]) -async def release_memory_occupation( - obj: ReleaseMemoryOccupationReqInput, request: Request -): - """Release GPU occupation temporarily""" - try: - await tokenizer_manager.release_memory_occupation(obj, request) - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/resume_memory_occupation", methods=["GET", "POST"]) -async def resume_memory_occupation( - obj: ResumeMemoryOccupationReqInput, request: Request -): - """Resume GPU occupation""" - try: - await tokenizer_manager.resume_memory_occupation(obj, request) - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/open_session", methods=["GET", "POST"]) -async def open_session(obj: OpenSessionReqInput, request: Request): - """Open a session, and return its unique session id.""" - try: - session_id = await tokenizer_manager.open_session(obj, request) - if session_id is None: - raise Exception( - "Failed to open the session. Check if a session with the same id is still open." - ) - return session_id - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/close_session", methods=["GET", "POST"]) -async def close_session(obj: CloseSessionReqInput, request: Request): - """Close the session""" - try: - await tokenizer_manager.close_session(obj, request) - return Response(status_code=200) - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/configure_logging", methods=["GET", "POST"]) -async def configure_logging(obj: ConfigureLoggingReq, request: Request): - """Close the session""" - tokenizer_manager.configure_logging(obj) - return Response(status_code=200) - - -##### OpenAI-compatible API endpoints ##### - - -@app.post("/v1/completions") -@time_func_latency -async def openai_v1_completions(raw_request: Request): - return await v1_completions(tokenizer_manager, raw_request) - - -@app.post("/v1/chat/completions") -@time_func_latency -async def openai_v1_chat_completions(raw_request: Request): - return await v1_chat_completions(tokenizer_manager, raw_request) - - -@app.post("/v1/embeddings", response_class=ORJSONResponse) -@time_func_latency -async def openai_v1_embeddings(raw_request: Request): - response = await v1_embeddings(tokenizer_manager, raw_request) - return response - - -@app.get("/v1/models", response_class=ORJSONResponse) -def available_models(): - """Show available models.""" - served_model_names = [tokenizer_manager.served_model_name] - model_cards = [] - for served_model_name in served_model_names: - model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) - return ModelList(data=model_cards) - - -@app.post("/v1/files") -async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): - return await v1_files_create( - file, purpose, tokenizer_manager.server_args.file_storage_pth - ) - - -@app.delete("/v1/files/{file_id}") -async def delete_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/delete - return await v1_delete_file(file_id) - - -@app.post("/v1/batches") -async def openai_v1_batches(raw_request: Request): - return await v1_batches(tokenizer_manager, raw_request) - - -@app.post("/v1/batches/{batch_id}/cancel") -async def cancel_batches(batch_id: str): - # https://platform.openai.com/docs/api-reference/batch/cancel - return await v1_cancel_batch(tokenizer_manager, batch_id) - - -@app.get("/v1/batches/{batch_id}") -async def retrieve_batch(batch_id: str): - return await v1_retrieve_batch(batch_id) - - -@app.get("/v1/files/{file_id}") -async def retrieve_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve - return await v1_retrieve_file(file_id) - - -@app.get("/v1/files/{file_id}/content") -async def retrieve_file_content(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve-contents - return await v1_retrieve_file_content(file_id) - - -def _create_error_response(e): - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -def launch_engine( - server_args: ServerArgs, -): - """ - Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. - """ - - global tokenizer_manager - global scheduler_info - - # Configure global environment - configure_logger(server_args) - server_args.check_server_args() - _set_envs_and_config(server_args) - - # Allocate ports for inter-process communications - port_args = PortArgs.init_new(server_args) - logger.info(f"{server_args=}") - - # If using model from www.modelscope.cn, first download the model. - server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( - server_args.model_path, server_args.tokenizer_path - ) - - scheduler_procs = [] - if server_args.dp_size == 1: - # Launch tensor parallel scheduler processes - memory_saver_adapter = TorchMemorySaverAdapter.create( - enable=server_args.enable_memory_saver - ) - - scheduler_pipe_readers = [] - tp_size_per_node = server_args.tp_size // server_args.nnodes - tp_rank_range = range( - tp_size_per_node * server_args.node_rank, - tp_size_per_node * (server_args.node_rank + 1), - ) - for tp_rank in tp_rank_range: - reader, writer = mp.Pipe(duplex=False) - gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node - proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, None, writer), - ) - with memory_saver_adapter.configure_subprocess(): - proc.start() - scheduler_procs.append(proc) - scheduler_pipe_readers.append(reader) - else: - # Launch the data parallel controller - reader, writer = mp.Pipe(duplex=False) - scheduler_pipe_readers = [reader] - proc = mp.Process( - target=run_data_parallel_controller_process, - args=(server_args, port_args, writer), - ) - proc.start() - scheduler_procs.append(proc) - - if server_args.node_rank >= 1: - # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, - # so they can just wait here. - - for reader in scheduler_pipe_readers: - data = reader.recv() - assert data["status"] == "ready" - - if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": - # When using `Engine` as a Python API, we don't want to block here. - return - - for proc in scheduler_procs: - proc.join() - logger.error( - f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" - ) - return - - # Launch detokenizer process - detoken_proc = mp.Process( - target=run_detokenizer_process, - args=( - server_args, - port_args, - ), - ) - detoken_proc.start() - - # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args) - if server_args.chat_template: - load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) - - # Wait for model to finish loading - scheduler_infos = [] - for i in range(len(scheduler_pipe_readers)): - try: - data = scheduler_pipe_readers[i].recv() - except EOFError as e: - logger.exception(e) - logger.error( - f"Rank {i} scheduler is dead. Please check if there are relevant logs." - ) - scheduler_procs[i].join() - logger.error(f"Exit code: {scheduler_procs[i].exitcode}") - raise - - if data["status"] != "ready": - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) - scheduler_infos.append(data) - - # Assume all schedulers have same scheduler_info - scheduler_info = scheduler_infos[0] - tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] - - -def launch_server( - server_args: ServerArgs, - pipe_finish_writer: Optional[mp.connection.Connection] = None, -): - """ - Launch SRT (SGLang Runtime) Server - - The SRT server consists of an HTTP server and the SRT engine. - - 1. HTTP server: A FastAPI server that routes requests to the engine. - 2. SRT engine: - 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. - 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. - 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. - - Note: - 1. The HTTP server and TokenizerManager both run in the main process. - 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. - """ - launch_engine(server_args=server_args) - - # Add api key authorization - if server_args.api_key: - add_api_key_middleware(app, server_args.api_key) - - # Add prometheus middleware - if server_args.enable_metrics: - add_prometheus_middleware(app) - enable_func_timer() - - # Send a warmup request - t = threading.Thread( - target=_wait_and_warmup, - args=( - server_args, - pipe_finish_writer, - tokenizer_manager.image_token_id, - ), - ) - t.start() - - try: - # Update logging configs - set_uvicorn_logging_configs() - - # Listen for HTTP requests - uvicorn.run( - app, - host=server_args.host, - port=server_args.port, - log_level=server_args.log_level_http or server_args.log_level, - timeout_keep_alive=5, - loop="uvloop", - ) - finally: - t.join() - - -def _set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" - - # Set prometheus env vars - if server_args.enable_metrics: - set_prometheus_multiproc_dir() - - # Set ulimit - set_ulimit() - - # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - - # Check flashinfer version - if server_args.attention_backend == "flashinfer": - assert_pkg_version( - "flashinfer", - "0.1.6", - "Please uninstall the old version and " - "reinstall the latest version by following the instructions " - "at https://docs.flashinfer.ai/installation.html.", - ) - - # Register the signal handler. - # The child processes will send SIGQUIT to this process when any error happens - # This process then clean up the whole process tree - def sigquit_handler(signum, frame): - logger.error( - "Received sigquit from a child proces. It usually means the child failed." - ) - kill_process_tree(os.getpid()) - - signal.signal(signal.SIGQUIT, sigquit_handler) - - # Set mp start method - mp.set_start_method("spawn", force=True) - - -def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): - headers = {} - url = server_args.url() - if server_args.api_key: - headers["Authorization"] = f"Bearer {server_args.api_key}" - - # Wait until the server is launched - success = False - for _ in range(120): - time.sleep(1) - try: - res = requests.get(url + "/get_model_info", timeout=5, headers=headers) - assert res.status_code == 200, f"{res=}, {res.text=}" - success = True - break - except (AssertionError, requests.exceptions.RequestException): - last_traceback = get_exception_traceback() - pass - - if not success: - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_process_tree(os.getpid()) - return - - model_info = res.json() - - # Send a warmup request - request_name = "/generate" if model_info["is_generation"] else "/encode" - max_new_tokens = 128 if model_info["is_generation"] else 1 - json_data = { - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - }, - } - if server_args.skip_tokenizer_init: - json_data["input_ids"] = [10, 11, 12] - else: - # json_data["text"] = "The capital city of France is" - target_length = int(os.getenv("SRT_WARMUP_PASSKEY_LENGTH", "35000")) - json_data["text"] = ( - "You need to find the passkey. Read carefully following text, and remember the passkey\n\n" - ) - filler = "Sky is blue, grass is green, sun is red. And here we go again" - json_data["text"] += filler * (target_length // 35) - json_data[ - "text" - ] += "\n\nThe passkey is $000310$. Remember, the passkey is $000310$.\n\n" - json_data[ - "text" - ] += "\n\nThe passkey is $000310$. Remember, the passkey is $000310$.\n\n" - json_data[ - "text" - ] += "\n\nThe passkey is $000310$. Remember, the passkey is $000310$.\n\n" - json_data["text"] += filler * (target_length // 35) - json_data["text"] += "What was the passkey? The passkey is" - - try: - for _ in range(server_args.dp_size): - res = requests.post( - url + request_name, - json=json_data, - headers=headers, - timeout=3600, - ) - assert res.status_code == 200, f"{res}" - logger.info(f"Warmup response: {res.json()}") - if os.getenv("SRT_EXIT_AFTER_WARMUP", "0") == "1": - logger.error(f"Initialization canceled. SRT_EXIT_AFTER_WARMUP") - kill_process_tree(os.getpid()) - except Exception: - last_traceback = get_exception_traceback() - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_process_tree(os.getpid()) - return - - # Debug print - # logger.info(f"{res.json()=}") - - logger.info("The server is fired up and ready to roll!") - if pipe_finish_writer is not None: - pipe_finish_writer.send("ready") - - if server_args.delete_ckpt_after_loading: - delete_directory(server_args.model_path) - - -STREAM_END_SYMBOL = b"data: [DONE]" -STREAM_CHUNK_START_SYMBOL = b"data:" - - -class Engine: - """ - SRT Engine without an HTTP server layer. - - This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where - launching the HTTP server adds unnecessary complexity or overhead, - """ - - def __init__(self, log_level: str = "error", *args, **kwargs): - """See the arguments in server_args.py::ServerArgs""" - - # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() - atexit.register(self.shutdown) - - server_args = ServerArgs(*args, log_level=log_level, **kwargs) - launch_engine(server_args=server_args) - - def generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - stream: bool = False, - ): - obj = GenerateReqInput( - text=prompt, - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - lora_path=lora_path, - stream=stream, - ) - - # get the current event loop - loop = asyncio.get_event_loop() - ret = loop.run_until_complete(generate_request(obj, None)) - - if stream is True: - - def generator_wrapper(): - offset = 0 - loop = asyncio.get_event_loop() - generator = ret.body_iterator - while True: - chunk = loop.run_until_complete(generator.__anext__()) - - if chunk.startswith(STREAM_END_SYMBOL): - break - else: - data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) - data["text"] = data["text"][offset:] - offset += len(data["text"]) - yield data - - # we cannot yield in the scope of generate() because python does not allow yield + return in the same function - # however, it allows to wrap the generator as a subfunction and return - return generator_wrapper() - else: - return ret - - async def async_generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Dict] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - stream: bool = False, - ): - obj = GenerateReqInput( - text=prompt, - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - lora_path=lora_path, - stream=stream, - ) - - ret = await generate_request(obj, None) - - if stream is True: - generator = ret.body_iterator - - async def generator_wrapper(): - offset = 0 - - while True: - chunk = await generator.__anext__() - - if chunk.startswith(STREAM_END_SYMBOL): - break - else: - data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) - data["text"] = data["text"][offset:] - offset += len(data["text"]) - yield data - - return generator_wrapper() - else: - return ret - - def shutdown(self): - kill_process_tree(os.getpid(), include_parent=False) - - def get_tokenizer(self): - global tokenizer_manager - - if tokenizer_manager is None: - raise ReferenceError("Tokenizer Manager is not initialized.") - else: - return tokenizer_manager.tokenizer - - def encode( - self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - ): - obj = EmbeddingReqInput(text=prompt) - - # get the current event loop - loop = asyncio.get_event_loop() - return loop.run_until_complete(encode_request(obj, None)) - - def start_profile(self): - tokenizer_manager.start_profile() - - def stop_profile(self): - tokenizer_manager.stop_profile() - - def get_server_info(self): - return { - **dataclasses.asdict(tokenizer_manager.server_args), # server args - **scheduler_info, - "version": __version__, - } - - def init_weights_update_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", - ): - """Initialize parameter update group.""" - obj = InitWeightsUpdateGroupReqInput( - master_address=master_address, - master_port=master_port, - rank_offset=rank_offset, - world_size=world_size, - group_name=group_name, - backend=backend, - ) - loop = asyncio.get_event_loop() - return loop.run_until_complete( - tokenizer_manager.init_weights_update_group(obj, None) - ) - - def update_weights_from_distributed(self, name, dtype, shape): - """Update weights from distributed source.""" - obj = UpdateWeightsFromDistributedReqInput( - name=name, - dtype=dtype, - shape=shape, - ) - loop = asyncio.get_event_loop() - return loop.run_until_complete( - tokenizer_manager.update_weights_from_distributed(obj, None) - ) - - def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): - """Update weights from distributed source.""" - obj = UpdateWeightsFromTensorReqInput( - serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors) - ) - loop = asyncio.get_event_loop() - return loop.run_until_complete( - tokenizer_manager.update_weights_from_tensor(obj, None) - ) - - def get_weights_by_name(self, name, truncate_size=100): - """Get weights by parameter name.""" - obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) - loop = asyncio.get_event_loop() - return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None)) - - def release_memory_occupation(self): - """Release GPU occupation temporarily""" - obj = ReleaseMemoryOccupationReqInput() - loop = asyncio.get_event_loop() - loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None)) - - def resume_memory_occupation(self): - """Resume GPU occupation""" - obj = ResumeMemoryOccupationReqInput() - loop = asyncio.get_event_loop() - loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None)) - - -class Runtime: - """ - A wrapper for the HTTP server. - This is used for launching the server in a python program without - using the commond line interface. - - It is mainly used for the frontend language. - You should use the Engine class above if you want to do normal offline processing. - """ - - def __init__( - self, - log_level: str = "error", - *args, - **kwargs, - ): - """See the arguments in server_args.py::ServerArgs""" - self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) - - # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() - atexit.register(self.shutdown) - - # Pre-allocate ports - for port in range(self.server_args.port, 40000): - if is_port_available(port): - break - self.server_args.port = port - - self.url = self.server_args.url() - self.generate_url = self.url + "/generate" - - # NOTE: We store pid instead of proc to fix some issues during __delete__ - self.pid = None - pipe_reader, pipe_writer = mp.Pipe(duplex=False) - - proc = mp.Process( - target=launch_server, - args=(self.server_args, pipe_writer), - ) - proc.start() - pipe_writer.close() - self.pid = proc.pid - - try: - init_state = pipe_reader.recv() - except EOFError: - init_state = "" - - if init_state != "ready": - self.shutdown() - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) - - self.endpoint = RuntimeEndpoint(self.url) - - def shutdown(self): - if self.pid is not None: - kill_process_tree(self.pid) - self.pid = None - - def cache_prefix(self, prefix: str): - self.endpoint.cache_prefix(prefix) - - def get_tokenizer(self): - return get_tokenizer( - self.server_args.tokenizer_path, - tokenizer_mode=self.server_args.tokenizer_mode, - trust_remote_code=self.server_args.trust_remote_code, - ) - - async def async_generate( - self, - prompt: str, - sampling_params: Optional[Dict] = None, - ): - if self.server_args.skip_tokenizer_init: - json_data = { - "input_ids": prompt, - "sampling_params": sampling_params, - "stream": True, - } - else: - json_data = { - "text": prompt, - "sampling_params": sampling_params, - "stream": True, - } - pos = 0 - - timeout = aiohttp.ClientTimeout(total=3 * 3600) - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.post(self.generate_url, json=json_data) as response: - async for chunk, _ in response.content.iter_chunks(): - chunk = chunk.decode("utf-8") - if chunk and chunk.startswith("data:"): - if chunk == "data: [DONE]\n\n": - break - data = json.loads(chunk[5:].strip("\n")) - if "text" in data: - cur = data["text"][pos:] - if cur: - yield cur - pos += len(cur) - else: - yield data - - add_request = async_generate - - def generate( - self, - prompt: Union[str, List[str]], - sampling_params: Optional[Dict] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - ): - json_data = { - "text": prompt, - "sampling_params": sampling_params, - "return_logprob": return_logprob, - "logprob_start_len": logprob_start_len, - "top_logprobs_num": top_logprobs_num, - "lora_path": lora_path, - } - assert not isinstance(lora_path, list) or len(lora_path) == len(prompt) - response = requests.post( - self.url + "/generate", - json=json_data, - ) - return json.dumps(response.json()) - - def encode( - self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - ): - json_data = {"text": prompt} - response = requests.post(self.url + "/encode", json=json_data) - return json.dumps(response.json()) - - async def get_server_info(self): - async with aiohttp.ClientSession() as session: - async with session.get(f"{self.url}/get_server_info") as response: - if response.status == 200: - return await response.json() - else: - error_data = await response.json() - raise RuntimeError( - f"Failed to get server info. {error_data['error']['message']}" - ) - - def __del__(self): - self.shutdown() +# Some shortcuts for backward compatibility. +# They will be removed in new versions. +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 18f308b59f9..019d784f3cf 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -29,8 +29,8 @@ get_nvgpu_memory_capacity, is_flashinfer_available, is_hip, - is_ipv6, is_port_available, + is_valid_ipv6_address, nullable_str, ) @@ -75,6 +75,7 @@ class ServerArgs: # Other runtime options tp_size: int = 1 stream_interval: int = 1 + stream_output: bool = False random_seed: Optional[int] = None constrained_json_whitespace_pattern: Optional[str] = None watchdog_timeout: float = 300 @@ -168,6 +169,11 @@ class ServerArgs: enable_memory_saver: bool = False allow_auto_truncate: bool = False + # Custom logit processor + enable_custom_logit_processor: bool = False + tool_call_parser: str = None + enable_hierarchical_cache: bool = False + def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -323,6 +329,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "dummy", "gguf", "bitsandbytes", + "layered", ], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' @@ -336,7 +343,10 @@ def add_cli_args(parser: argparse.ArgumentParser): "which is mainly for profiling." '"gguf" will load the weights in the gguf format. ' '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization.", + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -501,6 +511,11 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.stream_interval, help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher", ) + parser.add_argument( + "--stream-output", + action="store_true", + help="Whether to output as a sequence of disjoint segments.", + ) parser.add_argument( "--random-seed", type=int, @@ -912,6 +927,24 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.", ) + parser.add_argument( + "--enable-custom-logit-processor", + action="store_true", + help="Enable users to pass custom logit processors to the server (disabled by default for security)", + ) + # Function Calling + parser.add_argument( + "--tool-call-parser", + type=str, + choices=["qwen25", "mistral", "llama3"], + default=ServerArgs.tool_call_parser, + help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.", + ) + parser.add_argument( + "--enable-hierarchical-cache", + action="store_true", + help="Enable hierarchical cache", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -922,7 +955,7 @@ def from_cli_args(cls, args: argparse.Namespace): return cls(**{attr: getattr(args, attr) for attr in attrs}) def url(self): - if is_ipv6(self.host): + if is_valid_ipv6_address(self.host): return f"http://[{self.host}]:{self.port}" else: return f"http://{self.host}:{self.port}" diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 1a324000cb2..97cdb264043 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices( class EAGLEDraftInput(SpecInfo): def __init__(self): self.prev_mode = ForwardMode.DECODE - self.sample_output = None self.scores: torch.Tensor = None self.score_list: List[torch.Tensor] = [] @@ -190,12 +189,16 @@ def __init__(self): self.cache_list: List[torch.Tenor] = [] self.iter = 0 + # shape: (b, hidden_size) self.hidden_states: torch.Tensor = None + # shape: (b,) self.verified_id: torch.Tensor = None + # shape: (b, vocab_size) + self.sample_output: torch.Tensor = None + self.positions: torch.Tensor = None self.accept_length: torch.Tensor = None - self.has_finished: bool = False - self.unfinished_index: List[int] = None + self.accept_length_cpu: List[int] = None def load_server_args(self, server_args: ServerArgs): self.topk: int = server_args.speculative_eagle_topk @@ -218,7 +221,7 @@ def prepare_for_extend(self, batch: ScheduleBatch): :pre_len ] = req.prefix_indices - batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( + batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = ( out_cache_loc[pt : pt + req.extend_input_len] ) @@ -228,6 +231,14 @@ def prepare_for_extend(self, batch: ScheduleBatch): assert len(batch.extend_lens) == 1 batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id)) + def filter_batch( + self, + new_indices: torch.Tensor, + ): + self.sample_output = self.sample_output[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] + def prepare_for_decode(self, batch: ScheduleBatch): prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab) top = torch.topk(prob, self.topk, dim=-1) @@ -287,7 +298,9 @@ def prepare_for_decode(self, batch: ScheduleBatch): self.cache_list.append(batch.out_cache_loc) self.positions = ( batch.seq_lens[:, None] - + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter + + torch.full( + [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long + ) ).flatten() bs = len(batch.seq_lens) @@ -304,24 +317,25 @@ def prepare_for_decode(self, batch: ScheduleBatch): def prepare_extend_after_decode(self, batch: ScheduleBatch): batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) - batch.extend_lens = (self.accept_length + 1).tolist() + accept_length_cpu = batch.spec_info.accept_length_cpu + batch.extend_lens = [x + 1 for x in accept_length_cpu] + batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend + seq_lens_cpu = batch.seq_lens.tolist() pt = 0 - seq_lens = batch.seq_lens.tolist() - i = 0 - for req in batch.reqs: if req.finished(): continue # assert seq_len - pre_len == req.extend_input_len - input_len = self.accept_length[i] + 1 - seq_len = seq_lens[i] + input_len = batch.extend_lens[i] + seq_len = seq_lens_cpu[i] batch.req_to_token_pool.req_to_token[req.req_pool_idx][ seq_len - input_len : seq_len ] = batch.out_cache_loc[pt : pt + input_len] pt += input_len i += 1 + assert pt == batch.out_cache_loc.shape[0] self.positions = torch.empty_like(self.verified_id) new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) @@ -337,7 +351,7 @@ def prepare_extend_after_decode(self, batch: ScheduleBatch): triton.next_power_of_2(self.spec_steps + 1), ) - batch.seq_lens_sum = sum(batch.seq_lens) + batch.seq_lens_sum = sum(seq_lens_cpu) batch.input_ids = self.verified_id self.verified_id = new_verified_id @@ -565,6 +579,8 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten finished_extend_len = {} # {rid:accept_length + 1} accept_index_cpu = accept_index.tolist() predict_cpu = predict.tolist() + has_finished = False + # iterate every accepted token and check if req has finished after append the token # should be checked BEFORE free kv cache slots for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): @@ -578,7 +594,7 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten finished_extend_len[req.rid] = j + 1 req.check_finished() if req.finished(): - draft_input.has_finished = True + has_finished = True # set all tokens after finished token to -1 and break accept_index[i, j + 1 :] = -1 break @@ -587,12 +603,12 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten if not req.finished(): new_accept_index.extend(new_accept_index_) unfinished_index.append(i) + req.spec_verify_ct += 1 accept_length = (accept_index != -1).sum(dim=1) - 1 accept_index = accept_index[accept_index != -1] accept_length_cpu = accept_length.tolist() verified_id = predict[accept_index] - verified_id_cpu = verified_id.tolist() evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False @@ -614,7 +630,13 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten draft_input.verified_id = predict[new_accept_index] draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] draft_input.accept_length = accept_length[unfinished_index] - draft_input.unfinished_index = unfinished_index + draft_input.accept_length_cpu = [ + accept_length_cpu[i] for i in unfinished_index + ] + if has_finished: + draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] + else: + draft_input.seq_lens_for_draft_extend = batch.seq_lens logits_output.next_token_logits = logits_output.next_token_logits[accept_index] return ( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 2a6ec96048b..06a4372fce2 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -13,6 +13,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.eagle_utils import EAGLEDraftInput +from sglang.srt.utils import rank0_print class EAGLEWorker(TpModelWorker): @@ -50,18 +51,18 @@ def __init__( def forward_draft_decode(self, batch: ScheduleBatch): batch.spec_info.prepare_for_decode(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) self.capture_for_decode(logits_output, forward_batch) def forward_draft_extend(self, batch: ScheduleBatch): self._set_mem_pool(batch, self.model_runner) batch.spec_info.prepare_for_extend(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) self.capture_for_decode(logits_output, forward_batch) self._set_mem_pool(batch, self.target_worker.model_runner) @@ -134,26 +135,23 @@ def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): batch.req_to_token_pool = runner.req_to_token_pool def forward_draft_extend_after_decode(self, batch: ScheduleBatch): + seq_lens_backup = batch.seq_lens + self._set_mem_pool(batch, self.model_runner) batch.forward_mode = ForwardMode.DRAFT_EXTEND - if batch.spec_info.has_finished: - index = batch.spec_info.unfinished_index - seq_lens = batch.seq_lens - batch.seq_lens = batch.seq_lens[index] - batch.spec_info.prepare_extend_after_decode(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) - - batch.spec_info.hidden_states = logits_output.hidden_states self.capture_for_decode(logits_output, forward_batch) - batch.forward_mode = ForwardMode.DECODE - if batch.spec_info.has_finished: - batch.seq_lens = seq_lens self._set_mem_pool(batch, self.target_worker.model_runner) + # Restore backup. + # This is because `seq_lens` can be modified in `prepare_extend_after_decode` + batch.forward_mode = ForwardMode.DECODE + batch.seq_lens = seq_lens_backup + def capture_for_decode( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch ): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index fb2c699a58e..898f2debb9f 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -14,6 +14,7 @@ """Common utilities.""" import base64 +import ctypes import dataclasses import io import ipaddress @@ -29,6 +30,7 @@ import signal import socket import subprocess +import sys import tempfile import time import warnings @@ -59,7 +61,6 @@ default_dump_dir, default_override_dir, ) -from uvicorn.config import LOGGING_CONFIG logger = logging.getLogger(__name__) @@ -73,7 +74,7 @@ def is_hip() -> bool: def is_cuda(): - return hasattr(torch, "cuda") and torch.cuda.is_available() + return hasattr(torch, "cuda") and torch.version.cuda is not None def is_cuda_alike(): @@ -102,14 +103,6 @@ def is_cuda_available(): return torch.cuda.is_available() and torch.version.cuda -def is_ipv6(address): - try: - ipaddress.IPv6Address(address) - return True - except ipaddress.AddressValueError: - return False - - def enable_show_time_cost(): global show_time_cost show_time_cost = True @@ -475,8 +468,6 @@ def load_image(image_file: Union[str, bytes]): else: raise ValueError(f"Invalid image: {image}") - # if image_size is None: - # image_size = image.size return image, image_size @@ -542,68 +533,24 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N pass -def monkey_patch_vllm_p2p_access_check(gpu_id: int): +def monkey_patch_p2p_access_check(): """ - Monkey patch the slow p2p access check in vllm. + Monkey patch the slow p2p access check. NOTE: We assume the p2p access is always allowed, which can be wrong for some setups. """ - import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt + import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) # Suppress the warnings from this delete function when using sglang.bench_one_batch - from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce + from sglang.srt.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce, + ) setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None) -vllm_all_gather_backup = None - - -def monkey_patch_vllm_all_gather(reverse: bool = False): - """Monkey patch all-gather to remove in-place operations.""" - from torch.distributed import _functional_collectives as funcol - from vllm.distributed.parallel_state import GroupCoordinator - - global vllm_all_gather_backup - if vllm_all_gather_backup is None: - vllm_all_gather_backup = GroupCoordinator.all_gather - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert ( - -input_.dim() <= dim < input_.dim() - ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # Allocate output tensor. - output_tensor = torch.empty( - (world_size,) + input_size, dtype=input_.dtype, device=input_.device - ) - - output_tensor = funcol.all_gather_tensor( - input_, gather_dim=0, group=self.device_group - ).view((world_size,) + input_size) - - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape( - input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] - ) - return output_tensor - - if reverse: - setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup) - else: - setattr(GroupCoordinator, "all_gather", all_gather) - - def monkey_patch_vllm_gguf_config(): from vllm.model_executor.layers.quantization.gguf import ( GGUFConfig, @@ -849,7 +796,7 @@ def get_zmq_socket( def dump_to_file(dirpath, name, value): - from vllm.distributed import get_tensor_model_parallel_rank + from sglang.srt.distributed import get_tensor_model_parallel_rank if get_tensor_model_parallel_rank() != 0: return @@ -1286,9 +1233,9 @@ def dataclass_to_string_truncated(data, max_length=2048): if isinstance(data, str): if len(data) > max_length: half_length = max_length // 2 - return f'"{data[:half_length]} ... {data[-half_length:]}"' + return f"{repr(data[:half_length])} ... {repr(data[-half_length:])}" else: - return f'"{data}"' + return f"{repr(data)}" elif isinstance(data, (list, tuple)): if len(data) > max_length: half_length = max_length // 2 @@ -1299,7 +1246,7 @@ def dataclass_to_string_truncated(data, max_length=2048): return ( "{" + ", ".join( - f"{k}: {dataclass_to_string_truncated(v, max_length)}" + f"'{k}': {dataclass_to_string_truncated(v, max_length)}" for k, v in data.items() ) + "}" @@ -1318,68 +1265,6 @@ def dataclass_to_string_truncated(data, max_length=2048): return str(data) -TOOLS_TAG_LIST = ["<|plugin|>", "", "<|python_tag|>"] - - -def parse_tool_response(text, tools, **kwargs): - """Parse model response containing tool information. - - Args: - text(str): model response in string format - tools(List): tools from user request - """ - if "<|plugin|>" in text: # internlm2 - text, action = text.split("<|action_start|><|plugin|>") - action = action.split("<|action_end|>".strip())[0] - action = action[action.find("{") :] - action = json.loads(action) - name, parameters = action["name"], json.dumps( - action.get("parameters", action.get("arguments", {})), ensure_ascii=False - ) - call_info_list = [(name, parameters)] - elif "") - parameters = action[action.find("{") :] - name = action.split("{")[0] - call_info_list = [(name, parameters)] - elif "" in text and "" in text: # qwen2.5 - # get tool_call in text - pattern = r"(.*?)" - match_result_list = re.findall(pattern, text, re.DOTALL) - call_info_list = [] - for match_result in match_result_list: - action = json.loads(match_result) - call_info_list.append( - (action["name"], json.dumps(action["arguments"], ensure_ascii=False)) - ) - # get text outside of tags - if not text.startswith(""): - text = text[: text.find("")] - elif not text.endswith(""): - text = text[text.rfind("") + len("") :] - else: - text = "" - elif "<|python_tag|>" in text: # llama3.2 - _, action = text.split("<|python_tag|>") - action = json.loads(action) - name, parameters = action["name"], json.dumps( - action.get("parameters", action.get("arguments", {})), ensure_ascii=False - ) - call_info_list = [(name, parameters)] - else: - raise RuntimeError(f"Unexpected model response: {text}") - - call_info_list = [ - ( - [tool.function.name for tool in tools].index(call_info[0]), - call_info[0], - call_info[1], - ) - for call_info in call_info_list - ] - return text, call_info_list - - def permute_weight(x: torch.Tensor) -> torch.Tensor: b_ = x.shape[0] n_ = x.shape[1] @@ -1442,7 +1327,33 @@ def nullable_str(val: str): return val +def pyspy_dump_schedulers(): + """py-spy dump on all scheduler in a local node.""" + try: + pid = psutil.Process().pid + # Command to run py-spy with the PID + cmd = f"py-spy dump --pid {pid}" + result = subprocess.run( + cmd, shell=True, capture_output=True, text=True, check=True + ) + logger.info(f"Profile for PID {pid}:\n{result.stdout}") + except subprocess.CalledProcessError as e: + logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}") + + +def kill_itself_when_parent_died(): + if sys.platform == "linux": + # sigkill this process when parent worker manager dies + PR_SET_PDEATHSIG = 1 + libc = ctypes.CDLL("libc.so.6") + libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL) + else: + logger.warninig("kill_itself_when_parent_died is only supported in linux.") + + def set_uvicorn_logging_configs(): + from uvicorn.config import LOGGING_CONFIG + LOGGING_CONFIG["formatters"]["default"][ "fmt" ] = "[%(asctime)s] %(levelprefix)s %(message)s" @@ -1451,3 +1362,102 @@ def set_uvicorn_logging_configs(): "fmt" ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + + +def get_ip() -> str: + # SGLANG_HOST_IP env can be ignore + host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "") + if host_ip: + return host_ip + + # IP is not set, try to get it from the network interface + + # try ipv4 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + # try ipv6 + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + warnings.warn( + "Failed to get the IP address, using 0.0.0.0 by default." + "The value can be set by the environment variable" + " SGLANG_HOST_IP or HOST_IP.", + stacklevel=2, + ) + return "0.0.0.0" + + +def get_open_port() -> int: + + port = os.getenv("SGLANG_PORT") + if port is not None: + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError: + port += 1 # Increment port number if already in use + logger.info("Port %d is already in use, trying port %d", port - 1, port) + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def rank0_print(msg: str): + from sglang.srt.distributed import get_tensor_model_parallel_rank + + if get_tensor_model_parallel_rank() == 0: + print(msg, flush=True) + + +def launch_dummy_health_check_server(host, port): + import uvicorn + from fastapi import FastAPI, Response + + app = FastAPI() + + @app.get("/health") + async def health(): + """Check the health of the http server.""" + return Response(status_code=200) + + @app.get("/health_generate") + async def health_generate(): + """Check the health of the http server.""" + return Response(status_code=200) + + uvicorn.run( + app, + host=host, + port=port, + timeout_keep_alive=5, + loop="uvloop", + ) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index f22f9cafaf3..bae0fcf2a49 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -12,7 +12,6 @@ # limitations under the License. # ============================================================================== -import json import multiprocessing as mp import os from dataclasses import dataclass @@ -22,8 +21,8 @@ import torch.nn.functional as F from transformers import AutoModelForCausalLM +from sglang.srt.entrypoints.engine import Engine from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.server import Runtime from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER DEFAULT_PROMPTS = [ @@ -278,7 +277,7 @@ def __init__( ): self.model_type = model_type self.is_generation = model_type == "generation" - self.runtime = Runtime( + self.engine = Engine( model_path=model_path, tp_size=tp_size, dtype=get_dtype_str(torch_dtype), @@ -306,7 +305,7 @@ def forward( top_output_logprobs = [] sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} for i, prompt in enumerate(prompts): - response = self.runtime.generate( + response = self.engine.generate( prompt, lora_path=lora_paths[i] if lora_paths else None, sampling_params=sampling_params, @@ -314,7 +313,6 @@ def forward( logprob_start_len=0, top_logprobs_num=NUM_TOP_LOGPROBS, ) - response = json.loads(response) output_strs.append(response["text"]) top_input_logprobs.append( [ @@ -343,8 +341,7 @@ def forward( top_output_logprobs=top_output_logprobs, ) else: - response = self.runtime.encode(prompts) - response = json.loads(response) + response = self.engine.encode(prompts) if self.model_type == "embedding": logits = [x["embedding"] for x in response] return ModelOutput(embed_logits=logits) @@ -366,20 +363,18 @@ def batch_forward( # the return value contains logprobs from prefill output_strs = [] sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} - response = self.runtime.generate( + response = self.engine.generate( prompts, lora_path=lora_paths if lora_paths else None, sampling_params=sampling_params, ) - response = json.loads(response) output_strs = [r["text"] for r in response] return ModelOutput( output_strs=output_strs, ) else: - response = self.runtime.encode(prompts) - response = json.loads(response) + response = self.engine.encode(prompts) if self.model_type == "embedding": logits = [x["embedding"] for x in response] return ModelOutput(embed_logits=logits) @@ -391,8 +386,8 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - self.runtime.shutdown() - del self.runtime + self.engine.shutdown() + del self.engine def monkey_patch_gemma2_sdpa(): diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 219ed3cf6ec..088cb0d0af9 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -535,6 +535,7 @@ def few_shot_hellaswag(s, question, choices): # Compute accuracy accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels)) + print(f"{accuracy=}, {accuracy_gen=}") assert np.abs(accuracy_gen - accuracy) < 0.05 assert np.abs(latency_gen - latency) < 1 diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index c1437074f67..b303f19121d 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -34,7 +34,7 @@ DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" -DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 +DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000 DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8" @@ -42,6 +42,9 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4" DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct" +DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf" +DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmzheng/sglang-EAGLE-llama2-chat-7B" + def is_in_ci(): """Return whether it is in CI runner.""" @@ -132,10 +135,6 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None): return pred -def call_generate_gserver(prompt, temperature, max_tokens, stop=None, url=None): - raise NotImplementedError() - - def call_generate_guidance( prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None ): @@ -527,6 +526,48 @@ def get_similarities(vec1, vec2): return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0) +def get_benchmark_args( + base_url="", + dataset_name="", + dataset_path="", + tokenizer="", + num_prompts=500, + random_input_len=4096, + random_output_len=2048, + request_rate=float("inf"), + disable_stream=False, + disable_ignore_eos=False, +): + return SimpleNamespace( + backend="sglang", + base_url=base_url, + host=None, + port=None, + dataset_name=dataset_name, + dataset_path=dataset_path, + model=None, + tokenizer=tokenizer, + num_prompts=num_prompts, + sharegpt_output_len=None, + sharegpt_context_len=None, + random_input_len=random_input_len, + random_output_len=random_output_len, + random_range_ratio=0.0, + request_rate=request_rate, + multi=None, + output_file=None, + disable_tqdm=False, + disable_stream=disable_stream, + return_logprob=False, + seed=0, + disable_ignore_eos=disable_ignore_eos, + extra_request_body=None, + apply_chat_template=False, + profile=None, + lora_name=None, + ) + + def run_bench_serving( model, num_prompts, @@ -538,6 +579,7 @@ def run_bench_serving( random_input_len=4096, random_output_len=2048, disable_stream=False, + disable_ignore_eos=False, need_warmup=False, ): # Launch the server @@ -550,32 +592,17 @@ def run_bench_serving( ) # Run benchmark - args = SimpleNamespace( - backend="sglang", + args = get_benchmark_args( base_url=base_url, - host=None, - port=None, dataset_name=dataset_name, dataset_path=dataset_path, - model=None, tokenizer=tokenizer, num_prompts=num_prompts, - sharegpt_output_len=None, - sharegpt_context_len=None, random_input_len=random_input_len, random_output_len=random_output_len, - random_range_ratio=0.0, request_rate=request_rate, - multi=None, - seed=0, - output_file=None, - disable_tqdm=False, disable_stream=disable_stream, - disable_ignore_eos=False, - return_logprob=False, - lora_name=None, - extra_request_body=None, - profile=None, + disable_ignore_eos=disable_ignore_eos, ) try: @@ -591,6 +618,38 @@ def run_bench_serving( return res +def run_bench_serving_multi( + model, + base_url, + other_server_args, + benchmark_args, + need_warmup=False, +): + # Launch the server + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_server_args, + ) + + # run benchmark for all + res_l = [] + try: + for args in benchmark_args: + if need_warmup: + warmup_args = copy.deepcopy(args) + warmup_args.num_prompts = 16 + run_benchmark(warmup_args) + + res = run_benchmark(args) + res_l.append((args, res)) + finally: + kill_process_tree(process.pid) + + return res_l + + def run_bench_one_batch(model, other_args): command = [ "python3", diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 98e0f3f4f8d..399427ef34c 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -1,7 +1,6 @@ """Common utilities""" import base64 -import gc import importlib import json import logging @@ -15,7 +14,7 @@ from concurrent.futures import ThreadPoolExecutor from io import BytesIO from json import dumps -from typing import Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import numpy as np import requests @@ -363,3 +362,56 @@ def terminate_process(process): def print_highlight(html_content: str): html_content = str(html_content).replace("\n", "
") display(HTML(f"{html_content}")) + + +class TypeBasedDispatcher: + def __init__(self, mapping: List[Tuple[Type, Callable]]): + self._mapping = mapping + + def __call__(self, obj: Any): + for ty, fn in self._mapping: + if isinstance(obj, ty): + return fn(obj) + raise ValueError(f"Invalid object: {obj}") + + +def trim_overlap(existing_text, new_chunk): + """ + Finds the largest suffix of 'existing_text' that is a prefix of 'new_chunk' + and removes that overlap from the start of 'new_chunk'. + """ + max_overlap = 0 + max_possible = min(len(existing_text), len(new_chunk)) + for i in range(max_possible, 0, -1): + if existing_text.endswith(new_chunk[:i]): + max_overlap = i + break + return new_chunk[max_overlap:] + + +def stream_and_merge(llm, prompt, sampling_params): + """ + 1) Streams the text, + 2) Removes chunk overlaps, + 3) Returns the merged text. + """ + final_text = "" + for chunk in llm.generate(prompt, sampling_params, stream=True): + chunk_text = chunk["text"] + cleaned_chunk = trim_overlap(final_text, chunk_text) + final_text += cleaned_chunk + return final_text + + +async def async_stream_and_merge(llm, prompt, sampling_params): + """ + Streams tokens asynchronously, removes chunk overlaps, + and yields the cleaned chunk in real time for printing. + """ + final_text = "" + generator = await llm.async_generate(prompt, sampling_params, stream=True) + async for chunk in generator: + chunk_text = chunk["text"] + cleaned_chunk = trim_overlap(final_text, chunk_text) + final_text += cleaned_chunk + yield cleaned_chunk # yield the non-overlapping portion diff --git a/python/sglang/version.py b/python/sglang/version.py index 3a906dbcfff..df12433297b 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.4.1.post6" +__version__ = "0.4.2" diff --git a/scripts/ci_install_rust.sh b/scripts/ci_install_rust.sh index 724207fd782..519155dfbe8 100755 --- a/scripts/ci_install_rust.sh +++ b/scripts/ci_install_rust.sh @@ -1,9 +1,14 @@ #!/bin/bash set -euxo pipefail -# these are required for actix -apt-get update -apt-get install -y libssl-dev pkg-config +# Check if sudo is available +if command -v sudo >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y libssl-dev pkg-config +else + apt-get update + apt-get install -y libssl-dev pkg-config +fi # Install rustup (Rust installer and version manager) curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y diff --git a/scripts/deprecated/test_jump_forward.py b/scripts/deprecated/test_jump_forward.py index 60074a04005..315a50b5ba7 100644 --- a/scripts/deprecated/test_jump_forward.py +++ b/scripts/deprecated/test_jump_forward.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, constr import sglang as sgl -from sglang.srt.constrained import build_regex_from_object +from sglang.srt.constrained.outlines_backend import build_regex_from_object from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index 53d08703e01..163a60f184b 100755 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -1,5 +1,14 @@ #!/bin/bash +# Check if sudo is available +if command -v sudo >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y lsof +else + apt-get update + apt-get install -y lsof +fi + # Show current GPU status nvidia-smi diff --git a/scripts/update_kernel_whl_index.py b/scripts/update_kernel_whl_index.py new file mode 100644 index 00000000000..a42969641f5 --- /dev/null +++ b/scripts/update_kernel_whl_index.py @@ -0,0 +1,16 @@ +# Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py + +import hashlib +import pathlib +import re + +for path in sorted(pathlib.Path("sgl-kernel/dist").glob("*.whl")): + with open(path, "rb") as f: + sha256 = hashlib.sha256(f.read()).hexdigest() + ver = re.findall(r"sgl_kernel-([0-9.]+(?:\.post[0-9]+)?)-", path.name)[0] + index_dir = pathlib.Path(f"sgl-whl/cu118/sgl-kernel") + index_dir.mkdir(exist_ok=True) + base_url = "https://github.com/sgl-project/whl/releases/download" + full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}" + with (index_dir / "index.html").open("a") as f: + f.write(f'{path.name}
\n') diff --git a/sgl-kernel/3rdparty/cccl b/sgl-kernel/3rdparty/cccl new file mode 160000 index 00000000000..b5fe509fd11 --- /dev/null +++ b/sgl-kernel/3rdparty/cccl @@ -0,0 +1 @@ +Subproject commit b5fe509fd11a925f90d6495176707cc1184eed9d diff --git a/sgl-kernel/3rdparty/cub b/sgl-kernel/3rdparty/cub deleted file mode 160000 index 0fc3c370163..00000000000 --- a/sgl-kernel/3rdparty/cub +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 0fc3c3701632a4be906765b73be20a9ad0da603d diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer new file mode 160000 index 00000000000..4f1f08989c7 --- /dev/null +++ b/sgl-kernel/3rdparty/flashinfer @@ -0,0 +1 @@ +Subproject commit 4f1f08989c71f92df181e346548c2ca48ae6daf5 diff --git a/sgl-kernel/3rdparty/turbomind b/sgl-kernel/3rdparty/turbomind new file mode 160000 index 00000000000..0c9d0c724a9 --- /dev/null +++ b/sgl-kernel/3rdparty/turbomind @@ -0,0 +1 @@ +Subproject commit 0c9d0c724a99974ca3af0c12b24ef8a0444c4fd9 diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index fac4c5c56c8..1384f1bcd81 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -1,7 +1,7 @@ -.PHONY: tree ln submodule install build clean test format +.PHONY: tree ln submodule install build clean rebuild test format tree: - @tree --prune -I "__pycache__|*.egg-info|*.so|build" + @tree --prune -I "__pycache__|*.egg-info|*.so|build|3rdparty|dist" submodule: @git submodule update --init --recursive @@ -13,13 +13,16 @@ install: submodule @pip install -e . build: submodule - @export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel + @rm -rf dist/* || true && export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel && pip3 install dist/*whl --force-reinstall --no-deps clean: @rm -rf build dist *.egg-info +rebuild: clean submodule build + @echo "Succeed to rebuild" + test: - @pytest tests/ + @find tests -name "test_*.py" | xargs -n 1 python3 format: @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 857cae366d8..0572f9758ab 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -1,5 +1,19 @@ # SGL Kernel -Kernel Library for SGLang +[Kernel Library](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) for SGLang [![PyPI](https://img.shields.io/pypi/v/sgl-kernel)](https://pypi.org/project/sgl-kernel) + +## Installation + +For CUDA 11.8: + +```bash +pip3 install sgl-kernel -i https://docs.sglang.ai/whl/cu118 +``` + +For CUDA 12.1 or CUDA 12.4: + +```bash +pip3 install sgl-kernel +``` diff --git a/sgl-kernel/THIRDPARTYNOTICES.txt b/sgl-kernel/THIRDPARTYNOTICES.txt new file mode 100644 index 00000000000..c930aa5dd3d --- /dev/null +++ b/sgl-kernel/THIRDPARTYNOTICES.txt @@ -0,0 +1,225 @@ +Notice for flashinfer-ai/flashinfer +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------------------- +Some of the code in this project are adapted from other open-source projects with different +licenses. This product also bundles some third-party components under other open source licenses. +This section summarizes those components and their licenses. +See licenses/ for text of these licenses. + +BSD 3-Clause License +-------------------- + +include/flashinfer/attention/hopper/epilogue.cuh +include/flashinfer/attention/hopper/mainloop.cuh +include/flashinfer/attention/hopper/kernel_traits.cuh +include/flashinfer/attention/hopper/named_barrier.cuh +include/flashinfer/attention/hopper/tile_scheduler.cuh +include/flashinfer/attention/hopper/utils.cuh + +BSD 3-Clause "New" License +-------------------------- + +3rdparty/cutlass +include/flashinfer/attention/hopper/block_sparse_gather.cuh diff --git a/sgl-kernel/benchmark/bench_fp8_gemm.py b/sgl-kernel/benchmark/bench_fp8_gemm.py new file mode 100644 index 00000000000..c3f80475356 --- /dev/null +++ b/sgl-kernel/benchmark/bench_fp8_gemm.py @@ -0,0 +1,164 @@ +import argparse +import copy +import itertools + +import torch +import triton +from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant + +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], + x_log=False, + line_arg="provider", + line_vals=[ + "vllm-fp8-fp16", + "vllm-fp8-bf16", + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ], + line_names=[ + "vllm-fp8-fp16", + "vllm-fp8-bf16", + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ], + styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], + ylabel="GB/s", + plot_name="fp8 scaled matmul", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + # M, N, K = batch_size, 4096, 8192 + M = batch_size + a = torch.ones((M, K), device="cuda") * 5.0 + b = torch.ones((N, K), device="cuda") * 5.0 + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + b_fp8 = b_fp8.t() + quantiles = [0.5, 0.2, 0.8] + + dtype = torch.float16 if "fp16" in provider else torch.bfloat16 + + if "vllm-fp8" in provider: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), + quantiles=quantiles, + ) + elif "sglang-fp8" in provider: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sgl_scaled_mm( + a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None + ), + quantiles=quantiles, + ) + + gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run( + print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K + ) + + print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_int8_gemm.py b/sgl-kernel/benchmark/bench_int8_gemm.py index 2657c616cf3..c5a709393c1 100644 --- a/sgl-kernel/benchmark/bench_int8_gemm.py +++ b/sgl-kernel/benchmark/bench_int8_gemm.py @@ -1,3 +1,7 @@ +import argparse +import copy +import itertools + import torch import triton from sgl_kernel import int8_scaled_mm @@ -8,6 +12,56 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor: return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], @@ -22,8 +76,8 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor: args={}, ) ) -def benchmark(batch_size, provider): - M, N, K = batch_size, 4096, 8192 +def benchmark(batch_size, provider, N, K): + M = batch_size a = to_int8(torch.randn((M, K), device="cuda") * 5) b = to_int8(torch.randn((N, K), device="cuda").t() * 5) scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) @@ -52,4 +106,41 @@ def benchmark(batch_size, provider): return gbps(ms), gbps(max_ms), gbps(min_ms) -benchmark.run(print_data=True, show_plots=True, save_path="bench_int8_res") +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run( + print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K + ) + + print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_lightning_attention_decode.py b/sgl-kernel/benchmark/bench_lightning_attention_decode.py new file mode 100644 index 00000000000..24872e61a4d --- /dev/null +++ b/sgl-kernel/benchmark/bench_lightning_attention_decode.py @@ -0,0 +1,299 @@ +import itertools +import math + +import torch +import triton +import triton.language as tl +from sgl_kernel import lightning_attention_decode + + +def next_power_of_2(n): + return 2 ** (int(math.ceil(math.log(n, 2)))) + + +@triton.jit +def _decode_kernel( + Q, + K, + V, + KV, + Out, + S, + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + d_original: tl.constexpr, + e: tl.constexpr, + e_original: tl.constexpr, +): + off_bh = tl.program_id(0) + off_h = off_bh % h + + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + kv_offset = off_bh * d * e + + s = tl.load(S + off_h) + ratio = tl.exp(-s) + + d_idx = tl.arange(0, d) + e_idx = tl.arange(0, e) + + # Create masks for original dimensions + d_mask = d_idx < d_original + e_mask = e_idx < e_original + + # Load with masking + q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0) + k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0) + v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0) + + # Load KV with 2D masking + kv = tl.load( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + mask=(d_mask[:, None] & e_mask[None, :]), + other=0.0, + ) + + # Compute outer product using element-wise operations + k_v_prod = k[:, None] * v[None, :] + kv = ratio * kv + k_v_prod + + # Store KV with 2D masking + tl.store( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + kv.to(KV.dtype.element_ty), + mask=(d_mask[:, None] & e_mask[None, :]), + ) + + # Compute matrix-vector multiplication using element-wise operations and reduction + o = tl.sum(q[:, None] * kv, axis=0) + + # Store output with masking + tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask) + + +def triton_lightning_attn_decode(q, k, v, kv, s): + """Triton implementation of Lightning Attention decode operation""" + b, h, n, d = q.shape + e = v.shape[-1] + assert n == 1, "Sequence length must be 1 in decode mode" + + # Get padded dimensions (power of 2) + d_padded = next_power_of_2(d) + e_padded = next_power_of_2(e) + + # Create output tensor (padded) + o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + + # Create padded tensors without actually padding the data + q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device) + k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device) + v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + kv_padded = torch.empty( + b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device + ) + + # Copy data to padded tensors + q_padded[..., :d] = q + k_padded[..., :d] = k + v_padded[..., :e] = v + kv_padded[..., :d, :e] = kv + + # Launch kernel + grid = (b * h, 1) + _decode_kernel[grid]( + q_padded, + k_padded, + v_padded, + kv_padded, + o_padded, + s, + b=b, + h=h, + n=n, + d=d_padded, + d_original=d, + e=e_padded, + e_original=e, + ) + + # Get unpadded outputs + o = o_padded[..., :e] + kv_out = kv_padded[..., :d, :e] + + return o, kv_out + + +def lightning_attention_decode_naive(q, k, v, past_kv, slope): + """Naive implementation of lightning attention decode""" + original_dtype = q.dtype + ratio = torch.exp(-slope) # [h, 1, 1] + + kv = past_kv + b, h, n, d = q.shape + + output = [] + for i in range(n): + kv = ratio * kv.to(torch.float32) + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", + q[:, :, i : i + 1].to(torch.float32), + kv.to(torch.float32), + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + return output.to(original_dtype), kv + + +def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv): + return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + +def calculate_diff(batch_size): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_heads = 64 + head_dim = 96 + seq_len = 1 + + q = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + k = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + output_naive, new_kv_naive = lightning_attention_decode_naive( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ) + + output_kernel = torch.empty_like(output_naive) + new_kv_kernel = torch.empty_like(new_kv_naive) + lightning_attention_decode_kernel( + q.clone(), + k.clone(), + v.clone(), + past_kv.clone(), + slope.clone(), + output_kernel, + new_kv_kernel, + ) + + output_triton, new_kv_triton = triton_lightning_attn_decode( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ) + + if ( + torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2) + and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2) + and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2) + and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2) + ): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [i for i in range(1, 65)] # 1 to 128 +configs = [(bs,) for bs in batch_size_range] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "kernel", "triton"], + line_names=["PyTorch Naive", "SGL Kernel", "Triton"], + styles=[("blue", "-"), ("red", "-"), ("green", "-")], + ylabel="us", + plot_name="lightning-attention-decode-performance", + args={}, + ) +) +def benchmark(batch_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_heads = 64 + head_dim = 96 + seq_len = 1 + + q = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + k = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "naive": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: lightning_attention_decode_naive( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ), + quantiles=quantiles, + ) + elif provider == "kernel": + output = torch.empty( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: lightning_attention_decode_kernel( + q.clone(), + k.clone(), + v.clone(), + past_kv.clone(), + slope.clone(), + output, + new_kv, + ), + quantiles=quantiles, + ) + elif provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_lightning_attn_decode( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/lightning_attention_decode_sgl/", + help="Path to save lightning attention decode benchmark results", + ) + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=4) + + # Run performance benchmark + benchmark.run(print_data=True) diff --git a/sgl-kernel/benchmark/bench_sampling_scaling_penalties.py b/sgl-kernel/benchmark/bench_sampling_scaling_penalties.py deleted file mode 100644 index 000dab0d8e9..00000000000 --- a/sgl-kernel/benchmark/bench_sampling_scaling_penalties.py +++ /dev/null @@ -1,159 +0,0 @@ -import itertools - -import torch -import triton -from sgl_kernel import sampling_scaling_penalties - - -def sampling_scaling_penalties_naive(logits, scaling_penalties): - return torch.where( - logits > 0, logits / scaling_penalties, logits * scaling_penalties - ) - - -def sampling_scaling_penalties_kernel(logits, scaling_penalties): - return sampling_scaling_penalties(logits, scaling_penalties) - - -def test_memory(func, _iter): - total_mem = [] - - for _ in range(_iter): - torch.cuda.memory.reset_peak_memory_stats() - func() - mem = torch.cuda.max_memory_allocated() / (2**20) - total_mem.append(mem) - - return sum(total_mem) / len(total_mem) - - -def calculate_diff(batch_size, vocab_size): - dtype = torch.bfloat16 - device = torch.device("cuda") - - logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) - scaling_penalties = ( - torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 - ) - - output_naive = sampling_scaling_penalties_naive( - logits.clone(), scaling_penalties.clone() - ) - output_kernel = sampling_scaling_penalties_kernel( - logits.clone(), scaling_penalties.clone() - ) - - print(f"Naive output={output_naive}") - print(f"Kernel output={output_kernel}") - - if torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2): - print("✅ Both implementations match") - else: - print("❌ Implementations differ") - - -batch_size_range = [2**i for i in range(0, 12)] -vocab_size_range = [2**i for i in range(10, 17)] -configs = list(itertools.product(batch_size_range, vocab_size_range)) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size", "vocab_size"], - x_vals=[list(_) for _ in configs], - line_arg="provider", - line_vals=["naive", "kernel"], - line_names=["PyTorch Naive", "SGL Kernel"], - styles=[("blue", "-"), ("red", "-")], - ylabel="us", - plot_name="sampling-scaling-penalties-performance", - args={}, - ) -) -def benchmark(batch_size, vocab_size, provider): - dtype = torch.bfloat16 - device = torch.device("cuda") - - logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) - scaling_penalties = ( - torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 - ) - - quantiles = [0.5, 0.2, 0.8] - - if provider == "naive": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: sampling_scaling_penalties_naive( - logits.clone(), - scaling_penalties.clone(), - ), - quantiles=quantiles, - ) - else: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: sampling_scaling_penalties_kernel( - logits.clone(), - scaling_penalties.clone(), - ), - quantiles=quantiles, - ) - - return 1000 * ms, 1000 * max_ms, 1000 * min_ms - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size", "vocab_size"], - x_vals=[list(_) for _ in configs], - line_arg="provider", - line_vals=["naive", "kernel"], - line_names=["PyTorch Naive", "SGL Kernel"], - styles=[("blue", "-"), ("red", "-")], - ylabel="GPU memory usage (MB)", - plot_name="sampling-scaling-penalties-memory", - args={}, - ) -) -def benchmark_memory(batch_size, vocab_size, provider): - dtype = torch.bfloat16 - device = torch.device("cuda") - - print( - f"Running memory benchmark with batch_size={batch_size}, vocab_size={vocab_size}, provider={provider}" - ) - - def run_kernel(): - logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) - scaling_penalties = ( - torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 - ) - - if provider == "naive": - return sampling_scaling_penalties_naive(logits, scaling_penalties) - else: - return sampling_scaling_penalties_kernel(logits, scaling_penalties) - - mem = test_memory(run_kernel, _iter=10) - return mem - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--save_path", - type=str, - default="./configs/benchmark_ops/sampling_scaling_penalties/", - help="Path to save sampling_scaling_penalties benchmark results", - ) - args = parser.parse_args() - - # Run correctness test - calculate_diff(batch_size=4, vocab_size=4096) - - # Run performance benchmark - benchmark.run(print_data=True, save_path=args.save_path) - - # Run memory benchmark - benchmark_memory.run(print_data=True, save_path=args.save_path) diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh index 55ce9df7f33..ffa798d145a 100755 --- a/sgl-kernel/build.sh +++ b/sgl-kernel/build.sh @@ -4,13 +4,23 @@ PYTHON_VERSION=$1 CUDA_VERSION=$2 PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.} +if (( ${CUDA_VERSION%.*} < 12 )); then + ENABLE_SM90A=0 +else + ENABLE_SM90A=1 +fi + docker run --rm \ -v "$(pwd)":/sgl-kernel \ pytorch/manylinux-builder:cuda${CUDA_VERSION} \ bash -c " - ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.4.0 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ + ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ + ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja && \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \ export CUDA_VERSION=${CUDA_VERSION} && \ + export SGL_KERNEL_ENABLE_BF16=1 && \ + export SGL_KERNEL_ENABLE_FP8=1 && \ + export SGL_KERNEL_ENABLE_SM90A=${ENABLE_SM90A} && \ mkdir -p /usr/lib/x86_64-linux-gnu/ && \ ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \ cd /sgl-kernel && \ diff --git a/sgl-kernel/developer_guide.md b/sgl-kernel/developer_guide.md new file mode 100644 index 00000000000..2b9859d948f --- /dev/null +++ b/sgl-kernel/developer_guide.md @@ -0,0 +1,55 @@ +# Developer Guide for sgl-kernel + +## Development Environment Setup + +Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer/development_guide_using_docker.md#setup-docker-container). + +Create and enter development container: +```bash +docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh +docker exec -it sglang_zhyncs /bin/zsh +``` + +## Project Structure + +### Dependencies + +Third-party libraries: + +- [CCCL](https://github.com/NVIDIA/cccl) +- [CUTLASS](https://github.com/NVIDIA/cutlass) +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) +- [TurboMind](https://github.com/InternLM/turbomind) + +### Kernel Development + +Steps to add a new kernel: + +1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc) +2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h) +3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc) +4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py) +5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py) +6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source + +### Build & Install + +Development build: + +```bash +make build +``` + +Note: + +The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`. + +### Testing & Benchmarking + +1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests) +2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark) +3. Run test suite + +### Release new version + +Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/version.py) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index b0554bd8fed..aca6f045054 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,22 +4,20 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post14" +version = "0.0.3" description = "Kernel Library for SGLang" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = { file = "LICENSE" } classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Environment :: GPU :: NVIDIA CUDA" ] -dependencies = [ - "torch", -] +dependencies = [] [project.urls] -"Homepage" = "https://github.com/sgl-project/sglang" +"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" [tool.setuptools] diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 33e4abe1b23..f887f5c19f0 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -1,19 +1,15 @@ +import multiprocessing +import os from pathlib import Path +import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension root = Path(__file__).parent.resolve() -def get_version(): - with open(root / "pyproject.toml") as f: - for line in f: - if line.startswith("version"): - return line.split("=")[1].strip().strip('"') - - -def update_wheel_platform_tag(): +def _update_wheel_platform_tag(): wheel_dir = Path("dist") if wheel_dir.exists() and wheel_dir.is_dir(): old_wheel = next(wheel_dir.glob("*.whl")) @@ -23,13 +19,48 @@ def update_wheel_platform_tag(): old_wheel.rename(new_wheel) -cutlass = root / "3rdparty" / "cutlass" +def _get_cuda_version(): + if torch.version.cuda: + return tuple(map(int, torch.version.cuda.split("."))) + return (0, 0) + + +def _get_device_sm(): + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + return 0 + + +def _get_version(): + with open(root / "pyproject.toml") as f: + for line in f: + if line.startswith("version"): + return line.split("=")[1].strip().strip('"') + + +operator_namespace = "sgl_kernels" +cutlass_default = root / "3rdparty" / "cutlass" +cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) +flashinfer = root / "3rdparty" / "flashinfer" +turbomind = root / "3rdparty" / "turbomind" include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", + root / "src" / "sgl-kernel" / "include", root / "src" / "sgl-kernel" / "csrc", + flashinfer.resolve() / "include", + flashinfer.resolve() / "include" / "gemm", + flashinfer.resolve() / "csrc", + "cublas", + "cublasLt", + turbomind.resolve(), + turbomind.resolve() / "src", ] + nvcc_flags = [ + "-DNDEBUG", + f"-DOPERATOR_NAMESPACE={operator_namespace}", "-O3", "-Xcompiler", "-fPIC", @@ -37,23 +68,76 @@ def update_wheel_platform_tag(): "-gencode=arch=compute_80,code=sm_80", "-gencode=arch=compute_89,code=sm_89", "-gencode=arch=compute_90,code=sm_90", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF2_OPERATORS__", + "-std=c++17", + "-use_fast_math", + "-DFLASHINFER_ENABLE_F16", + "-Xcompiler=-Wconversion", + "-Xcompiler=-fno-strict-aliasing", +] +nvcc_flags_fp8 = [ + "-DFLASHINFER_ENABLE_FP8", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", +] + +sources = [ + "src/sgl-kernel/torch_extension.cc", + "src/sgl-kernel/csrc/trt_reduce_internal.cu", + "src/sgl-kernel/csrc/trt_reduce_kernel.cu", + "src/sgl-kernel/csrc/moe_align_kernel.cu", + "src/sgl-kernel/csrc/int8_gemm_kernel.cu", + "src/sgl-kernel/csrc/fp8_gemm_kernel.cu", + "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", + "src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu", + "3rdparty/flashinfer/csrc/activation.cu", + "3rdparty/flashinfer/csrc/bmm_fp8.cu", + "3rdparty/flashinfer/csrc/norm.cu", + "3rdparty/flashinfer/csrc/sampling.cu", + "3rdparty/flashinfer/csrc/renorm.cu", + "3rdparty/flashinfer/csrc/rope.cu", ] + +enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1" +enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1" +enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1" +cuda_version = _get_cuda_version() +sm_version = _get_device_sm() + +if torch.cuda.is_available(): + if cuda_version >= (12, 0) and sm_version >= 90: + nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + if sm_version >= 90: + nvcc_flags.extend(nvcc_flags_fp8) + if sm_version >= 80: + nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") +else: + # compilation environment without GPU + if enable_sm90a: + nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + if enable_fp8: + nvcc_flags.extend(nvcc_flags_fp8) + if enable_bf16: + nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") + +for flag in [ + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", +]: + try: + torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag) + except ValueError: + pass + cxx_flags = ["-O3"] -libraries = ["c10", "torch", "torch_python", "cuda"] +libraries = ["c10", "torch", "torch_python", "cuda", "cublas", "cublasLt"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] + ext_modules = [ CUDAExtension( name="sgl_kernel.ops._kernels", - sources=[ - "src/sgl-kernel/csrc/trt_reduce_internal.cu", - "src/sgl-kernel/csrc/trt_reduce_kernel.cu", - "src/sgl-kernel/csrc/moe_align_kernel.cu", - "src/sgl-kernel/csrc/int8_gemm_kernel.cu", - "src/sgl-kernel/csrc/sampling_scaling_penalties.cu", - "src/sgl-kernel/csrc/sgl_kernel_ops.cu", - ], + sources=sources, include_dirs=include_dirs, extra_compile_args={ "nvcc": nvcc_flags, @@ -61,17 +145,22 @@ def update_wheel_platform_tag(): }, libraries=libraries, extra_link_args=extra_link_args, + py_limited_api=True, ), ] setup( name="sgl-kernel", - version=get_version(), + version=_get_version(), packages=find_packages(), package_dir={"": "src"}, ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension}, - install_requires=["torch"], + cmdclass={ + "build_ext": BuildExtension.with_options( + use_ninja=True, max_jobs=multiprocessing.cpu_count() + ) + }, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) -update_wheel_platform_tag() +_update_wheel_platform_tag() diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 0c744982dd8..a3d35072d03 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,21 +1,51 @@ from sgl_kernel.ops import ( + apply_rope_with_cos_sin_cache_inplace, + bmm_fp8, custom_dispose, custom_reduce, + fp8_scaled_mm, + fused_add_rmsnorm, + gelu_and_mul, + gelu_tanh_and_mul, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, get_graph_buffer_ipc_meta, init_custom_reduce, int8_scaled_mm, + lightning_attention_decode, + min_p_sampling_from_probs, moe_align_block_size, register_graph_buffers, + rmsnorm, sampling_scaling_penalties, + silu_and_mul, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, ) __all__ = [ - "moe_align_block_size", - "init_custom_reduce", + "apply_rope_with_cos_sin_cache_inplace", + "bmm_fp8", "custom_dispose", "custom_reduce", - "int8_scaled_mm", - "sampling_scaling_penalties", + "fp8_scaled_mm", + "fused_add_rmsnorm", + "gelu_and_mul", + "gelu_tanh_and_mul", + "gemma_fused_add_rmsnorm", + "gemma_rmsnorm", "get_graph_buffer_ipc_meta", + "init_custom_reduce", + "int8_scaled_mm", + "lightning_attention_decode", + "min_p_sampling_from_probs", + "moe_align_block_size", "register_graph_buffers", + "rmsnorm", + "sampling_scaling_penalties", + "silu_and_mul", + "top_k_renorm_prob", + "top_k_top_p_sampling_from_probs", + "top_p_renorm_prob", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h index a9deeb9a7da..c83cf49ad83 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h @@ -3,11 +3,8 @@ #pragma once -#include "cutlass/arch/memory.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/numeric_conversion.h" +#include +#include namespace cutlass { namespace epilogue { diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h index 10be552a8ec..33e82decc2b 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h @@ -2,16 +2,9 @@ // https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h #pragma once -#include "cutlass/arch/arch.h" -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/gemm/device/default_gemm_configuration.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/default_gemm_universal.h" -#include "cutlass/gemm/kernel/gemm_universal.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -#include "cutlass/numeric_types.h" -#include "cutlass/trace.h" +#include +#include +#include //////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h index cf0b9cfa3e9..674e191a077 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h @@ -3,14 +3,11 @@ #pragma once -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" -#include "cutlass/trace.h" -#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h" +#include +#include +#include +#include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu new file mode 100644 index 00000000000..3e33e143c0c --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu @@ -0,0 +1,624 @@ +// Adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +using namespace cute; + +#if defined CUDA_VERSION && CUDA_VERSION >= 12040 +template typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT, + typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>> +struct DeviceGemmFp8RowwiseSm89 { + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + using ElementA = ElementType; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementType; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = OutElementType; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementOutput = OutElementType; + using LayoutOutput = cutlass::layout::RowMajor; + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = AccumElementType; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm89; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + // Number of epilogue stages in EVT + static constexpr int EVTEpilogueStages = 1; + + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout; + + // Definition of EVT + using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>; + using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; + using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeAScale = + cutlass::epilogue::threadblock::VisitorCompute; + using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast>; + using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT; + + // With bias + using biasSrc = + cutlass::epilogue::threadblock::VisitorRowBroadcast>; + using ComputeAScaleWithBias = + cutlass::epilogue::threadblock::VisitorCompute; + using EpilogueAScaleWithBias = + cutlass::epilogue::threadblock::Sm80EVT; + + using dTar = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride>; + using EpilogueStore = + typename cutlass::platform::conditional, + cutlass::epilogue::threadblock::Sm80EVT>::type; + + using EpilogueOp = EpilogueStore; + + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB, + cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator, + ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp, + ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); + ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput const* ptr_bias = nullptr; + if constexpr (WithBias) { + TORCH_CHECK(bias.has_value()) + ptr_bias = reinterpret_cast(bias.value().data_ptr()); + } + ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); + ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); + ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); + + typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode + {m, n, k}, // Problem size + 1, // Split-k factor + {}, // Epilogue args + ptr_a, // a pointer + ptr_b, // b pointer + nullptr, // c pointer (unused) + nullptr, // d pointer (unused) + m * k, // batch stride a (unused) + n * k, // batch stride b (unused) + m * n, // batch stride c (unused) + m * n, // batch stride d (unused) + lda, // stride a + ldb, // stride b + ldc, // stride c (unused) + ldc); // stride d (unused) + if constexpr (WithBias) { + args.epilogue = {{ + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; + } else { + args.epilogue = {{ + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; + } + + return args; +} + +template +void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + auto args = prepare_sm89_fp8_args(out, a, b, scales_a, scales_b, bias); + Gemm gemm_op; + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + + auto status = gemm_op(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess) +} + +template +void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + if (bias) { + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } +} + +template +void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + uint32_t const m = a.size(0); + uint32_t const n = out.size(1); + + if (m == 1) { + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 16) { + // M in (1, 16] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 64) { + // M in (16, 64] + if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 128) { + // M in (64, 128] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 256) { + // M in (128, 256] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 512) { + // M in (256, 512) + if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } + } else { + // M in (512, inf) + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + } + } +} +#endif + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +template +struct DeviceGemmFp8RowwiseSm90 { + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + // A matrix configuration + using ElementA = ElementType; // Element type for A matrix operand + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = ElementType; // Element type for B matrix operand + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementC = void; // Element type for C matrix operands + using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrices in + // units of elements (up to 16 bytes) + + // Output matrix configuration + using ElementOutput = OutElementType; // Element type for output matrix operands + using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // // Auxiliary matrix configuration and other fusion types + // using ElementBias = float; + + // Multiply-accumulate blocking/pipelining details + using ElementAccumulator = AccumElementType; // Element type for internal accumulation + using ElementCompute = float; // Element type for compute + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size + + static constexpr bool PONG = false; + static constexpr bool FAST_ACCUM = true; + static constexpr bool USE_BIAS = false; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default + // setting in the Collective Builder + // Implement rowwise scaling epilogue. + using XScale = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using WScale = + cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; + + // With bias + using ComputeWithBias = + cutlass::epilogue::fusion::Sm90Compute; + using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = typename cutlass::platform::conditional::type; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC, + AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized, + EpilogueEVT>::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + + using SlowAccum = DefaultSchedule; + using FastAccum = FastPongSchedule; // Default apply Pingpong + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduleType>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); + ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput const* ptr_bias = nullptr; + if constexpr (WithBias) { + TORCH_CHECK(bias.has_value()) + ptr_bias = reinterpret_cast(bias.value().data_ptr()); + } + ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); + ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); + ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); + StrideC stride_c; + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); + typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {ptr_a, stride_a, ptr_b, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + ptr_d, + stride_d}}; + if constexpr (WithBias) { + args.epilogue.thread = { + {ptr_scales_a}, + { + {ptr_scales_b}, + {}, // Accumulator + {} // Multiplies + }, + {ptr_bias}, + {}, // Multiplies + }; + } else { + args.epilogue.thread = { + {ptr_scales_a}, + { + {ptr_scales_b}, + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + + return args; +} + +template +void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + auto args = prepare_sm90_fp8_args(out, a, b, scales_a, scales_b, bias); + Gemm gemm_op; + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + + auto status = gemm_op.run(args, workspace.data_ptr(), stream); + + TORCH_CHECK(status == cutlass::Status::kSuccess) +} + +template +void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias, bool fast_accum = true, + bool use_persistent = false) { + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + + if (bias) { + using Gemm = + typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = + typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } +} + +template +void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + uint32_t const m = a.size(0); + using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using PersistentTileScheduler = cutlass::gemm::PersistentScheduler; + using BasicTileScheduler = void; + if (m <= 1) { + return sm90_fp8_dispatch_bias, Shape<_1, _8, _1>, FastBasicScheduler, + BasicTileScheduler>(out, a, b, scales_a, scales_b, bias); + } + if (m <= 64) { + // m in [1, 64] + return sm90_fp8_dispatch_bias, Shape<_1, _4, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else if (m <= 256) { + // m in (64, 256] + return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else if (m <= 1024) { + // m in (256, 1024] + return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else { + // m in (1024, inf) + return sm90_fp8_dispatch_bias, Shape<_2, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } +} +#endif + +torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias) { + TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); + TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); + TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); + TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); + TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); + TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); + TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); + + TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, + "mat_a must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, + "mat_b must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); + TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); + TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); + + TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched"); + TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched"); + TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous"); + TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32"); + + if (bias) { + TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched"); + TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous"); + TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype"); + } + + torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype)); + TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment"); + + auto sm_version = getSMVersion(); + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + if (sm_version >= 90) { + if (out_dtype == torch::kBFloat16) { + sm90_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm90_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } + return out; + } +#endif + +#if defined CUDA_VERSION && CUDA_VERSION >= 12040 + if (sm_version == 89) { + if (out_dtype == torch::kBFloat16) { + sm89_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm89_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } + return out; + } +#endif + + TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu new file mode 100644 index 00000000000..4c4ecb966ee --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu @@ -0,0 +1,140 @@ +// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh +// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu +// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0 + +#include + +#include +#include +#include +#include + +#include "utils.h" + +using namespace flashinfer; + +template +__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, T* __restrict__ weight, + const uint32_t d, float eps) { + const uint32_t bx = blockIdx.x; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t warp_size = 32; + const uint32_t num_warps = blockDim.y; + const uint32_t thread_id = tx + ty * warp_size; + const uint32_t num_threads = num_warps * warp_size; + const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); + extern __shared__ float smem[]; + + float sum_sq = 0.f; + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + input_vec.fill(0.f); + vec_t residual_vec; + residual_vec.fill(0.f); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + float x = float(input_vec[j]); + x += float(residual_vec[j]); + sum_sq += x * x; + residual_vec[j] = (T)x; + } + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } + } + + // first, warp reduce sum +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + + smem[ty] = sum_sq; + __syncthreads(); + // then, cross warp reduce sum using only the first warp + if (ty == 0) { + sum_sq = (tx < num_warps) ? smem[tx] : 0.f; +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + smem[0] = sum_sq; + } + __syncthreads(); + + float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + vec_t weight_vec; + vec_t residual_vec; + input_vec.fill(0.f); + weight_vec.fill(0.f); + residual_vec.fill(0.f); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]); + } + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } + } +} + +template +cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, float eps = 1e-5, + cudaStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t block_size = std::min(1024, d / vec_size); + const uint32_t num_warps = ceil_div(block_size, 32); + dim3 nblks(batch_size); + dim3 nthrs(32, num_warps); + const uint32_t smem_size = num_warps * sizeof(float); + void* args[] = {&input, &residual, &weight, &d, &eps}; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = FusedAddRMSNormKernel; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + + return cudaSuccess; +} + +void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) { + CHECK_INPUT(input); + CHECK_INPUT(residual); + CHECK_INPUT(weight); + auto device = input.device(); + CHECK_EQ(residual.device(), device); + CHECK_EQ(weight.device(), device); + CHECK_DIM(2, input); // input: (batch_size, hidden_size) + CHECK_DIM(2, residual); // residual: (batch_size, hidden_size) + CHECK_DIM(1, weight); // weight: (hidden_size) + CHECK_EQ(input.size(0), residual.size(0)); + CHECK_EQ(input.size(1), residual.size(1)); + CHECK_EQ(input.size(1), weight.size(0)); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + + cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); + // support float16, bfloat16 and float32 + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + cudaError_t status = + FusedAddRMSNorm(static_cast(input.data_ptr()), static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); + return true; + }); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu index cce32c2d894..c77851c32b6 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu @@ -3,12 +3,22 @@ #include #include #include +#include #include +#include +#include +#include +#include +#include +#include + #include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h" #include "cutlass_extensions/gemm/gemm_universal_base_compat.h" #include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h" -#include "utils.hpp" +#include "utils.h" + +using namespace cute; template @@ -166,6 +176,186 @@ void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const t } } +template +void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ArchTag = cutlass::arch::Sm90; + + using ElementAccumulator = int32_t; + using ElementCompute = float; + using ElementInputA = int8_t; + using ElementInputB = int8_t; + + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + using TileSchedulerType = cutlass::gemm::PersistentScheduler; + + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, + Stride, Int<0>, Int<0>>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, + Stride, Int<1>, Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, + Stride, Int<1>, Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + // Scale + using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; + + // With bias + using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute; + using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = typename cutlass::platform::conditional::type; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, ElementOutput, cutlass::layout::RowMajor, AlignmentC, ElementOutput, + cutlass::layout::RowMajor, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; + + using Stages = cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementInputA, cutlass::layout::RowMajor, AlignmentA, ElementInputB, + cutlass::layout::ColumnMajor, AlignmentB, ElementAccumulator, TileShape, ClusterShape, Stages, + MainloopScheduleType>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + Gemm gemm_op; + + int m = mat_a.size(0); + int k = mat_a.size(1); + int n = mat_b.size(1); + + auto a_ptr = static_cast(mat_a.data_ptr()); + auto b_ptr = static_cast(mat_b.data_ptr()); + auto o_ptr = static_cast(out.data_ptr()); + + auto a_s_ptr = static_cast(scales_a.data_ptr()); + auto b_s_ptr = static_cast(scales_b.data_ptr()); + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); + StrideC stride_c; + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); + + typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {a_ptr, stride_a, b_ptr, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + o_ptr, + stride_d}}; + + if constexpr (WithBias) { + ElementOutput* bias_ptr = static_cast(bias->data_ptr()); + args.epilogue.thread = { + {a_s_ptr}, + {{b_s_ptr}, {}, {}}, + {bias_ptr}, + {}, + }; + } else { + args.epilogue.thread = { + {a_s_ptr}, + {{b_s_ptr}, {}, {}}, + {}, + }; + } + + auto workspace = torch::empty(gemm_op.get_workspace_size(args), + torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); + + auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess, + "gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); + + auto status = gemm_op(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); +} + +template +void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + if (bias) { + cutlass_int8_scaled_mm_sm90( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm_sm90( + out, mat_a, mat_b, scales_a, scales_b, bias); + } +} + +template +void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + int m = mat_a.size(0); + int n = mat_b.size(1); + if (m <= 32) { + if (n < 8192) { + return sm90_dispatch_bias, Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 64) { + if (n < 8192) { + return sm90_dispatch_bias, Shape<_1, _4, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_1, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 128) { + if (n <= 4096) { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, + bias); + } +} + torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const c10::optional& bias) { @@ -204,7 +394,24 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma TORCH_CHECK(out_dtype == torch::kHalf, "out_dtype must be Half for SM75"); sm75_dispatch_shape>( out, mat_a, mat_b, scales_a, scales_b, bias); - } else if (sm_version >= 80 && sm_version <= 90) { + } else if (sm_version >= 80 && sm_version < 90) { + if (out_dtype == torch::kBFloat16) { + sm80_dispatch_shape>( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm80_dispatch_shape>( + out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (sm_version == 90) { +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + // cutlass 3.x + if (out_dtype == torch::kBFloat16) { + sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } +#else + // fallback to cutlass 2.x if (out_dtype == torch::kBFloat16) { sm80_dispatch_shape>( out, mat_a, mat_b, scales_a, scales_b, bias); @@ -212,6 +419,7 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma sm80_dispatch_shape>( out, mat_a, mat_b, scales_a, scales_b, bias); } +#endif } else { TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability."); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu new file mode 100644 index 00000000000..e62a154cb18 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu @@ -0,0 +1,118 @@ +#include +#include +#include +#include +#include +#include + +#define THREADS_PER_BLOCK 128 + +template +__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d] + const T* __restrict__ k, // [b, h, 1, d] + const T* __restrict__ v, // [b, h, 1, e] + const float* __restrict__ past_kv, // [b, h, d, e] + const float* __restrict__ slope, // [h, 1, 1] + T* __restrict__ output, // [b, h, 1, e] + float* __restrict__ new_kv, // [b, h, d, e] + const int batch_size, const int num_heads, const int qk_dim, + const int v_dim) { + extern __shared__ char smem[]; + T* q_shared = reinterpret_cast(smem); + T* k_shared = reinterpret_cast(smem + qk_dim * sizeof(T)); + T* v_shared = reinterpret_cast(smem + 2 * qk_dim * sizeof(T)); + float* new_kv_shared = reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T)); + T* output_shared = + reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float)); + + const int32_t tid = threadIdx.x; + const int32_t current_head = blockIdx.x; + const int32_t b = current_head / num_heads; + const int32_t h = current_head % num_heads; + + if (b >= batch_size) return; + + const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim; + const int32_t v_offset = b * num_heads * v_dim + h * v_dim; + const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim; + + for (int d = tid; d < qk_dim; d += blockDim.x) { + q_shared[d] = q[qk_offset + d]; + k_shared[d] = k[qk_offset + d]; + } + for (int e = tid; e < v_dim; e += blockDim.x) { + v_shared[e] = v[v_offset + e]; + } + + __syncthreads(); + + const float ratio = expf(-1.0f * slope[h]); + + for (int d = tid; d < qk_dim; d += blockDim.x) { + T k_val = k_shared[d]; + for (int e = 0; e < v_dim; ++e) { + int past_kv_idx = kv_offset + d * v_dim + e; + T v_val = v_shared[e]; + float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val; + int shared_idx = d * (v_dim + 1) + e; + new_kv_shared[shared_idx] = new_val; + } + } + + __syncthreads(); + + for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) { + int d = idx / v_dim; + int e = idx % v_dim; + int shared_idx = d * (v_dim + 1) + e; + int global_idx = kv_offset + idx; + new_kv[global_idx] = new_kv_shared[shared_idx]; + } + + __syncthreads(); + + for (int e = tid; e < v_dim; e += blockDim.x) { + float sum = 0.0f; + for (int d = 0; d < qk_dim; ++d) { + int shared_idx = d * (v_dim + 1) + e; + sum += q_shared[d] * new_kv_shared[shared_idx]; + } + output_shared[e] = static_cast(sum); + } + + __syncthreads(); + + if (tid == 0) { + for (int e = 0; e < v_dim; ++e) { + output[v_offset + e] = output_shared[e]; + } + } +} + +void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, + const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, + torch::Tensor new_kv) { + TORCH_CHECK(q.is_contiguous(), "q must be contiguous"); + TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); + TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); + TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous"); + + auto batch_size = q.size(0); + auto num_heads = q.size(1); + auto qk_dim = q.size(3); + auto v_dim = v.size(3); + + dim3 block(THREADS_PER_BLOCK); + dim3 grid(batch_size * num_heads); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] { + size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float); + lightning_attention_decode_kernel<<>>( + q.data_ptr(), k.data_ptr(), v.data_ptr(), past_kv.data_ptr(), + slope.data_ptr(), output.data_ptr(), new_kv.data_ptr(), batch_size, num_heads, + qk_dim, v_dim); + })); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index c7faf9d3775..19e9850b51a 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -3,28 +3,14 @@ #include #include #include +#include #include -#include "utils.hpp" - -#ifdef USE_ROCM -#include -#endif - -#ifndef USE_ROCM #define WARP_SIZE 32 -#else -#define WARP_SIZE warpSize -#endif -#ifndef USE_ROCM #define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) -#else -#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ - hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) -#endif #define CEILDIV(x, y) (((x) + (y)-1) / (y)) @@ -39,7 +25,6 @@ AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { - // don't worry about overflow because num_experts is relatively small return row * total_col + col; } diff --git a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu deleted file mode 100644 index a61d4b86059..00000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu +++ /dev/null @@ -1,59 +0,0 @@ -#include -#include -#include - -#include - -#include "utils.hpp" -#include "vectorization.cuh" - -template -__global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const scalar_t* scaling_penalties, - scalar_t* output, const int32_t numel) { - const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const int32_t stride = blockDim.x * gridDim.x; - - auto const* vectorized_logits = reinterpret_cast const*>(logits); - auto const* vectorized_penalties = reinterpret_cast const*>(scaling_penalties); - auto* vectorized_output = reinterpret_cast*>(output); - - const int32_t num_vec_elems = numel >> 2; - -#pragma unroll 4 - for (int32_t i = tid; i < num_vec_elems; i += stride) { - vec4_t logits_vec = vectorized_logits[i]; - vec4_t penalties_vec = vectorized_penalties[i]; - vec4_t out_vec; - - out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x; - out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y; - out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z; - out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w; - - vectorized_output[i] = out_vec; - } - - const int32_t start_idx = num_vec_elems * 4; - for (int32_t i = start_idx + tid; i < numel; i += stride) { - scalar_t logit = logits[i]; - scalar_t penalty = scaling_penalties[i]; - output[i] = logit > 0 ? logit / penalty : logit * penalty; - } -} - -torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties) { - auto output = torch::empty_like(logits); - const auto numel = logits.numel(); - const int threads = 512; - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, at::ScalarType::BFloat16, logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] { - const int blocks = (numel + threads * 4 - 1) / (threads * 4); - sampling_scaling_penalties_kernel<<>>( - logits.data_ptr(), scaling_penalties.data_ptr(), output.data_ptr(), numel); - })); - - return output; -} diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu deleted file mode 100644 index b9879b114fe..00000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ /dev/null @@ -1,40 +0,0 @@ -#include "utils.hpp" - -// trt_reduce -using fptr_t = int64_t; -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, - const std::vector& tmp_result_buffers, const std::vector& barrier_in, - const std::vector& barrier_out); -void dispose(fptr_t _fa); -void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); -std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, const std::vector>& handles, - const std::vector>& offsets); - -// moe_align_block_size -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, - torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, - torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); - -// sampling_scaling_penalties -torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties); - -// int8_scaled_mm -torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // trt_reduce - m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); - m.def("dispose", &dispose, "dispose custom allreduce meta"); - m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); - m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "custom all reduce get graph ipc meta"); - m.def("register_graph_buffers", ®ister_graph_buffers, "custom all reduce register graph buffers"); - // moe_align_block_size - m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); - // sampling_scaling_penalties - m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)"); - // int8_scaled_mm - m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); -} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index 006c3200dd1..2ee0c98c91e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -26,6 +26,7 @@ #include #include "trt_reduce_internal.cuh" +#include "utils.h" //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -160,7 +161,7 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag } template -static __global__ void oneShotAllReduceKernel(AllReduceParams params) { +static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduceParams params) { // Suppose that two GPUs participate in the AR exchange, and we start four blocks. // The message is partitioned into chunks as detailed below: // message diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu index d80beedec82..fd0483e39ee 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -3,11 +3,9 @@ #include #include -#include -#include -#include #include "trt_reduce_internal.cuh" +#include "utils.h" using namespace trt_llm; diff --git a/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh b/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh deleted file mode 100644 index 2bfb710189b..00000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh +++ /dev/null @@ -1,29 +0,0 @@ -// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/quantization/vectorization.cuh -#pragma once -/** - * __device__ datatypes vectorized by 4 - */ - -// Include both AMD and NVIDIA fp8 types to avoid circular import -// TODO(luka/varun) use FP8_TYPE instead after refactoring -#include -#include - -// Vectorization containers -template -struct __align__(8) vec4_t { - scalar_t x; - scalar_t y; - scalar_t z; - scalar_t w; -}; - -template -struct __align__(4) q8x4_t { - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v); - quant_type_t x; - quant_type_t y; - quant_type_t z; - quant_type_t w; -}; diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h new file mode 100644 index 00000000000..c5cc30c1888 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } + +// trt_reduce +using fptr_t = int64_t; +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, + const std::vector& tmp_result_buffers, const std::vector& barrier_in, + const std::vector& barrier_out); +void dispose(fptr_t _fa); +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector>& handles, + const std::vector>& offsets); + +// moe_align_block_size +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); + +// int8_scaled_mm +torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias); + +// fp8_scaled_mm +torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias); + +// lightning_attention_decode +void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, + const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, + torch::Tensor new_kv); + +// rms norm +void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); + +// fused rms norm +void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); + +// gemma rms norm +void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); + +// fused gemma rms norm +void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, + int64_t cuda_stream); + +// silu and mul +void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// gelu tanh and mul +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// gelu and mul +void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// bmm fp8 +void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, + at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); + +// min p sampling from probs +void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + std::optional maybe_min_p_arr, double min_p_val, bool deterministic, + int64_t cuda_stream); + +// top k renorm probs +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. +void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, + unsigned int top_k_val, int64_t cuda_stream); + +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. +// wrapper for binding +inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, int64_t top_k_val, + int64_t cuda_stream) { + top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast(top_k_val), cuda_stream); +} + +// top p renorm probs +void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, + double top_p_val, int64_t cuda_stream); + +// top k top p sampling from probs +void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic, + int64_t cuda_stream); + +// top p sampling from probs +void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic, + int64_t cuda_stream); + +void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, + int64_t cuda_stream); diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh similarity index 99% rename from sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh rename to sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh index 9d6f9722eb5..46522348aaf 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh @@ -17,12 +17,11 @@ */ #pragma once + #include #include #include -#include "utils.hpp" - namespace trt_llm { constexpr size_t WARP_SIZE = 32; constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36; diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp b/sgl-kernel/src/sgl-kernel/include/utils.h similarity index 56% rename from sgl-kernel/src/sgl-kernel/csrc/utils.hpp rename to sgl-kernel/src/sgl-kernel/include/utils.h index 2fed2d60c03..55594f7b273 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -1,4 +1,7 @@ #pragma once + +#include +#include #include #include @@ -44,3 +47,20 @@ inline int getSMVersion() { CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); return sm_major * 10 + sm_minor; } + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float: { \ + using c_type = float; \ + return __VA_ARGS__(); \ + } \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 6b35f78a490..5aa484ff54d 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,39 +1,91 @@ -from sgl_kernel.ops._kernels import all_reduce as _all_reduce -from sgl_kernel.ops._kernels import dispose as _dispose -from sgl_kernel.ops._kernels import ( - get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta, -) -from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar -from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm -from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size -from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers -from sgl_kernel.ops._kernels import ( - sampling_scaling_penalties as _sampling_scaling_penalties, +import os +from typing import Optional, Tuple, Union + +import sgl_kernel.ops._kernels +import torch +from sgl_kernel.ops.utils import ( + _get_cache_buf, + _get_cuda_stream, + _to_tensor_scalar_tuple, ) +def apply_rope_with_cos_sin_cache_inplace( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool = True, +) -> None: + r""" + Apply rotary embedding to keys and queries with precomputed cos/sin values. + This is designed to be compatible with the SGL/vLLM implementation. + The result is inplace applied to the input tensors. + + Parameters + ---------- + positions : torch.Tensor + Position indices, shape: ``(nnz)``. + query : torch.Tensor + Query tensor, shape: ``(nnz, num_q_heads * head_size)``. + key : torch.Tensor + Key tensor, shape: ``(nnz, num_k_heads * head_size)``. + cos_sin_cache : torch.Tensor + Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``. + Cosine is the first half and Sine is the second half on rotary_dim. + is_neox : bool + Whether to use Neox style RoPE, default: ``True``. + + * If ``True``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + * If ``False``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + Note + ---- + The rotary dimension is determined by the cosine cache and sine cache. + """ + if cos_sin_cache.dtype != torch.float32: + raise ValueError("cos_sin_cache should be float32") + + with query.device as device: + positions = positions.int() + torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache( + q=query.view(query.shape[0], -1, head_size), + k=key.view(key.shape[0], -1, head_size), + q_rope=query.view(query.shape[0], -1, head_size), + k_rope=key.view(key.shape[0], -1, head_size), + cos_sin_cache=cos_sin_cache, + pos_ids=positions, + interleave=(not is_neox), + cuda_stream=_get_cuda_stream(device), + ) + + def init_custom_reduce( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ): - return _init_custom_ar( + return torch.ops.sgl_kernels.init_custom_ar( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ) def custom_dispose(fa): - _dispose(fa) + torch.ops.sgl_kernels.dispose(fa) def custom_reduce(fa, inp, out): - _all_reduce(fa, inp, out) + torch.ops.sgl_kernels.all_reduce(fa, inp, out) def get_graph_buffer_ipc_meta(fa): - return _get_graph_buffer_ipc_meta(fa) + return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) def register_graph_buffers(fa, handles, offsets): - _register_graph_buffers(fa, handles, offsets) + torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) def moe_align_block_size( @@ -46,7 +98,7 @@ def moe_align_block_size( token_cnts_buffer, cumsum_buffer, ): - _moe_align_block_size( + torch.ops.sgl_kernels.moe_align_block_size( topk_ids, num_experts, block_size, @@ -59,11 +111,22 @@ def moe_align_block_size( def sampling_scaling_penalties(logits, scaling_penalties): - return _sampling_scaling_penalties(logits, scaling_penalties) + return torch.ops.sgl_kernels.sampling_scaling_penalties(logits, scaling_penalties) def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return _int8_scaled_mm( + return torch.ops.sgl_kernels.int8_scaled_mm( + mat_a, + mat_b, + scales_a, + scales_b, + out_dtype, + bias, + ) + + +def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): + return torch.ops.sgl_kernels.fp8_scaled_mm( mat_a, mat_b, scales_a, @@ -71,3 +134,364 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): out_dtype, bias, ) + + +def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): + torch.ops.sgl_kernels.lightning_attention_decode( + q, k, v, past_kv, slope, output, new_kv + ) + + +# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer +# Kudos to @yzh119 +def rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + with input.device as device: + if out is None: + out = torch.empty_like(input) + torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) + return out + + +def fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> None: + with input.device as device: + torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps) + + +def gemma_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + with input.device as device: + if out is None: + out = torch.empty_like(input) + torch.ops.sgl_kernels.gemma_rmsnorm( + out, input, weight, eps, _get_cuda_stream(device) + ) + return out + + +def gemma_fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> None: + with input.device as device: + torch.ops.sgl_kernels.gemma_fused_add_rmsnorm( + input, residual, weight, eps, _get_cuda_stream(device) + ) + + +def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: + assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}" + assert ( + input.shape[:-1] == output.shape[:-1] + ), f"{input.shape[:-1]} != {output.shape[:-1]}" + assert ( + input.shape[-1] == 2 * output.shape[-1] + ), f"{input.shape[-1]} != {2 * output.shape[-1]}" + + +def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + torch.ops.sgl_kernels.silu_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + torch.ops.sgl_kernels.gelu_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def _bmm_fp8_internal( + workspace_buffer: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + D: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, +) -> None: + with A.device as device: + cublas_handle = torch.cuda.current_blas_handle() + torch.ops.sgl_kernels.bmm_fp8( + A, + B, + D, + A_scale, + B_scale, + workspace_buffer, + cublas_handle, + _get_cuda_stream(device), + ) + + +def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) + _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale) + return out + + +def _top_k_renorm_probs_internal( + probs: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + renorm_probs = torch.empty_like(probs) + torch.ops.sgl_kernels.top_k_renorm_probs_wrapper( + probs, + renorm_probs, + maybe_top_k_arr, + top_k_val, + _get_cuda_stream(device), + ) + return renorm_probs + + +def top_k_renorm_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], +) -> torch.Tensor: + return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k)) + + +top_k_renorm_prob = top_k_renorm_probs + + +def _top_p_renorm_probs_internal( + probs: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + renorm_probs = torch.empty_like(probs) + torch.ops.sgl_kernels.top_p_renorm_probs( + probs, + renorm_probs, + maybe_top_p_arr, + top_p_val, + _get_cuda_stream(device), + ) + return renorm_probs + + +def top_p_renorm_probs( + probs: torch.Tensor, + top_p: Union[torch.Tensor, float], +) -> torch.Tensor: + return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p)) + + +top_p_renorm_prob = top_p_renorm_probs + + +def _top_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + torch.ops.sgl_kernels.top_p_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_top_p_arr, + top_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples, success + + +def top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_p: Union[torch.Tensor, float], + deterministic: bool = True, + check_nan: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _top_p_sampling_from_probs_internal( + probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic + ) + + +def _top_k_top_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples, success + + +def top_k_top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_k: Union[torch.Tensor, int], + top_p: Union[torch.Tensor, float], + filter_apply_order: str = "top_k_first", + deterministic: bool = True, + check_nan: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if filter_apply_order == "top_k_first": + renorm_probs = top_k_renorm_probs(probs, top_k) + return top_p_sampling_from_probs( + renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan + ) + elif filter_apply_order == "joint": + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _top_k_top_p_sampling_from_probs_internal( + probs, + uniform_samples, + *_to_tensor_scalar_tuple(top_k), + *_to_tensor_scalar_tuple(top_p), + deterministic, + ) + else: + raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") + + +def _min_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_min_p_arr: Optional[torch.Tensor], + min_p_val: float, + deterministic: bool, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_min_p_arr = ( + maybe_min_p_arr.float() if maybe_min_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + torch.ops.sgl_kernels.min_p_sampling_from_probs( + probs, + uniform_samples, + samples, + maybe_min_p_arr, + min_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples + + +def min_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + min_p: Union[torch.Tensor, float], + deterministic: bool = True, + check_nan: bool = False, +) -> torch.Tensor: + if uniform_samples.dim() == 2: + # Take the first row (round) of uniform_samples + uniform_samples = uniform_samples[0] + + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _min_p_sampling_from_probs_internal( + probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic + ) diff --git a/sgl-kernel/src/sgl-kernel/ops/utils.py b/sgl-kernel/src/sgl-kernel/ops/utils.py new file mode 100644 index 00000000000..31a6bbf9919 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/ops/utils.py @@ -0,0 +1,26 @@ +from typing import Dict, Tuple + +import torch + + +def _get_cuda_stream(device: torch.device) -> int: + return torch.cuda.current_stream(device).cuda_stream + + +_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} + + +def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: + key = (name, device) + buf = _cache_buf.get(key) + if buf is None: + buf = torch.empty(bytes, dtype=torch.uint8, device=device) + _cache_buf[key] = buf + return buf + + +def _to_tensor_scalar_tuple(x): + if isinstance(x, torch.Tensor): + return (x, 0) + else: + return (None, x) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc new file mode 100644 index 00000000000..01f93199ccb --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -0,0 +1,120 @@ +#include +#include + +#include "sgl_kernels_ops.h" + +TORCH_LIBRARY_EXPAND(sgl_kernels, m) { + // trt_reduce + m.def( + "init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] " + "barrier_in, int[] barrier_out) -> int"); + m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + + m.def("dispose", &dispose); + + m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()"); + m.impl("all_reduce", torch::kCUDA, &all_reduce); + + m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])"); + m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta); + + m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()"); + m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers); + + // moe_align_block_size + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + + // int8_scaled_mm + m.def( + "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " + "bias) -> Tensor"); + m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); + + // fp8_scaled_mm + m.def( + "fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " + "bias) -> Tensor"); + m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm); + + // lightning_attention_decode + m.def( + "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " + "new_kv) -> ()"); + m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); + + // rms norm + m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("rmsnorm", torch::kCUDA, &rmsnorm); + + // fused rms norm + m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); + m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); + + // gemma rms norm + m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); + + // fused gemma rms norm + m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); + + // silu and mul + m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + // gelu tanh and mul + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + // gelu and mul + m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + // bmm fp8 + m.def( + "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " + "cublas_handle, int cuda_stream) -> ()"); + m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); + + // min p sampling from probs + m.def( + "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float " + "min_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); + + // top k renorm probs + m.def( + "top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " + "cuda_stream) -> ()"); + m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper); + + // top p renorm probs + m.def( + "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " + "cuda_stream) -> ()"); + m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); + + // top k top p sampling from probs + m.def( + "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int " + "cuda_stream) -> ()"); + m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); + + // top p sampling from probs + m.def( + "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); + + // apply rope with cos sin cache + m.def( + "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " + "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); + m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); +} + +REGISTER_EXTENSION(_kernels) diff --git a/sgl-kernel/tests/test_activation.py b/sgl-kernel/tests/test_activation.py new file mode 100644 index 00000000000..43593441e3b --- /dev/null +++ b/sgl-kernel/tests/test_activation.py @@ -0,0 +1,39 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py + +import pytest +import sgl_kernel +import torch + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_silu_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim]) + y = sgl_kernel.silu_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_tanh_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh") + y = sgl_kernel.gelu_tanh_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none") + y = sgl_kernel.gelu_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_bmm_fp8.py b/sgl-kernel/tests/test_bmm_fp8.py new file mode 100644 index 00000000000..e0be92896f6 --- /dev/null +++ b/sgl-kernel/tests/test_bmm_fp8.py @@ -0,0 +1,43 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_bmm_fp8.py + +import pytest +import torch +import torch.nn.functional as F +from sgl_kernel import bmm_fp8 + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype): + if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2: + pytest.skip("Invalid combination: both input and mat2 are e5m2") + + input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) + input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) + + # mat2 row major -> column major + mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose( + -2, -1 + ) + mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) + + res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype) + bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res) + + reference = torch.bmm(input, mat2) + cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_fp8_gemm.py b/sgl-kernel/tests/test_fp8_gemm.py new file mode 100644 index 00000000000..1a731865944 --- /dev/null +++ b/sgl-kernel/tests/test_fp8_gemm.py @@ -0,0 +1,67 @@ +import unittest + +import torch +from sgl_kernel import fp8_scaled_mm + + +def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): + o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + + o = o.to(torch.float32) + temp1 = o * scale_a.view(-1, 1) + temp2 = temp1 * scale_b.view(1, -1) + final = temp2.to(out_dtype) + if bias is not None: + final = final + bias.view(1, -1) + + return final + + +class TestFp8Gemm(unittest.TestCase): + def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a_fp32 = ( + (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + ) + a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + b_fp32 = ( + (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + ) + b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001 + scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001 + if with_bias: + bias = torch.randn((N,), device=device, dtype=out_dtype) + else: + bias = None + o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16) + b_fp8 = b_fp8.t() + o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + rtol = 0.02 + atol = 1 + torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) + print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") + + def test_accuracy(self): + Ms = [1, 128, 512, 1024, 4096] + Ns = [16, 128, 512, 1024, 4096] + Ks = [512, 1024, 4096, 8192, 16384] + bias_opts = [True, False] + out_dtypes = [torch.bfloat16, torch.float16] + for M in Ms: + for N in Ns: + for K in Ks: + for with_bias in bias_opts: + for out_dtype in out_dtypes: + self._test_accuracy_once( + M, N, K, with_bias, out_dtype, "cuda" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sgl-kernel/tests/test_int8_gemm.py b/sgl-kernel/tests/test_int8_gemm.py index 34d17d1c76a..c33a3effcaf 100644 --- a/sgl-kernel/tests/test_int8_gemm.py +++ b/sgl-kernel/tests/test_int8_gemm.py @@ -25,7 +25,7 @@ def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) if with_bias: - bias = torch.ones((N,), device="cuda", dtype=out_dtype) * 10 + bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10 else: bias = None diff --git a/sgl-kernel/tests/test_lightning_attention_decode.py b/sgl-kernel/tests/test_lightning_attention_decode.py new file mode 100644 index 00000000000..f2cace00157 --- /dev/null +++ b/sgl-kernel/tests/test_lightning_attention_decode.py @@ -0,0 +1,88 @@ +import pytest +import torch +from sgl_kernel import lightning_attention_decode + + +def naive_lightning_attention_decode(q, k, v, past_kv, slope): + """Naive implementation of lightning attention decode""" + original_dtype = q.dtype + ratio = torch.exp(-slope) # [h, 1, 1] + + kv = past_kv + b, h, n, d = q.shape + + output = [] + for i in range(n): + kv = ratio * kv.to(torch.float32) + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", + q[:, :, i : i + 1].to(torch.float32), + kv.to(torch.float32), + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + return output.to(original_dtype), kv + + +configs = [ + # (batch_size, num_heads, dim, embed_dim) + (1, 8, 64, 64), + (2, 8, 64, 64), + (1, 32, 32, 64), + (2, 32, 32, 64), + (4, 32, 64, 64), + (4, 32, 64, 64), + (16, 64, 96, 96), + (64, 64, 96, 96), +] + +dtypes = [torch.float32, torch.float16, torch.bfloat16] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs) +def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim): + device = torch.device("cuda") + + q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype) + k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype) + v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype) + past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope) + + output = torch.empty_like(ref_output) + new_kv = torch.empty_like(ref_new_kv) + lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + rtol = 1e-2 + atol = 1e-2 + + torch.testing.assert_close( + output, + ref_output, + rtol=rtol, + atol=atol, + msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, " + f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", + ) + + torch.testing.assert_close( + new_kv, + ref_new_kv, + rtol=rtol, + atol=atol, + msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, " + f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_norm.py b/sgl-kernel/tests/test_norm.py new file mode 100644 index 00000000000..d22da931f57 --- /dev/null +++ b/sgl-kernel/tests/test_norm.py @@ -0,0 +1,133 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_norm.py + +import pytest +import sgl_kernel +import torch + + +def llama_rms_norm(x, w, eps=1e-6): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * w.float() + x = x.to(orig_dtype) + return x + + +def gemma_rms_norm(x, w, eps=1e-6): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * (1.0 + w.float()) + x = x.to(orig_dtype) + return x + + +def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6): + orig_dtype = x.dtype + x = x + residual + residual = x + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * (1.0 + w.float()) + x = x.to(orig_dtype) + return x, residual + + +def fused_add_rms_norm(x, residual, weight, eps): + orig_dtype = x.dtype + x = x.to(torch.float32) + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = (x * weight.float()).to(orig_dtype) + return x, residual + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("specify_out", [True, False]) +def test_norm(batch_size, hidden_size, dtype, specify_out): + x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + w = torch.randn(hidden_size).to(0).to(dtype) + + y_ref = llama_rms_norm(x, w) + if specify_out: + y = torch.empty_like(x) + sgl_kernel.rmsnorm(x, w, out=y) + else: + y = sgl_kernel.rmsnorm(x, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + x_native, residual_native = fused_add_rms_norm( + x.clone(), residual.clone(), weight, eps + ) + + x_fused = x.clone() + residual_fused = residual.clone() + sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + + torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("specify_out", [True, False]) +def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): + x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + w = torch.randn(hidden_size).to(0).to(dtype) + + y_ref = gemma_rms_norm(x, w) + if specify_out: + y = torch.empty_like(x) + sgl_kernel.gemma_rmsnorm(x, w, out=y) + else: + y = sgl_kernel.gemma_rmsnorm(x, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + x_native, residual_native = gemma_fused_add_rms_norm( + x.clone(), residual.clone(), weight, eps + ) + + x_fused = x.clone() + residual_fused = residual.clone() + sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + + torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_rotary_embedding.py b/sgl-kernel/tests/test_rotary_embedding.py new file mode 100644 index 00000000000..b7a141404e6 --- /dev/null +++ b/sgl-kernel/tests/test_rotary_embedding.py @@ -0,0 +1,202 @@ +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytest +import torch +import torch.nn as nn +from sgl_kernel import apply_rope_with_cos_sin_cache_inplace + + +# vLLM torch native +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(torch.nn.Module): + # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + + # Modification: float32 is required for the rotary embedding to work correctly + query = query.to(torch.float32) + key = key.to(torch.float32) + + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + + # Modification: convert to the correct dtype + query = query.to(self.dtype) + key = key.to(self.dtype) + return query, key + + +class FlashInferRotaryEmbedding(RotaryEmbedding): + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=self.head_size, + cos_sin_cache=self.cos_sin_cache, + is_neox=self.is_neox_style, + ) + + return query, key + + +@pytest.mark.parametrize( + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + [ + (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2), + (512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4), + (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2), + ], +) +def test_correctness( + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + device: str, + batch_size: int, + seq_len: int, + num_q_heads: int, + num_kv_heads: int, +): + rope_ref = RotaryEmbedding( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ).to(device) + rope_flashinfer = FlashInferRotaryEmbedding( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ).to(device) + + pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) + query = torch.randn( + batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device + ) + + query_ref, key_ref = query.clone(), key.clone() + query_flashinfer, key_flashinfer = query.clone(), key.clone() + + query_ref_out, key_ref_out = rope_ref.forward_native(pos_ids, query_ref, key_ref) + query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda( + pos_ids, query_flashinfer, key_flashinfer + ) + + print(query_ref_out) + print(query_flashinfer_out) + + torch.testing.assert_close( + query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_sampling.py b/sgl-kernel/tests/test_sampling.py new file mode 100644 index 00000000000..7d3bc5059ee --- /dev/null +++ b/sgl-kernel/tests/test_sampling.py @@ -0,0 +1,141 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/93e1a2634e22355b0856246b032b285ad1d1da6b/tests/test_sampling.py + +import pytest +import sgl_kernel +import torch + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5]) +def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): + torch.manual_seed(42) + if p == 0.1: + k = int(vocab_size * 0.5) + elif p == 0.5: + k = int(vocab_size * 0.1) + else: + raise ValueError("p not recognized") + max_top_k_trails = 32 + eps = 1e-4 + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + # top-p mask + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) + # top-k mask + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int() + # overall mask + mask = torch.minimum(mask_top_p, mask_top_k) + uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to( + 0 + ) + top_p_tensor = torch.full((batch_size,), p).to(0) + top_k_tensor = torch.full((batch_size,), k).to(0) + + num_trails = 1000 + for _ in range(num_trails): + uniform_samples.uniform_() + samples, success = sgl_kernel.top_k_top_p_sampling_from_probs( + normalized_prob, + uniform_samples, + top_k_tensor, + top_p_tensor, + filter_apply_order="joint", + ) + assert torch.all(success) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ + torch.arange(batch_size), samples + ] + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) +def test_top_p_renorm_probs(batch_size, vocab_size, p): + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask.scatter_add_(1, indices, (cdf >= (1 - p)).int()) + renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p) + torch.testing.assert_close( + renorm_prob_ground_truth, + renorm_prob, + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_top_k_renorm_probs(batch_size, vocab_size, k): + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask = (normalized_prob >= pivot.unsqueeze(-1)).int() + renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k) + torch.testing.assert_close( + renorm_prob_ground_truth, + renorm_prob, + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1]) +def test_min_p_sampling(batch_size, vocab_size, p): + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + # scale min-p + top_probs = sorted_prob[:, -1].unsqueeze(-1) + scaled_p = p * top_probs + # min-p mask + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int()) + uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0) + min_p_tensor = torch.full((batch_size,), p).to(0) + + num_trails = 1000 + for _ in range(num_trails): + uniform_samples.uniform_() + samples = sgl_kernel.min_p_sampling_from_probs( + normalized_prob, + uniform_samples, + min_p_tensor, + ) + + assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[ + torch.nonzero(mask[torch.arange(batch_size), samples] == 0) + ] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_sampling_scaling_penalties.py b/sgl-kernel/tests/test_sampling_scaling_penalties.py deleted file mode 100644 index 4b9746fd793..00000000000 --- a/sgl-kernel/tests/test_sampling_scaling_penalties.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch -from sgl_kernel import sampling_scaling_penalties - - -def test_sampling_scaling_penalties(): - batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65] - vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767] - dtypes = [torch.float32, torch.half, torch.bfloat16] - device = torch.device("cuda") - - for dtype in dtypes: - rtol = 1e-3 - atol = 1e-3 - - for bs in batch_sizes: - for vocab_size in vocab_sizes: - logits = torch.randn(bs, vocab_size, device=device, dtype=dtype) - scaling_penalties = ( - torch.rand(bs, vocab_size, device=device, dtype=dtype) + 0.5 - ) - - ref_output = torch.where( - logits > 0, logits / scaling_penalties, logits * scaling_penalties - ) - - kernel_output = sampling_scaling_penalties(logits, scaling_penalties) - - torch.testing.assert_close( - kernel_output, - ref_output, - rtol=rtol, - atol=atol, - msg=f"Failed for batch_size={bs}, vocab_size={vocab_size}, dtype={dtype}", - ) - - -if __name__ == "__main__": - test_sampling_scaling_penalties() - print("All tests passed!") diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py new file mode 100644 index 00000000000..27fdca497c3 --- /dev/null +++ b/sgl-kernel/version.py @@ -0,0 +1 @@ +__version__ = "0.0.3" diff --git a/sgl-router/README.md b/sgl-router/README.md index f39d63625de..61c9e692c92 100644 --- a/sgl-router/README.md +++ b/sgl-router/README.md @@ -67,6 +67,16 @@ $ pip install -e . **Note:** When modifying Rust code, you must rebuild the wheel for changes to take effect. +### Troubleshooting + +1. If rust analyzer is not working in VSCode, set `rust-analyzer.linkedProjects` to the absolute path of `Cargo.toml` in your repo. For example: + +```json +{ + "rust-analyzer.linkedProjects": ["/workspaces/sglang/sgl-router/Cargo.toml"] +} +``` + ### CI/CD Setup The continuous integration pipeline consists of three main steps: diff --git a/sgl-router/py_src/sglang_router/__init__.py b/sgl-router/py_src/sglang_router/__init__.py index 285ee173ba9..081740479ca 100644 --- a/sgl-router/py_src/sglang_router/__init__.py +++ b/sgl-router/py_src/sglang_router/__init__.py @@ -1,11 +1,7 @@ # a lightweihgt wrapper on router with argument type and comments -from sglang_router_rs import PolicyType - # no wrapper on policy type => direct export -from .router import Router - -__all__ = ["Router", "PolicyType"] - +from sglang_router.router import Router from sglang_router.version import __version__ +from sglang_router_rs import PolicyType -__all__ += ["__version__"] +__all__ = ["Router", "PolicyType", "__version__"] diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index e4f26a8d4bc..38f1fbba2dc 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -27,12 +27,14 @@ def setup_logger(): @dataclasses.dataclass class RouterArgs: # Worker configuration - worker_urls: List[str] + worker_urls: List[str] = dataclasses.field(default_factory=list) host: str = "127.0.0.1" port: int = 30000 # Routing policy policy: str = "cache_aware" + worker_startup_timeout_secs: int = 300 + worker_startup_check_interval: int = 10 cache_threshold: float = 0.5 balance_abs_threshold: int = 32 balance_rel_threshold: float = 1.0001 @@ -87,6 +89,18 @@ def add_cli_args( choices=["random", "round_robin", "cache_aware"], help="Load balancing policy to use", ) + parser.add_argument( + f"--{prefix}worker-startup-timeout-secs", + type=int, + default=RouterArgs.worker_startup_timeout_secs, + help="Timeout in seconds for worker startup", + ) + parser.add_argument( + f"--{prefix}worker-startup-check-interval", + type=int, + default=RouterArgs.worker_startup_check_interval, + help="Interval in seconds between checks for worker startup", + ) parser.add_argument( f"--{prefix}cache-threshold", type=float, @@ -141,11 +155,18 @@ def from_cli_args( use_router_prefix: If True, look for arguments with 'router-' prefix """ prefix = "router_" if use_router_prefix else "" + worker_urls = args.worker_urls if args.worker_urls is not None else [] return cls( - worker_urls=args.worker_urls, + worker_urls=worker_urls, host=args.host, port=args.port, policy=getattr(args, f"{prefix}policy"), + worker_startup_timeout_secs=getattr( + args, f"{prefix}worker_startup_timeout_secs" + ), + worker_startup_check_interval=getattr( + args, f"{prefix}worker_startup_check_interval" + ), cache_threshold=getattr(args, f"{prefix}cache_threshold"), balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"), balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"), @@ -187,9 +208,11 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: router = Router( worker_urls=router_args.worker_urls, - policy=policy_from_str(router_args.policy), host=router_args.host, port=router_args.port, + policy=policy_from_str(router_args.policy), + worker_startup_timeout_secs=router_args.worker_startup_timeout_secs, + worker_startup_check_interval=router_args.worker_startup_check_interval, cache_threshold=router_args.cache_threshold, balance_abs_threshold=router_args.balance_abs_threshold, balance_rel_threshold=router_args.balance_rel_threshold, @@ -204,7 +227,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: except Exception as e: logger.error(f"Error starting router: {e}") - return None + raise e class CustomHelpFormatter( @@ -237,12 +260,8 @@ def parse_router_args(args: List[str]) -> RouterArgs: def main() -> None: - logger = setup_logger() router_args = parse_router_args(sys.argv[1:]) - router = launch_router(router_args) - - if router is None: - sys.exit(1) + launch_router(router_args) if __name__ == "__main__": diff --git a/sgl-router/py_src/sglang_router/launch_server.py b/sgl-router/py_src/sglang_router/launch_server.py index 6ee19241542..74353c21edb 100644 --- a/sgl-router/py_src/sglang_router/launch_server.py +++ b/sgl-router/py_src/sglang_router/launch_server.py @@ -13,7 +13,7 @@ from setproctitle import setproctitle from sglang_router.launch_router import RouterArgs, launch_router -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_port_available @@ -23,7 +23,7 @@ def setup_logger(): logger.setLevel(logging.INFO) formatter = logging.Formatter( - "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s", + "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d", datefmt="%Y-%m-%d %H:%M:%S", ) @@ -68,7 +68,7 @@ def run_server(server_args, dp_rank): # create new process group os.setpgrp() - setproctitle(f"sglang::server") + setproctitle("sglang::server") # Set SGLANG_DP_RANK environment variable os.environ["SGLANG_DP_RANK"] = str(dp_rank) @@ -120,9 +120,26 @@ def find_available_ports(base_port: int, count: int) -> List[int]: def cleanup_processes(processes: List[mp.Process]): for process in processes: - logger.info(f"Terminating process {process.pid}") - process.terminate() - logger.info("All processes terminated") + logger.info(f"Terminating process group {process.pid}") + try: + os.killpg(process.pid, signal.SIGTERM) + except ProcessLookupError: + # Process group may already be terminated + pass + + # Wait for processes to terminate + for process in processes: + process.join(timeout=5) + if process.is_alive(): + logger.warning( + f"Process {process.pid} did not terminate gracefully, forcing kill" + ) + try: + os.killpg(process.pid, signal.SIGKILL) + except ProcessLookupError: + pass + + logger.info("All process groups terminated") def main(): @@ -173,7 +190,12 @@ def main(): ] # Start the router - router = launch_router(router_args) + try: + launch_router(router_args) + except Exception as e: + logger.error(f"Failed to start router: {e}") + cleanup_processes(server_processes) + sys.exit(1) if __name__ == "__main__": diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index 5ce21c3d78e..b8757168b24 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -17,6 +17,8 @@ class Router: - PolicyType.CacheAware: Distribute requests based on cache state and load balance host: Host address to bind the router server. Default: '127.0.0.1' port: Port number to bind the router server. Default: 3001 + worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300 + worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10 cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5 @@ -37,6 +39,8 @@ def __init__( policy: PolicyType = PolicyType.RoundRobin, host: str = "127.0.0.1", port: int = 3001, + worker_startup_timeout_secs: int = 300, + worker_startup_check_interval: int = 10, cache_threshold: float = 0.50, balance_abs_threshold: int = 32, balance_rel_threshold: float = 1.0001, @@ -50,6 +54,8 @@ def __init__( policy=policy, host=host, port=port, + worker_startup_timeout_secs=worker_startup_timeout_secs, + worker_startup_check_interval=worker_startup_check_interval, cache_threshold=cache_threshold, balance_abs_threshold=balance_abs_threshold, balance_rel_threshold=balance_rel_threshold, diff --git a/sgl-router/py_src/sglang_router/version.py b/sgl-router/py_src/sglang_router/version.py index 485f44ac21b..bbab0242f6a 100644 --- a/sgl-router/py_src/sglang_router/version.py +++ b/sgl-router/py_src/sglang_router/version.py @@ -1 +1 @@ -__version__ = "0.1.1" +__version__ = "0.1.4" diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py index 1c3700d423b..27ed64d6e66 100644 --- a/sgl-router/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -22,14 +22,14 @@ def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> class TestLaunchRouter(unittest.TestCase): - def test_launch_router_no_exception(self): - - # Create SimpleNamespace with default arguments - args = SimpleNamespace( - worker_urls=["http://localhost:8000"], + def setUp(self): + """Set up default arguments for router tests.""" + self.default_args = SimpleNamespace( host="127.0.0.1", port=30000, policy="cache_aware", + worker_startup_timeout_secs=600, + worker_startup_check_interval=10, cache_threshold=0.5, balance_abs_threshold=32, balance_rel_threshold=1.0001, @@ -39,6 +39,15 @@ def test_launch_router_no_exception(self): verbose=False, ) + def create_router_args(self, **kwargs): + """Create router arguments by updating default args with provided kwargs.""" + args_dict = vars(self.default_args).copy() + args_dict.update(kwargs) + return SimpleNamespace(**args_dict) + + def run_router_process(self, args): + """Run router in a separate process and verify it starts successfully.""" + def run_router(): try: from sglang_router.launch_router import launch_router @@ -51,7 +60,6 @@ def run_router(): print(e) return 1 - # Start router in separate process process = multiprocessing.Process(target=run_router) try: process.start() @@ -62,6 +70,14 @@ def run_router(): finally: terminate_process(process) + def test_launch_router_common(self): + args = self.create_router_args(worker_urls=["http://localhost:8000"]) + self.run_router_process(args) + + def test_launch_router_with_empty_worker_urls(self): + args = self.create_router_args(worker_urls=[]) + self.run_router_process(args) + if __name__ == "__main__": unittest.main() diff --git a/sgl-router/py_test/test_launch_server.py b/sgl-router/py_test/test_launch_server.py index e11602933a6..80659fc4f3e 100644 --- a/sgl-router/py_test/test_launch_server.py +++ b/sgl-router/py_test/test_launch_server.py @@ -22,6 +22,7 @@ def popen_launch_router( timeout: float, policy: str = "cache_aware", max_payload_size: int = None, + api_key: str = None, ): """ Launch the router server process. @@ -33,6 +34,7 @@ def popen_launch_router( timeout: Server launch timeout policy: Router policy, one of "cache_aware", "round_robin", "random" max_payload_size: Maximum payload size in bytes + api_key: API key for the router """ _, host, port = base_url.split(":") host = host[2:] @@ -55,6 +57,9 @@ def popen_launch_router( policy, ] + if api_key is not None: + command.extend(["--api-key", api_key]) + if max_payload_size is not None: command.extend(["--router-max-payload-size", str(max_payload_size)]) @@ -333,6 +338,57 @@ def test_4_payload_size(self): f"1.2MB payload should fail with 413 but got status {response.status_code}", ) + def test_5_api_key(self): + print("Running test_5_api_key...") + + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + api_key="correct_api_key", + ) + + # # Test case 1: request without api key should fail + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is, ", "temperature": 0}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, + 401, + "Request without api key should fail with 401", + ) + + # Test case 2: request with invalid api key should fail + with requests.Session() as session: + response = requests.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is, ", "temperature": 0}, + headers={"Authorization": "Bearer 123"}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, + 401, + "Request with invalid api key should fail with 401", + ) + + # Test case 3: request with correct api key should succeed + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is ", "temperature": 0}, + headers={"Authorization": "Bearer correct_api_key"}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, 200, "Request with correct api key should succeed" + ) + if __name__ == "__main__": unittest.main() diff --git a/sgl-router/pyproject.toml b/sgl-router/pyproject.toml index 20096b6b491..da5c44a1196 100644 --- a/sgl-router/pyproject.toml +++ b/sgl-router/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang-router" -version = "0.1.1" +version = "0.1.4" description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances." authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}] requires-python = ">=3.8" @@ -20,6 +20,10 @@ classifiers = [ [tool.setuptools.packages] find = { where = ["py_src"] } +# workaround for https://github.com/pypa/twine/issues/1216 +[tool.setuptools] +license-files = [] + [[tool.setuptools-rust.ext-modules]] target = "sglang_router_rs" path = "Cargo.toml" diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 2d8cf4c0c8d..ba9aeac1fef 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -17,6 +17,8 @@ struct Router { port: u16, worker_urls: Vec, policy: PolicyType, + worker_startup_timeout_secs: u64, + worker_startup_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, @@ -34,6 +36,8 @@ impl Router { policy = PolicyType::RoundRobin, host = String::from("127.0.0.1"), port = 3001, + worker_startup_timeout_secs = 300, + worker_startup_check_interval = 10, cache_threshold = 0.50, balance_abs_threshold = 32, balance_rel_threshold = 1.0001, @@ -47,6 +51,8 @@ impl Router { policy: PolicyType, host: String, port: u16, + worker_startup_timeout_secs: u64, + worker_startup_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, @@ -60,6 +66,8 @@ impl Router { port, worker_urls, policy, + worker_startup_timeout_secs, + worker_startup_check_interval, cache_threshold, balance_abs_threshold, balance_rel_threshold, @@ -72,9 +80,17 @@ impl Router { fn start(&self) -> PyResult<()> { let policy_config = match &self.policy { - PolicyType::Random => router::PolicyConfig::RandomConfig, - PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig, + PolicyType::Random => router::PolicyConfig::RandomConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, + }, + PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, + }, PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, cache_threshold: self.cache_threshold, balance_abs_threshold: self.balance_abs_threshold, balance_rel_threshold: self.balance_rel_threshold, @@ -93,10 +109,9 @@ impl Router { max_payload_size: self.max_payload_size, }) .await - .unwrap(); - }); - - Ok(()) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(()) + }) } } diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs index 08f6cdefa75..5ee34c59869 100644 --- a/sgl-router/src/router.rs +++ b/sgl-router/src/router.rs @@ -3,7 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; use bytes::Bytes; use futures_util::{StreamExt, TryStreamExt}; -use log::{debug, info, warn}; +use log::{debug, error, info, warn}; use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; @@ -12,14 +12,30 @@ use std::thread; use std::time::Duration; use tokio; +fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { + req.headers() + .iter() + .filter_map(|(name, value)| { + value + .to_str() + .ok() + .map(|v| (name.to_string(), v.to_string())) + }) + .collect() +} + #[derive(Debug)] pub enum Router { RoundRobin { worker_urls: Arc>>, current_index: AtomicUsize, + timeout_secs: u64, + interval_secs: u64, }, Random { worker_urls: Arc>>, + timeout_secs: u64, + interval_secs: u64, }, CacheAware { /* @@ -89,36 +105,73 @@ pub enum Router { cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, + timeout_secs: u64, + interval_secs: u64, _eviction_thread: Option>, }, } #[derive(Debug, Clone)] pub enum PolicyConfig { - RandomConfig, - RoundRobinConfig, + RandomConfig { + timeout_secs: u64, + interval_secs: u64, + }, + RoundRobinConfig { + timeout_secs: u64, + interval_secs: u64, + }, CacheAwareConfig { cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, + timeout_secs: u64, + interval_secs: u64, }, } impl Router { pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result { + // Get timeout and interval from policy config + let (timeout_secs, interval_secs) = match &policy_config { + PolicyConfig::RandomConfig { + timeout_secs, + interval_secs, + } => (*timeout_secs, *interval_secs), + PolicyConfig::RoundRobinConfig { + timeout_secs, + interval_secs, + } => (*timeout_secs, *interval_secs), + PolicyConfig::CacheAwareConfig { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + }; + // Wait until all workers are healthy - Self::wait_for_healthy_workers(&worker_urls, 300, 10)?; + Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?; // Create router based on policy... Ok(match policy_config { - PolicyConfig::RandomConfig => Router::Random { + PolicyConfig::RandomConfig { + timeout_secs, + interval_secs, + } => Router::Random { worker_urls: Arc::new(RwLock::new(worker_urls)), + timeout_secs, + interval_secs, }, - PolicyConfig::RoundRobinConfig => Router::RoundRobin { + PolicyConfig::RoundRobinConfig { + timeout_secs, + interval_secs, + } => Router::RoundRobin { worker_urls: Arc::new(RwLock::new(worker_urls)), current_index: std::sync::atomic::AtomicUsize::new(0), + timeout_secs, + interval_secs, }, PolicyConfig::CacheAwareConfig { cache_threshold, @@ -126,6 +179,8 @@ impl Router { balance_rel_threshold, eviction_interval_secs, max_tree_size, + timeout_secs, + interval_secs, } => { let mut running_queue = HashMap::new(); for url in &worker_urls { @@ -176,6 +231,8 @@ impl Router { cache_threshold, balance_abs_threshold, balance_rel_threshold, + timeout_secs, + interval_secs, _eviction_thread: Some(eviction_thread), } } @@ -192,9 +249,13 @@ impl Router { loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { + error!( + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls + ); return Err(format!( - "Timeout {}s waiting for workers to become healthy", - timeout_secs + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls )); } @@ -238,7 +299,7 @@ impl Router { fn select_first_worker(&self) -> Result { match self { Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } + | Router::Random { worker_urls, .. } | Router::CacheAware { worker_urls, .. } => { if worker_urls.read().unwrap().is_empty() { Err("No workers are available".to_string()) @@ -254,8 +315,18 @@ impl Router { client: &reqwest::Client, worker_url: &str, route: &str, + req: &HttpRequest, ) -> HttpResponse { - match client.get(format!("{}{}", worker_url, route)).send().await { + let mut request_builder = client.get(format!("{}{}", worker_url, route)); + + // Copy all headers from original request except for /health because it does not need authorization + if route != "/health" { + for (name, value) in copy_request_headers(req) { + request_builder = request_builder.header(name, value); + } + } + + match request_builder.send().await { Ok(res) => { let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); @@ -273,7 +344,12 @@ impl Router { } } - pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse { + pub async fn route_to_first( + &self, + client: &reqwest::Client, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { const MAX_REQUEST_RETRIES: u32 = 3; const MAX_TOTAL_RETRIES: u32 = 6; let mut total_retries = 0; @@ -289,10 +365,17 @@ impl Router { info!("Retrying request after {} failed attempts", total_retries); } - let response = self.send_request(client, &worker_url, route).await; + let response = self.send_request(client, &worker_url, route, req).await; if response.status().is_success() { return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + return response; + } } warn!( @@ -349,6 +432,7 @@ impl Router { Router::RoundRobin { worker_urls, current_index, + .. } => { let idx = current_index .fetch_update( @@ -360,7 +444,7 @@ impl Router { worker_urls.read().unwrap()[idx].clone() } - Router::Random { worker_urls } => worker_urls.read().unwrap() + Router::Random { worker_urls, .. } => worker_urls.read().unwrap() [rand::random::() % worker_urls.read().unwrap().len()] .clone(), @@ -446,19 +530,16 @@ impl Router { .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) .unwrap_or(false); - let res = match client + let mut request_builder = client .post(format!("{}{}", worker_url, route)) - .header( - "Content-Type", - req.headers() - .get("Content-Type") - .and_then(|h| h.to_str().ok()) - .unwrap_or("application/json"), - ) - .body(body.to_vec()) - .send() - .await - { + .body(body.to_vec()); + + // Copy all headers from original request + for (name, value) in copy_request_headers(req) { + request_builder = request_builder.header(name, value); + } + + let res = match request_builder.send().await { Ok(res) => res, Err(_) => return HttpResponse::InternalServerError().finish(), }; @@ -546,6 +627,13 @@ impl Router { if response.status().is_success() { return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + return response; + } } warn!( @@ -570,16 +658,35 @@ impl Router { } pub async fn add_worker(&self, worker_url: &str) -> Result { - let interval_secs = 10; // check every 10 seconds - let timeout_secs = 300; // 5 minutes + let (timeout_secs, interval_secs) = match self { + Router::Random { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + Router::RoundRobin { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + Router::CacheAware { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + }; let start_time = std::time::Instant::now(); let client = reqwest::Client::new(); loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { + error!( + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_url + ); return Err(format!( - "Timeout {}s waiting for worker {} to become healthy", + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", timeout_secs, worker_url )); } @@ -589,7 +696,7 @@ impl Router { if res.status().is_success() { match self { Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } + | Router::Random { worker_urls, .. } | Router::CacheAware { worker_urls, .. } => { info!("Worker {} health check passed", worker_url); let mut urls = worker_urls.write().unwrap(); @@ -663,7 +770,7 @@ impl Router { pub fn remove_worker(&self, worker_url: &str) { match self { Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } + | Router::Random { worker_urls, .. } | Router::CacheAware { worker_urls, .. } => { let mut urls = worker_urls.write().unwrap(); if let Some(index) = urls.iter().position(|url| url == &worker_url) { diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 09878f07f8e..0706c57c06c 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -18,45 +18,45 @@ impl AppState { worker_urls: Vec, client: reqwest::Client, policy_config: PolicyConfig, - ) -> Self { + ) -> Result { // Create router based on policy - let router = match Router::new(worker_urls, policy_config) { - Ok(router) => router, - Err(error) => panic!("Failed to create router: {}", error), - }; - - Self { router, client } + let router = Router::new(worker_urls, policy_config)?; + Ok(Self { router, client }) } } #[get("/health")] -async fn health(data: web::Data) -> impl Responder { - data.router.route_to_first(&data.client, "/health").await +async fn health(req: HttpRequest, data: web::Data) -> impl Responder { + data.router + .route_to_first(&data.client, "/health", &req) + .await } #[get("/health_generate")] -async fn health_generate(data: web::Data) -> impl Responder { +async fn health_generate(req: HttpRequest, data: web::Data) -> impl Responder { data.router - .route_to_first(&data.client, "/health_generate") + .route_to_first(&data.client, "/health_generate", &req) .await } #[get("/get_server_info")] -async fn get_server_info(data: web::Data) -> impl Responder { +async fn get_server_info(req: HttpRequest, data: web::Data) -> impl Responder { data.router - .route_to_first(&data.client, "/get_server_info") + .route_to_first(&data.client, "/get_server_info", &req) .await } #[get("/v1/models")] -async fn v1_models(data: web::Data) -> impl Responder { - data.router.route_to_first(&data.client, "/v1/models").await +async fn v1_models(req: HttpRequest, data: web::Data) -> impl Responder { + data.router + .route_to_first(&data.client, "/v1/models", &req) + .await } #[get("/get_model_info")] -async fn get_model_info(data: web::Data) -> impl Responder { +async fn get_model_info(req: HttpRequest, data: web::Data) -> impl Responder { data.router - .route_to_first(&data.client, "/get_model_info") + .route_to_first(&data.client, "/get_model_info", &req) .await } @@ -131,6 +131,7 @@ pub struct ServerConfig { } pub async fn startup(config: ServerConfig) -> std::io::Result<()> { + // Initialize logger Builder::new() .format(|buf, record| { use chrono::Local; @@ -152,24 +153,30 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ) .init(); + info!("🚧 Initializing router on {}:{}", config.host, config.port); + info!("🚧 Initializing workers on {:?}", config.worker_urls); + info!("🚧 Policy Config: {:?}", config.policy_config); + info!( + "🚧 Max payload size: {} MB", + config.max_payload_size / (1024 * 1024) + ); + let client = reqwest::Client::builder() .build() .expect("Failed to create HTTP client"); - let app_state = web::Data::new(AppState::new( - config.worker_urls.clone(), - client, - config.policy_config.clone(), - )); - - info!("✅ Starting router on {}:{}", config.host, config.port); - info!("✅ Serving Worker URLs: {:?}", config.worker_urls); - info!("✅ Policy Config: {:?}", config.policy_config); - info!( - "✅ Max payload size: {} MB", - config.max_payload_size / (1024 * 1024) + let app_state = web::Data::new( + AppState::new( + config.worker_urls.clone(), + client, + config.policy_config.clone(), + ) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, ); + info!("✅ Serving router on {}:{}", config.host, config.port); + info!("✅ Serving workers on {:?}", config.worker_urls); + HttpServer::new(move || { App::new() .app_data(app_state.clone()) diff --git a/test/lang/run_suite.py b/test/lang/run_suite.py index ebc26e608c1..327d18b3fbd 100644 --- a/test/lang/run_suite.py +++ b/test/lang/run_suite.py @@ -4,7 +4,11 @@ from sglang.test.test_utils import run_unittest_files suites = { - "per-commit": ["test_srt_backend.py", "test_openai_backend.py"], + "per-commit": [ + "test_srt_backend.py", + # Skip this due to some OPENAI_API_KEY issues + # "test_openai_backend.py", + ], } diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index b99606fc1cb..a4b1b88a23d 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -1,6 +1,7 @@ """ Usage: python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens +python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select """ import unittest @@ -73,7 +74,7 @@ def test_hellaswag_select(self): # Run twice to capture more bugs for _ in range(2): accuracy, latency = test_hellaswag_select() - self.assertGreater(accuracy, 0.71) + self.assertGreater(accuracy, 0.70) def test_gen_min_new_tokens(self): test_gen_min_new_tokens() diff --git a/test/srt/models/test_qwen_models.py b/test/srt/models/test_qwen_models.py index 903fd45d550..c7788fa8e50 100644 --- a/test/srt/models/test_qwen_models.py +++ b/test/srt/models/test_qwen_models.py @@ -37,8 +37,7 @@ def test_gsm8k(self): port=int(self.base_url.split(":")[-1]), ) metrics = run_eval(args) - print(metrics) - + print(f"{metrics=}") self.assertGreater(metrics["accuracy"], 0.81) @@ -69,9 +68,8 @@ def test_gsm8k(self): port=int(self.base_url.split(":")[-1]), ) metrics = run_eval(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.8) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.79) if __name__ == "__main__": diff --git a/test/srt/models/test_reward_models.py b/test/srt/models/test_reward_models.py index 0d80a4d0cde..69ad563671b 100644 --- a/test/srt/models/test_reward_models.py +++ b/test/srt/models/test_reward_models.py @@ -20,8 +20,8 @@ from sglang.test.runners import HFRunner, SRTRunner MODELS = [ - ("LxzGordon/URM-LLaMa-3.1-8B", 1, 3e-2), - ("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 3e-2), + ("LxzGordon/URM-LLaMa-3.1-8B", 1, 4e-2), + ("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 4e-2), ] TORCH_DTYPES = [torch.float16] diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index fb1c6abf29b..603bab957bd 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -18,7 +18,6 @@ "test_eagle_infer.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", - "test_get_weights_by_name.py", "test_gguf.py", "test_input_embeddings.py", "test_json_constrained.py", @@ -31,6 +30,7 @@ "test_openai_server.py", "test_pytorch_sampling_backend.py", "test_radix_attention.py", + "test_regex_constrained.py", "test_release_memory_occupation.py", "test_request_length_validation.py", "test_retract_decode.py", @@ -41,18 +41,18 @@ "test_srt_endpoint.py", "test_torch_compile.py", "test_torch_compile_moe.py", - # Temporarily disable this because it requires PyTorch >= 2.5 - # "test_torch_native_attention_backend.py", + "test_torch_native_attention_backend.py", "test_torchao.py", "test_triton_attention_kernels.py", "test_triton_attention_backend.py", "test_update_weights_from_disk.py", "test_update_weights_from_tensor.py", "test_vision_chunked_prefill.py", + "test_vision_llm.py", "test_vision_openai_server.py", "test_w8a8_quantization.py", - "test_session_control.py", "test_fp8_kvcache.py", + "test_fp8_kernel.py", ], "nightly": [ "test_nightly_gsm8k_eval.py", @@ -73,7 +73,6 @@ tests.remove(target_suite_name) tests.extend(target_tests) - if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument( diff --git a/test/srt/test_bench_one_batch.py b/test/srt/test_bench_one_batch.py index c1bc98e8e04..c6562170d61 100644 --- a/test/srt/test_bench_one_batch.py +++ b/test/srt/test_bench_one_batch.py @@ -5,24 +5,46 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST, is_in_ci, run_bench_one_batch, + write_github_step_summary, ) class TestBenchOneBatch(unittest.TestCase): - def test_default(self): + def test_bs1(self): output_throughput = run_bench_one_batch(DEFAULT_MODEL_NAME_FOR_TEST, []) if is_in_ci(): + write_github_step_summary( + f"### test_bs1\n" + f"output_throughput : {output_throughput:.2f} token/s\n" + ) self.assertGreater(output_throughput, 135) - def test_moe_default(self): + def test_moe_tp2_bs1(self): output_throughput = run_bench_one_batch( DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"] ) if is_in_ci(): + write_github_step_summary( + f"### test_moe_tp2_bs1\n" + f"output_throughput : {output_throughput:.2f} token/s\n" + ) self.assertGreater(output_throughput, 125) + def test_torch_compile_tp2_bs1(self): + output_throughput = run_bench_one_batch( + DEFAULT_MODEL_NAME_FOR_TEST, + ["--tp", "2", "--enable-torch-compile", "--cuda-graph-max-bs", "2"], + ) + + if is_in_ci(): + write_github_step_summary( + f"### test_torch_compile_tp2_bs1\n" + f"output_throughput : {output_throughput:.2f} token/s\n" + ) + self.assertGreater(output_throughput, 240) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index b882f12f9df..8233438fcaf 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -1,6 +1,8 @@ import unittest from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_FP8_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST, @@ -47,7 +49,7 @@ def test_offline_throughput_non_stream_small_batch_size(self): ) # There is a regression with torch 2.5 # This number was 950 for torch 2.4 - self.assertGreater(res["output_throughput"], 800) + self.assertGreater(res["output_throughput"], 1000) def test_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -112,7 +114,7 @@ def test_offline_throughput_default_fp8(self): f"### test_offline_throughput_default_fp8\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 3850) + self.assertGreater(res["output_throughput"], 3900) def test_online_latency_default(self): res = run_bench_serving( @@ -127,10 +129,40 @@ def test_online_latency_default(self): f"### test_online_latency_default\n" f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' ) - self.assertLess(res["median_e2e_latency_ms"], 12000) + self.assertLess(res["median_e2e_latency_ms"], 11000) self.assertLess(res["median_ttft_ms"], 86) self.assertLess(res["median_itl_ms"], 10) + def test_online_latency_eagle(self): + res = run_bench_serving( + model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + num_prompts=50, + request_rate=1, + disable_ignore_eos=True, + dataset_name="sharegpt", + other_server_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "8", + "--speculative-num-draft-tokens", + "64", + "--mem-fraction-static", + "0.7", + ], + ) + + if is_in_ci(): + write_github_step_summary( + f"### test_online_latency_eagle\n" + f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' + ) + self.assertLess(res["median_e2e_latency_ms"], 450) + def test_moe_offline_throughput_default(self): res = run_bench_serving( model=DEFAULT_MOE_MODEL_NAME_FOR_TEST, @@ -144,7 +176,7 @@ def test_moe_offline_throughput_default(self): f"### test_moe_offline_throughput_default\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 2150) + self.assertGreater(res["output_throughput"], 2200) def test_moe_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -159,7 +191,7 @@ def test_moe_offline_throughput_without_radix_cache(self): f"### test_moe_offline_throughput_without_radix_cache\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 2150) + self.assertGreater(res["output_throughput"], 2200) if __name__ == "__main__": diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 92127b8ef59..b01c260496a 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -1,14 +1,18 @@ -import multiprocessing import random +import threading import time import unittest +from types import SimpleNamespace import requests -from transformers import AutoConfig, AutoTokenizer import sglang as sgl +from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, @@ -19,60 +23,59 @@ class TestEAGLEEngine(unittest.TestCase): def test_eagle_accuracy(self): prompt = "Today is a sunny day and I like" - target_model_path = "meta-llama/Llama-2-7b-chat-hf" - speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B" - sampling_params = {"temperature": 0, "max_new_tokens": 8} + # Get the reference output + ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) + ref_output = ref_engine.generate(prompt, sampling_params)["text"] + ref_engine.shutdown() + + # Launch EAGLE engine engine = sgl.Engine( - model_path=target_model_path, - speculative_draft_model_path=speculative_draft_model_path, + model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, speculative_algorithm="EAGLE", - speculative_num_steps=3, - speculative_eagle_topk=4, - speculative_num_draft_tokens=16, + speculative_num_steps=5, + speculative_eagle_topk=8, + speculative_num_draft_tokens=64, + mem_fraction_static=0.7, ) - out1 = engine.generate(prompt, sampling_params)["text"] - engine.shutdown() - - engine = sgl.Engine(model_path=target_model_path) - out2 = engine.generate(prompt, sampling_params)["text"] - engine.shutdown() - print("==== Answer 1 ====") - print(out1) - - print("==== Answer 2 ====") - print(out2) - self.assertEqual(out1, out2) + # Case 1: Test the output of EAGLE engine is the same as normal engine + out1 = engine.generate(prompt, sampling_params)["text"] + print(f"{out1=}, {ref_output=}") + self.assertEqual(out1, ref_output) - def test_eagle_end_check(self): + # Case 2: Test the output of EAGLE engine does not contain unexpected EOS prompt = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like [/INST]" - target_model_path = "meta-llama/Llama-2-7b-chat-hf" - tokenizer = AutoTokenizer.from_pretrained(target_model_path) - speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B" - sampling_params = { "temperature": 0, "max_new_tokens": 1024, "skip_special_tokens": False, } - engine = sgl.Engine( - model_path=target_model_path, - speculative_draft_model_path=speculative_draft_model_path, - speculative_algorithm="EAGLE", - speculative_num_steps=3, - speculative_eagle_topk=4, - speculative_num_draft_tokens=16, - ) - out1 = engine.generate(prompt, sampling_params)["text"] - engine.shutdown() - print("==== Answer 1 ====") - print(repr(out1)) - tokens = tokenizer.encode(out1, truncation=False) + tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) + out2 = engine.generate(prompt, sampling_params)["text"] + print(f"{out2=}") + tokens = tokenizer.encode(out2, truncation=False) assert tokenizer.eos_token_id not in tokens + # Case 3: Batched prompts + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = {"temperature": 0, "max_new_tokens": 30} + outputs = engine.generate(prompts, sampling_params) + for prompt, output in zip(prompts, outputs): + print("===============================") + print(f"Prompt: {prompt}\nGenerated text: {output['text']}") + + # Shutdown the engine + engine.shutdown() + prompts = [ "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]" @@ -83,64 +86,27 @@ def test_eagle_end_check(self): ] -def process(server_url: str): - time.sleep(random.uniform(0, 2)) - for prompt in prompts: - url = server_url - data = { - "model": "base", - "text": prompt, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 1024, - }, - } - response = requests.post(url, json=data) - assert response.status_code == 200 - - -def abort_process(server_url: str): - for prompt in prompts: - try: - time.sleep(1) - url = server_url - data = { - "model": "base", - "text": prompt, - "sampling_params": { - "temperature": 0, - "max_new_tokens": 1024, - }, - } - # set timeout = 1s,mock disconnected - requests.post(url, json=data, timeout=1) - except: - pass - - -class TestEAGLELaunchServer(unittest.TestCase): +class TestEAGLEServer(unittest.TestCase): @classmethod def setUpClass(cls): - speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B" - cls.model = "meta-llama/Llama-2-7b-chat-hf" cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( - cls.model, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--speculative-algorithm", "EAGLE", "--speculative-draft-model-path", - speculative_draft_model_path, + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, "--speculative-num-steps", - "3", + "5", "--speculative-eagle-topk", - "4", + "8", "--speculative-num-draft-tokens", - "16", - "--served-model-name", - "base", + "64", + "--mem-fraction-static", + "0.7", ], ) @@ -148,40 +114,67 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_eagle_server_concurrency(self): - concurrency = 4 - processes = [ - multiprocessing.Process( - target=process, - kwargs={"server_url": self.base_url + "/generate"}, - ) - for _ in range(concurrency) - ] - for worker in processes: - worker.start() - for p in processes: - p.join() - - def test_eagle_server_request_abort(self): + def send_request(self): + time.sleep(random.uniform(0, 2)) + for prompt in prompts: + url = self.base_url + "/generate" + data = { + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 1024, + }, + } + response = requests.post(url, json=data) + assert response.status_code == 200 + + def send_requests_abort(self): + for prompt in prompts: + try: + time.sleep(random.uniform(0, 2)) + url = self.base_url + "/generate" + data = { + "model": "base", + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 1024, + }, + } + # set timeout = 1s,mock disconnected + requests.post(url, json=data, timeout=1) + except Exception as e: + print(e) + pass + + def test_request_abort(self): concurrency = 4 - processes = [ - multiprocessing.Process( - target=process, - kwargs={"server_url": self.base_url + "/generate"}, - ) - for _ in range(concurrency) + threads = [ + threading.Thread(target=self.send_request) for _ in range(concurrency) ] + [ - multiprocessing.Process( - target=abort_process, - kwargs={"server_url": self.base_url + "/generate"}, - ) + threading.Thread(target=self.send_requests_abort) for _ in range(concurrency) ] - for worker in processes: + for worker in threads: worker.start() - for p in processes: + for p in threads: p.join() + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + + self.assertGreater(metrics["accuracy"], 0.20) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_ebnf_constrained.py b/test/srt/test_ebnf_constrained.py index 97b6f756118..5e852bec6e4 100644 --- a/test/srt/test_ebnf_constrained.py +++ b/test/srt/test_ebnf_constrained.py @@ -236,12 +236,5 @@ def test_ebnf_generate_custom_log_format(self): ) -class TestJumpForward(TestEBNFConstrained): - @classmethod - def setUpClass(cls): - setup_class(cls, disable_overlap=True) - cls.check_jump_forward = True - - if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_fp8_kernel.py b/test/srt/test_fp8_kernel.py new file mode 100644 index 00000000000..fe92bfd0769 --- /dev/null +++ b/test/srt/test_fp8_kernel.py @@ -0,0 +1,127 @@ +import unittest + +import torch + +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_fp8, + w8a8_block_fp8_matmul, +) + + +class TestFP8Base(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.M = 256 + # test non-aligned + cls.N = 1024 + 64 + cls.K = 512 + cls.group_size = 128 + cls.quant_type = torch.float8_e4m3fn + cls.output_type = torch.float16 + + @staticmethod + def _make_A(M, K, group_size, out_dtype): + quant_A = torch.rand( + M, K // group_size, group_size, dtype=torch.float32, device="cuda" + ) + # -1 ~ 1 + quant_A = quant_A * 2 - 1 + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_A.abs().amax(-1, keepdim=True) + quant_A *= scaling + quant_A = quant_A.to(out_dtype).to(torch.float32) + + # create scale and A + scale = torch.rand(M, K // group_size, dtype=torch.float32, device="cuda") + scale /= fmax + A = quant_A * scale[..., None] + + A = A.reshape(M, K) + quant_A = quant_A.reshape(M, K).to(out_dtype) + return A, quant_A, scale + + @staticmethod + def _make_B(K, N, group_size, out_dtype): + def _aligned_size(a, b): + return (a + b - 1) // b * b + + K_aligned = _aligned_size(K, group_size) + N_aligned = _aligned_size(N, group_size) + + quant_B = torch.rand( + K_aligned // group_size, + group_size, + N_aligned // group_size, + group_size, + dtype=torch.float32, + device="cuda", + ) + quant_B = quant_B * 2 - 1 + + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True) + quant_B *= scaling + quant_B = quant_B.to(out_dtype).to(torch.float32) + + scale = torch.rand( + K_aligned // group_size, + 1, + N_aligned // group_size, + 1, + dtype=torch.float32, + device="cuda", + ) + scale /= fmax + + B = quant_B * scale + + B = B.reshape(K_aligned, N_aligned)[:K, :N] + quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N] + scale = scale.reshape(K_aligned // group_size, N_aligned // group_size) + return B, quant_B, scale + + +class TestPerTokenGroupQuantFP8(TestFP8Base): + def test_per_token_group_quant_fp8(self): + if torch.cuda.get_device_capability()[0] < 9: + return + A, A_quant_gt, scale_gt = self._make_A( + M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type + ) + A_quant, scale = per_token_group_quant_fp8( + x=A, group_size=self.group_size, dtype=self.quant_type + ) + torch.testing.assert_close(scale, scale_gt) + diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs() + diff_count = (diff > 1e-5).count_nonzero() + assert diff_count / diff.numel() < 1e-4 + + +class TestW8A8BlockFP8Matmul(TestFP8Base): + def test_w8a8_block_fp8_matmul(self): + if torch.cuda.get_device_capability()[0] < 9: + return + A, A_quant_gt, A_scale_gt = self._make_A( + M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type + ) + B, B_quant_gt, B_scale_gt = self._make_B( + K=self.K, N=self.N, group_size=self.group_size, out_dtype=self.quant_type + ) + C_gt = A.to(self.output_type) @ B.to(self.output_type) + C = w8a8_block_fp8_matmul( + A=A_quant_gt, + B=B_quant_gt.T.contiguous(), + As=A_scale_gt, + Bs=B_scale_gt.T.contiguous(), + block_size=[128, 128], + output_dtype=self.output_type, + ) + torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_function_calling.py b/test/srt/test_function_calling.py new file mode 100644 index 00000000000..24f341a5e47 --- /dev/null +++ b/test/srt/test_function_calling.py @@ -0,0 +1,249 @@ +import json +import time +import unittest + +import openai + +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestOpenAIServerFunctionCalling(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + # Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools. + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + # If your server needs extra parameters to test function calling, please add them here. + "--tool-call-parser", + "llama3", + ], + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(cls.model) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_function_calling_format(self): + """ + Test: Whether the function call format returned by the AI is correct. + When returning a tool call, message.content should be None, and tool_calls should be a list. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "add", + "description": "Compute the sum of two numbers", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "int", + "description": "A number", + }, + "b": { + "type": "int", + "description": "A number", + }, + }, + "required": ["a", "b"], + }, + }, + } + ] + + messages = [{"role": "user", "content": "Compute (3+5)"}] + response = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=False, + tools=tools, + ) + + content = response.choices[0].message.content + tool_calls = response.choices[0].message.tool_calls + + assert content is None, ( + "When function call is successful, message.content should be None, " + f"but got: {content}" + ) + assert ( + isinstance(tool_calls, list) and len(tool_calls) > 0 + ), "tool_calls should be a non-empty list" + + function_name = tool_calls[0].function.name + assert function_name == "add", "Function name should be 'add'" + + def test_function_calling_streaming_simple(self): + """ + Test: Whether the function name can be correctly recognized in streaming mode. + - Expect a function call to be found, and the function name to be correct. + - Verify that streaming mode returns at least multiple chunks. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for", + }, + "unit": { + "type": "string", + "description": "Weather unit (celsius or fahrenheit)", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "unit"], + }, + }, + } + ] + + messages = [{"role": "user", "content": "What is the temperature in Paris?"}] + + response_stream = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=True, + tools=tools, + ) + + chunks = list(response_stream) + self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk") + + found_function_name = False + for chunk in chunks: + choice = chunk.choices[0] + # Check whether the current chunk contains tool_calls + if choice.delta.tool_calls: + tool_call = choice.delta.tool_calls[0] + if tool_call.function.name: + self.assertEqual( + tool_call.function.name, + "get_current_weather", + "Function name should be 'get_current_weather'", + ) + found_function_name = True + break + + self.assertTrue( + found_function_name, + "Target function name 'get_current_weather' was not found in the streaming chunks", + ) + + def test_function_calling_streaming_args_parsing(self): + """ + Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON. + - The user request requires multiple parameters. + - AI may return the arguments in chunks that need to be concatenated. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "add", + "description": "Compute the sum of two integers", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "int", + "description": "First integer", + }, + "b": { + "type": "int", + "description": "Second integer", + }, + }, + "required": ["a", "b"], + }, + }, + } + ] + + messages = [ + {"role": "user", "content": "Please sum 5 and 7, just call the function."} + ] + + response_stream = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.9, + top_p=0.9, + stream=True, + tools=tools, + ) + + argument_fragments = [] + function_name = None + for chunk in response_stream: + choice = chunk.choices[0] + if choice.delta.tool_calls: + tool_call = choice.delta.tool_calls[0] + # Record the function name on first occurrence + function_name = tool_call.function.name or function_name + # In case of multiple chunks, JSON fragments may need to be concatenated + if tool_call.function.arguments: + argument_fragments.append(tool_call.function.arguments) + + self.assertEqual(function_name, "add", "Function name should be 'add'") + joined_args = "".join(argument_fragments) + self.assertTrue( + len(joined_args) > 0, + "No parameter fragments were returned in the function call", + ) + + # Check whether the concatenated JSON is valid + try: + args_obj = json.loads(joined_args) + except json.JSONDecodeError: + self.fail( + "The concatenated tool call arguments are not valid JSON, parsing failed" + ) + + self.assertIn("a", args_obj, "Missing parameter 'a'") + self.assertIn("b", args_obj, "Missing parameter 'b'") + self.assertEqual( + args_obj["a"], + 5, + "Parameter a should be 5", + ) + self.assertEqual(args_obj["b"], 7, "Parameter b should be 7") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py index 69babf795f0..2837107a1e6 100644 --- a/test/srt/test_metrics.py +++ b/test/srt/test_metrics.py @@ -56,7 +56,6 @@ def test_metrics_enabled(self): "sglang:gen_throughput", "sglang:num_queue_reqs", "sglang:cache_hit_rate", - "sglang:func_latency_seconds", "sglang:prompt_tokens_total", "sglang:generation_tokens_total", "sglang:num_requests_total", diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py index 2e379c11179..6fe36171504 100644 --- a/test/srt/test_nightly_gsm8k_eval.py +++ b/test/srt/test_nightly_gsm8k_eval.py @@ -27,7 +27,7 @@ "google/gemma-2-27b-it": 0.92, "meta-llama/Llama-3.1-70B-Instruct": 0.95, "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.63, - "Qwen/Qwen2-57B-A14B-Instruct": 0.87, + "Qwen/Qwen2-57B-A14B-Instruct": 0.86, "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.83, "neuralmagic/Mistral-7B-Instruct-v0.3-FP8": 0.54, "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.84, @@ -45,7 +45,7 @@ def parse_models(model_string): return [model.strip() for model in model_string.split(",") if model.strip()] -def launch_server(base_url, model, is_fp8, is_tp2): +def popen_launch_server_wrapper(base_url, model, is_fp8, is_tp2): other_args = ["--log-level-http", "warning", "--trust-remote-code"] if is_fp8: if "Llama-3" in model or "gemma-2" in model: @@ -148,7 +148,9 @@ def test_mgsm_en_all_models(self): for model_group, is_fp8, is_tp2 in self.model_groups: for model in model_group: with self.subTest(model=model): - process = launch_server(self.base_url, model, is_fp8, is_tp2) + process = popen_launch_server_wrapper( + self.base_url, model, is_fp8, is_tp2 + ) args = SimpleNamespace( base_url=self.base_url, diff --git a/test/srt/test_nightly_human_eval.py b/test/srt/test_nightly_human_eval.py index 0b682937a82..6558b9effb9 100644 --- a/test/srt/test_nightly_human_eval.py +++ b/test/srt/test_nightly_human_eval.py @@ -4,7 +4,7 @@ import subprocess import unittest -from test_nightly_gsm8k_eval import launch_server, parse_models +from test_nightly_gsm8k_eval import parse_models, popen_launch_server_wrapper from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -93,7 +93,7 @@ def test_human_eval_all_models(self): # NOTE: only Llama for now if "Llama" in model: with self.subTest(model=model): - self.process = launch_server( + self.process = popen_launch_server_wrapper( self.base_url, model, is_fp8, is_tp2 ) self.run_evalplus(model) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 4bedf743966..23e0287292b 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -623,58 +623,6 @@ def test_ebnf_strict_json(self): text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape" ) - def test_function_calling_format(self): - - client = openai.Client(api_key=self.api_key, base_url=self.base_url) - tools = [ - { - "type": "function", - "function": { - "name": "add", - "description": "Compute the sum of two numbers", - "parameters": { - "type": "object", - "properties": { - "a": { - "type": "int", - "description": "A number", - }, - "b": { - "type": "int", - "description": "A number", - }, - }, - "required": ["a", "b"], - }, - }, - } - ] - - messages = [{"role": "user", "content": "Compute (3+5)"}] - response = client.chat.completions.create( - model=self.model, - messages=messages, - temperature=0.8, - top_p=0.8, - stream=False, - tools=tools, - ) - - content = response.choices[0].message.content - tool_calls = response.choices[0].message.tool_calls - - assert ( - content is None - ), "When tools provided by the response, content should be None" - assert ( - isinstance(tool_calls, list) and len(tool_calls) > 0 - ), "Format not matched, tool_calls should be a list" - - function_name = tool_calls[0].function.name - assert ( - function_name == "add" - ), "Function name should be add for the above response" - class TestOpenAIEmbedding(unittest.TestCase): @classmethod diff --git a/test/srt/test_regex_constrained.py b/test/srt/test_regex_constrained.py new file mode 100644 index 00000000000..6d5acec15e2 --- /dev/null +++ b/test/srt/test_regex_constrained.py @@ -0,0 +1,186 @@ +""" +python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email +python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting +""" + +import json +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +def setup_class(cls, disable_overlap: bool): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + + other_args = [ + "--max-running-requests", + "10", + "--grammar-backend", + "xgrammar", + ] + + if disable_overlap: + other_args += ["--disable-overlap-schedule"] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + +class TestRegexConstrained(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_class(cls, disable_overlap=False) + cls.check_jump_forward = False + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_decode( + self, + regex, + prompt, + return_logprob=False, + top_logprobs_num=0, + n=1, + ): + response = requests.post( + self.base_url + "/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 128, + "n": n, + "regex": regex, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + ) + + ret = response.json() + print(json.dumps(ret, indent=2)) + print("=" * 100) + + if not isinstance(ret, list): + self.fail(f"Expected response to be a list, but got {type(ret)}") + + for item in ret: + text = item.get("text", "").strip() + if not text: + self.fail("Generated text is empty.") + + if not self.regex_match(text, regex): + self.fail(f"Text '{text}' does not match regex pattern.") + + def regex_match(self, text, pattern): + import re + + return re.match(pattern, text) is not None + + def test_regex_generate_email(self): + pattern = r"^user@example\.com$" + prompt = "Generate an email address:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_greeting(self): + pattern = r"^(Hello|Hi|Hey)$" + prompt = "Generate a greeting:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_number(self): + pattern = r"^\d{3}$" + prompt = "Generate a three-digit number:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_phone(self): + pattern = r"^\(\d{3}\) \d{3}-\d{4}$" + prompt = "Generate a phone number:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_date(self): + pattern = r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$" + prompt = "Generate a date in YYYY-MM-DD format:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_hex_color(self): + pattern = r"^#[0-9A-F]{6}$" + prompt = "Generate a hex color code:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_complex_json(self): + pattern = r'^\{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*\}$' + prompt = "Generate a simple JSON with name, age, and city:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_custom_log_format(self): + pattern = r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$" + prompt = "Generate a log entry:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + +class TestJumpForward(TestRegexConstrained): + @classmethod + def setUpClass(cls): + setup_class(cls, disable_overlap=True) + cls.check_jump_forward = True + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_session_control.py b/test/srt/test_session_control.py index 5653e9b69f1..2915133f437 100644 --- a/test/srt/test_session_control.py +++ b/test/srt/test_session_control.py @@ -54,6 +54,7 @@ def test_session_control(self, gen_len=12): chunks_ids[i] = chunks_ids[i][1:] # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -215,7 +216,9 @@ def test_session_control(self, gen_len=12): print(outputs_from_session) print("outputs from normal queries:") print(outputs_normal) - assert outputs_from_session == outputs_normal + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" async def async_generate(self, payload): url = self.base_url + "/generate" @@ -250,6 +253,7 @@ async def run_session_control_backtrack_with_abort(self, replace): chunks_ids[i] = chunks_ids[i][1:] # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -320,6 +324,7 @@ async def run_session_control_backtrack_with_abort(self, replace): assert response["meta_info"]["finish_reason"]["type"] == "abort" else: # 2. not using session control + requests.post(self.base_url + "/flush_cache") output_ids = tokenizer.encode(gen_so_far) if output_ids[0] == tokenizer.bos_token_id: output_ids = output_ids[1:] @@ -342,7 +347,9 @@ async def run_session_control_backtrack_with_abort(self, replace): output_no_session = response["text"] print("second request output without session:") print(output_no_session) - assert second_output == output_no_session + assert ( + second_output == output_no_session + ), f"second_output: {second_output}, output_no_session: {output_no_session}" def test_session_control_backtrack_with_abort(self): asyncio.run(self.run_session_control_backtrack_with_abort(replace=True)) @@ -355,6 +362,7 @@ def run_session_control_with_branching( assert len(x) == len(chunks_per_step[0]) # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -459,7 +467,9 @@ def run_session_control_with_branching( print(outputs_from_session) print("====== outputs from normal queries: =======") print(outputs_normal) - assert outputs_from_session == outputs_normal + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" def test_session_control_with_branching(self): root_prompt = "First, let me explain in one sentence about AI" @@ -525,6 +535,7 @@ def test_session_control(self): gen_len = 32 # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -691,7 +702,9 @@ def test_session_control(self): print(outputs_from_session) print("outputs from normal queries:") print(outputs_normal) - assert outputs_from_session == outputs_normal + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" if __name__ == "__main__": diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 0fd71efcb0b..68db1d69983 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -4,11 +4,16 @@ """ import json +import random +import time import unittest +from concurrent.futures import ThreadPoolExecutor +from typing import Optional import numpy as np import requests +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -24,7 +29,14 @@ def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=( + "--enable-custom-logit-processor", + "--mem-fraction-static", + "0.8", + ), ) @classmethod @@ -147,14 +159,26 @@ def test_logprob_with_chunked_prefill(self): }, "return_logprob": True, "logprob_start_len": -1, + "top_logprobs_num": 5, }, ) response_json = response.json() - print(json.dumps(response_json, indent=2)) + # print(json.dumps(response_json, indent=2)) res = response_json self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) + + # Test the number of tokens are correct self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) + self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), new_tokens) + + # Test the top-1 tokens are the same as output tokens (because temp = 0.0) + for i in range(new_tokens): + self.assertListEqual( + res["meta_info"]["output_token_logprobs"][i], + res["meta_info"]["output_top_logprobs"][i][0], + ) + self.assertEqual(len(res["meta_info"]["output_top_logprobs"][i]), 5) def test_logprob_match(self): """Test the output logprobs are close to the input logprobs if we run a prefill again.""" @@ -213,6 +237,103 @@ def run_generate( max_diff = np.max(diff) self.assertLess(max_diff, 0.25) + def run_logprob_check(self, arg): + ( + input_len, + output_len, + temperature, + logprob_start_len, + return_logprob, + top_logprobs_num, + ) = arg + input_ids = list(range(input_len)) + + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": input_ids, + "sampling_params": { + "temperature": temperature, + "max_new_tokens": output_len, + }, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + }, + ) + response_json = response.json() + + res = response_json + self.assertEqual(res["meta_info"]["prompt_tokens"], input_len) + self.assertEqual(res["meta_info"]["completion_tokens"], output_len) + + # Test the number of tokens are correct + if return_logprob: + # This is because if logprob_start_len == 0, we added a padding for the first token. + # In other cases, we do not add the padding + delta = 0 if logprob_start_len == 0 else 1 + + self.assertEqual( + len(res["meta_info"]["input_token_logprobs"]) + + logprob_start_len + + delta, + res["meta_info"]["prompt_tokens"], + ) + self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len) + + if top_logprobs_num: + self.assertEqual( + len(res["meta_info"]["input_top_logprobs"]) + + logprob_start_len + + delta, + res["meta_info"]["prompt_tokens"], + ) + self.assertEqual( + len(res["meta_info"]["output_top_logprobs"]), output_len + ) + + for i in range(output_len): + self.assertEqual( + len(res["meta_info"]["output_top_logprobs"][i]), + top_logprobs_num, + ) + + # Test the top-1 tokens are the same as output tokens if temperature == 0 + if temperature == 0: + self.assertListEqual( + res["meta_info"]["output_token_logprobs"][i], + res["meta_info"]["output_top_logprobs"][i][0], + ) + + def test_logprob_mixed(self): + args = [] + temperature = 0 + # input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num + for input_len in [1000, 2000]: + for output_len in [4, 8]: + for logprob_start_len in [0, 500, 1000]: + for return_logprob in [True, False]: + for top_logprobs_num in [0, 5]: + + if logprob_start_len >= input_len: + continue + + args.append( + ( + input_len, + output_len, + temperature, + logprob_start_len, + return_logprob, + top_logprobs_num, + ) + ) + + random.shuffle(args) + + with ThreadPoolExecutor(8) as executor: + list(executor.map(self.run_logprob_check, args)) + def test_logprob_grammar(self): prompts = "Question: Is Paris the Capital of France? Answer:" allowed_tokens = [" Yes", " No"] @@ -248,6 +369,100 @@ def test_logprob_grammar(self): self.assertTrue(all(x is not None for x in logprobs)) + def run_custom_logit_processor(self, target_token_id: Optional[int] = None): + """Test custom logit processor with custom params. + + If target_token_id is None, the custom logit processor won't be passed in. + """ + + custom_params = {"token_id": target_token_id} + + class DeterministicLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits to always + sample the given token id. + """ + + def __call__(self, logits, custom_param_list): + assert logits.shape[0] == len(custom_param_list) + key = "token_id" + + for i, param_dict in enumerate(custom_param_list): + # Mask all other tokens + logits[i, :] = -float("inf") + # Assign highest probability to the specified token + logits[i, param_dict[key]] = 0.0 + return logits + + prompts = "Question: Is Paris the Capital of France? Answer:" + + # Base case json data to be posted to the server. + base_json = { + "text": prompts, + "sampling_params": {"temperature": 0.0}, + "return_logprob": True, + } + + # Custom json data with custom logit processor and params. + custom_json = base_json.copy() + # Only set the custom logit processor if target_token_id is not None. + if target_token_id is not None: + custom_json["custom_logit_processor"] = ( + DeterministicLogitProcessor().to_str() + ) + custom_json["sampling_params"]["custom_params"] = custom_params + + custom_response = requests.post( + self.base_url + "/generate", + json=custom_json, + ).json() + + output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"] + sampled_tokens = [x[1] for x in output_token_logprobs] + + # The logit processor should always sample the given token as the logits is deterministic. + if target_token_id is not None: + self.assertTrue( + all(x == custom_params["token_id"] for x in sampled_tokens), + # Print the detailed test case info if the test fails. + f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}", + ) + + def test_custom_logit_processor(self): + """Test custom logit processor with a single request.""" + self.run_custom_logit_processor(target_token_id=5) + + def test_custom_logit_processor_batch_mixed(self): + """Test a batch of requests mixed of requests with and without custom logit processor.""" + target_token_ids = list(range(32)) + [None] * 16 + random.shuffle(target_token_ids) + with ThreadPoolExecutor(len(target_token_ids)) as executor: + list(executor.map(self.run_custom_logit_processor, target_token_ids)) + + def test_cache_tokens(self): + for _ in range(2): + time.sleep(1) + response = requests.post(self.base_url + "/flush_cache") + assert response.status_code == 200 + + def send_and_check_cached_tokens(input_ids): + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": list(input_ids), + "sampling_params": { + "max_new_tokens": 1, + }, + }, + ) + response_json = response.json() + return response_json["meta_info"]["cached_tokens"] + + self.assertEqual(send_and_check_cached_tokens(range(0, 100)), 0) + self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 100) + self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 9999) + self.assertEqual(send_and_check_cached_tokens(range(0, 1000)), 999) + self.assertEqual(send_and_check_cached_tokens(range(0, 11000)), 10000) + def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") response_json = response.json() diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 7479b646837..c535d5c0686 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -1,6 +1,6 @@ """ Usage: -python3 -m unittest test_srt_engine.TestSRTEngine.test_3_sync_streaming_combination +python3 -m unittest test_srt_engine.TestSRTEngine.test_4_sync_async_stream_combination """ import asyncio @@ -44,64 +44,97 @@ def test_1_engine_runtime_consistency(self): print(out2) self.assertEqual(out1, out2) - def test_2_engine_multiple_generate(self): + def test_2_engine_runtime_encode_consistency(self): + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST + + engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42) + out1 = torch.tensor(engine.encode(prompt)["embedding"]) + engine.shutdown() + + runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42) + out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"]) + runtime.shutdown() + + self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) + + def test_3_engine_token_ids_consistency(self): # just to ensure there is no issue running multiple generate calls prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - sampling_params = {"temperature": 0, "max_new_tokens": 8} - engine = sgl.Engine(model_path=model_path, random_seed=42) - engine.generate(prompt, sampling_params) - engine.generate(prompt, sampling_params) - engine.shutdown() + engine = sgl.Engine( + model_path=model_path, random_seed=42, disable_radix_cache=True + ) + out1 = engine.generate(prompt, sampling_params)["text"] - def test_3_sync_streaming_combination(self): + tokenizer = get_tokenizer(model_path) + token_ids = tokenizer.encode(prompt) + out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[ + "text" + ] - prompt = "AI safety is..." - sampling_params = {"temperature": 0.8, "top_p": 0.95} + engine.shutdown() - async def async_streaming(engine): + print("==== Answer 1 ====") + print(out1) - generator = await engine.async_generate( - prompt, sampling_params, stream=True - ) + print("==== Answer 2 ====") + print(out2) + self.assertEqual(out1, out2) - async for output in generator: - print(output["text"], end="", flush=True) - print() + def test_4_sync_async_stream_combination(self): + prompt = "AI safety is" + sampling_params = {"temperature": 0.8, "top_p": 0.95} # Create an LLM. llm = sgl.Engine( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ) - # 1. sync + non streaming - print("\n\n==== 1. sync + non streaming ====") - output = llm.generate(prompt, sampling_params) + if True: + # 1. sync + non streaming + print("\n\n==== 1. sync + non streaming ====") + output = llm.generate(prompt, sampling_params) + print(output["text"]) + + # 2. sync + streaming + print("\n\n==== 2. sync + streaming ====") + output_generator = llm.generate(prompt, sampling_params, stream=True) + offset = 0 + for output in output_generator: + print(output["text"][offset:], end="", flush=True) + offset = len(output["text"]) + print() - print(output["text"]) + if True: + loop = asyncio.get_event_loop() + # 3. async + non_streaming + print("\n\n==== 3. async + non streaming ====") + output = loop.run_until_complete( + llm.async_generate(prompt, sampling_params) + ) + print(output["text"]) - # 2. sync + streaming - print("\n\n==== 2. sync + streaming ====") - output_generator = llm.generate(prompt, sampling_params, stream=True) - for output in output_generator: - print(output["text"], end="", flush=True) - print() + # 4. async + streaming + async def async_streaming(engine): + generator = await engine.async_generate( + prompt, sampling_params, stream=True + ) - loop = asyncio.get_event_loop() - # 3. async + non_streaming - print("\n\n==== 3. async + non streaming ====") - output = loop.run_until_complete(llm.async_generate(prompt, sampling_params)) - print(output["text"]) + offset = 0 + async for output in generator: + print(output["text"][offset:], end="", flush=True) + offset = len(output["text"]) + print() - # 4. async + streaming - print("\n\n==== 4. async + streaming ====") - loop.run_until_complete(async_streaming(llm)) + print("\n\n==== 4. async + streaming ====") + loop.run_until_complete(async_streaming(llm)) llm.shutdown() - def test_4_gsm8k(self): + def test_5_gsm8k(self): args = SimpleNamespace( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -113,46 +146,7 @@ def test_4_gsm8k(self): metrics = run_eval(args) self.assertGreater(metrics["accuracy"], 0.3) - def test_5_prompt_input_ids_consistency(self): - prompt = "The capital of UK is" - - model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - engine = sgl.Engine( - model_path=model_path, random_seed=42, disable_radix_cache=True - ) - sampling_params = {"temperature": 0, "max_new_tokens": 8} - out1 = engine.generate(prompt, sampling_params)["text"] - - tokenizer = get_tokenizer(model_path) - token_ids = tokenizer.encode(prompt) - out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[ - "text" - ] - - engine.shutdown() - - print("==== Answer 1 ====") - print(out1) - - print("==== Answer 2 ====") - print(out2) - self.assertEqual(out1, out2) - - def test_6_engine_runtime_encode_consistency(self): - prompt = "Today is a sunny day and I like" - model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST - - engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42) - out1 = torch.tensor(engine.encode(prompt)["embedding"]) - engine.shutdown() - - runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42) - out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"]) - runtime.shutdown() - - self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) - - def test_7_engine_cpu_offload(self): + def test_6_engine_cpu_offload(self): prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -182,7 +176,7 @@ def test_7_engine_cpu_offload(self): print(out2) self.assertEqual(out1, out2) - def test_8_engine_offline_throughput(self): + def test_7_engine_offline_throughput(self): server_args = ServerArgs( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ) diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index 6f3b344b3cc..e71de339117 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -23,7 +23,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-torch-compile"], + other_args=["--enable-torch-compile", "--cuda-graph-max-bs", "4"], ) @classmethod diff --git a/test/srt/test_vision_llm.py b/test/srt/test_vision_llm.py new file mode 100644 index 00000000000..7cda64fc0c7 --- /dev/null +++ b/test/srt/test_vision_llm.py @@ -0,0 +1,210 @@ +""" +""" + +import unittest +from io import BytesIO + +import numpy as np +import requests +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import AutoModel, AutoProcessor, AutoTokenizer + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.conversation import generate_chat_conv +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.openai_api.protocol import ChatCompletionRequest +from sglang.srt.server_args import ServerArgs + +MiniCPMV = "openbmb/MiniCPM-V-2_6" + + +# Test the logits output between HF and SGLang +class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls): + cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cls.model_path = "" + cls.chat_template = "" + cls.processor = "" + response = requests.get(cls.image_url) + cls.main_image = Image.open(BytesIO(response.content)) + + def compare_outputs(self, sglang_output: torch.Tensor, hf_output: torch.Tensor): + # Convert to float32 for numerical stability if needed + hf = hf_output.float() + sg = sglang_output.float() + + # Basic shape and dtype comparison + print("\n=== Basic Properties ===") + print(f"Shapes match: {hf.shape == sg.shape}") + print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}") + print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}") + + # Move tensors to CPU for numpy operations + hf_np = hf.cpu().numpy() + sg_np = sg.cpu().numpy() + + # Statistical metrics + print("\n=== Statistical Metrics ===") + print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}") + print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}") + print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}") + print( + f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}" + ) + + # Cosine similarity (across feature dimension) + cos_sim = F.cosine_similarity(hf, sg) + print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}") + print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}") + + # Find largest absolute differences + print("\n=== Largest Absolute Differences ===") + diffs = torch.abs(hf - sg) + flat_diffs = diffs.flatten() + + # Get indices of top 10 differences + top_k = 10 + top_values, top_flat_indices = torch.topk(flat_diffs, top_k) + + # Convert flat indices to multidimensional indices + top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape) + + print(f"\nTop {top_k} largest absolute differences:") + print( + "Index".ljust(30) + + "Difference".ljust(15) + + "HF Value".ljust(15) + + "SGLang Value" + ) + print("-" * 75) + + for i in range(top_k): + # Get the index tuple for this difference + idx = tuple(dim[i] for dim in top_indices) + diff_val = top_values[i].item() + hf_val = hf[idx].item() + sg_val = sg[idx].item() + + # Format the index tuple and values + idx_str = str(idx) + print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}") + + np.testing.assert_allclose(hf_np, sg_np) + + def get_processor_output(self): + json_str = f""" + {{ + "model": "{self.model_path}", + "messages": [ + {{ + "role": "user", + "content": [ + {{ + "type": "image_url", + "image_url": {{ + "url": "{self.image_url}" + }} + }}, + {{ + "type": "text", + "text": "Whats in this picture?" + }} + ] + }} + ] +}} + """ + + req = ChatCompletionRequest.model_validate_json(json_str) + + conv = generate_chat_conv(req, template_name=self.chat_template) + + text = conv.get_prompt() + + # Process inputs using processor + # FIXME: the formal arguments may differ + inputs = self.processor( + text=[text], + images=[self.main_image], + return_tensors="pt", + ).to(self.device) + + return inputs + + def get_sglang_model(self): + model_runner = ModelRunner( + model_config=ModelConfig(self.model_path, model_override_args="{}"), + mem_fraction_static=0.8, + gpu_id=0, + tp_rank=0, + tp_size=1, + nccl_port=12435, + server_args=ServerArgs( + model_path=self.model_path, + disable_cuda_graph=True, + ), + ) + return model_runner.model + + +class TestMiniCPMVLogits(VisionLLMLogitsBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_path = MiniCPMV + cls.tokenizer = AutoTokenizer.from_pretrained( + cls.model_path, trust_remote_code=True + ) + cls.processor = AutoProcessor.from_pretrained( + cls.model_path, trust_remote_code=True + ) + cls.chat_template = "minicpmv" + + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cls.model = AutoModel.from_pretrained( + cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True + ).eval() + cls.model.to(cls.device) + + async def test_encode_output(self): + inputs = self.get_processor_output() + + with torch.no_grad(): + model_inputs = { + "input_ids": inputs.input_ids, + "image_bound": inputs.image_bound, + "pixel_values": inputs.pixel_values, + "tgt_sizes": inputs.tgt_sizes, + } + (hf_output, _) = self.model.get_vllm_embedding( + model_inputs, + ) + hf_output = hf_output.squeeze(0) + + with torch.no_grad(): + model = self.get_sglang_model() + input_ids = inputs["input_ids"].to(self.device).flatten() + image_inputs = model._parse_and_validate_inputs( + input_ids=input_ids, + **{ + "pixel_values": [inputs["pixel_values"]], + "tgt_sizes": [inputs["tgt_sizes"]], + "im_start_id": [self.tokenizer.im_start_id], + "im_end_id": [self.tokenizer.im_end_id], + "slice_start_id": [self.tokenizer.slice_start_id], + "slice_end_id": [self.tokenizer.slice_end_id], + }, + ) + (sglang_output, _) = model.get_embedding( + input_ids=input_ids, image_inputs=image_inputs + ) + + self.compare_outputs(sglang_output, hf_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 5be911ab84a..01762202882 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -180,7 +180,9 @@ def test_multi_images_chat_completion(self): assert response.usage.total_tokens > 0 def prepare_video_messages(self, video_path): - max_frames_num = 32 + # the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa + # the size of the video embeds differs from the `modality` argument when preprocessed + max_frames_num = 12 vr = VideoReader(video_path, ctx=cpu(0)) total_frame_num = len(vr) uniform_sampled_frames = np.linspace(