diff --git a/.clang-format-ignore b/.clang-format-ignore new file mode 100644 index 00000000000..15c76cc457f --- /dev/null +++ b/.clang-format-ignore @@ -0,0 +1 @@ +sgl-kernel/3rdparty/tensorrt_llm/* diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 5493c4201c4..279994c596f 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -13,4 +13,4 @@ - [ ] Format your code according to the [Code Formatting with Pre-Commit](https://docs.sglang.ai/references/contribution_guide.html#code-formatting-with-pre-commit). - [ ] Add unit tests as outlined in the [Running Unit Tests](https://docs.sglang.ai/references/contribution_guide.html#running-unit-tests-adding-to-ci). - [ ] Update documentation / docstrings / example tutorials as needed, according to [Writing Documentation](https://docs.sglang.ai/references/contribution_guide.html#writing-documentation-running-docs-ci). -- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to [Benchmark and Profiling](https://docs.sglang.ai/references/benchmark_and_profiling.html). +- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to [Benchmark and Profiling](https://docs.sglang.ai/references/benchmark_and_profiling.html) and [Accuracy Results](https://docs.sglang.ai/references/accuracy_evaluation.html). diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 6ed6046ee6a..7fd91a5e9a1 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -37,7 +37,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -60,7 +60,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -84,7 +84,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -121,7 +121,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -165,7 +165,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -196,7 +196,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -234,7 +234,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -258,7 +258,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh diff --git a/.github/workflows/release-docker-amd.yml b/.github/workflows/release-docker-amd.yml index 228eecdb9c5..ffe2843d519 100644 --- a/.github/workflows/release-docker-amd.yml +++ b/.github/workflows/release-docker-amd.yml @@ -14,7 +14,7 @@ jobs: environment: 'prod' strategy: matrix: - rocm_version: ['6.2.0'] + rocm_version: ['6.3.0'] build_type: ['all', 'srt'] steps: - name: Checkout repository @@ -41,8 +41,8 @@ jobs: run: | version=$(cat python/sglang/version.py | cut -d'"' -f2) - if [ "${{ matrix.rocm_version }}" = "6.2.0" ]; then - rocm_tag="rocm620" + if [ "${{ matrix.rocm_version }}" = "6.3.0" ]; then + rocm_tag="rocm630" else echo "Unsupported ROCm version" exit 1 diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml index 99ffd7c49cb..d5669886d18 100644 --- a/.github/workflows/release-docker.yml +++ b/.github/workflows/release-docker.yml @@ -14,7 +14,7 @@ jobs: environment: 'prod' strategy: matrix: - cuda_version: ['11.8.0', '12.1.1', '12.4.1'] + cuda_version: ['11.8.0', '12.1.1', '12.4.1', '12.5.1'] build_type: ['all', 'srt'] steps: - name: Delete huge unnecessary tools folder @@ -39,6 +39,8 @@ jobs: cuda_tag="cu121" elif [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then cuda_tag="cu124" + elif [ "${{ matrix.cuda_version }}" = "12.5.1" ]; then + cuda_tag="cu125" else echo "Unsupported CUDA version" exit 1 @@ -58,7 +60,7 @@ jobs: docker build . -f docker/Dockerfile --build-arg CUDA_VERSION=${{ matrix.cuda_version }} --build-arg BUILD_TYPE=${{ matrix.build_type }} -t lmsysorg/sglang:${tag}${tag_suffix} --no-cache docker push lmsysorg/sglang:${tag}${tag_suffix} - if [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then + if [ "${{ matrix.cuda_version }}" = "12.5.1" ]; then docker tag lmsysorg/sglang:${tag}${tag_suffix} lmsysorg/sglang:latest${tag_suffix} docker push lmsysorg/sglang:latest${tag_suffix} fi diff --git a/README.md b/README.md index 0b08c919949..8da8b54a151 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,11 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487) ## Adoption and Sponsorship -The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. +The project is supported by (alphabetically): AMD, Atlas Cloud, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS CORP, Meituan, Nebius, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. + +## Contact Us + +For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at contact@sglang.ai. ## Acknowledgment and Citation We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. diff --git a/benchmark/kernels/quantization/tuning_block_wise_fp8.py b/benchmark/kernels/quantization/tuning_block_wise_fp8.py new file mode 100644 index 00000000000..07bdb4bf167 --- /dev/null +++ b/benchmark/kernels/quantization/tuning_block_wise_fp8.py @@ -0,0 +1,335 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +import json +import os +import time +from datetime import datetime +from typing import Any, Dict, List + +import torch +import triton +from tqdm import tqdm + +from sglang.srt.layers.quantization.fp8_kernel import _w8a8_block_fp8_matmul +from sglang.srt.utils import get_device_name + +DTYPE_MAP = { + "float32": torch.float32, + "float16": torch.float16, + "half": torch.half, + "bfloat16": torch.bfloat16, +} + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + config: Dict[str, Any], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise quantization. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C + + +def get_configs_compute_bound(): + configs = [] + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def get_weight_shapes(tp_size): + # NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model. + # cannot TP + total = [ + (512 + 64, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (7168, 16384), + (7168, 18432), + ] + # N can TP + n_tp = [ + (18432 * 2, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (24576, 1536), + (4096, 7168), + ] + # K can TP + k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)] + + weight_shapes = [] + for t in total: + weight_shapes.append(t) + for n_t in n_tp: + new_t = (n_t[0] // tp_size, n_t[1]) + weight_shapes.append(new_t) + for k_t in k_tp: + new_t = (k_t[0], k_t[1] // tp_size) + weight_shapes.append(new_t) + return weight_shapes + + +def benchmark_config( + A_fp8, B_fp8, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10 +): + def run(): + w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, config, out_dtype) + + torch.cuda.synchronize() + # JIT complication & warmup + for _ in range(5): + run() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: List[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + run() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +def tune(M, N, K, block_size, out_dtype, search_space): + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale + Bs = ( + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") + * factor_for_scale + ) + + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + try: + kernel_time = benchmark_config( + A_fp8, + B_fp8, + As, + Bs, + block_size, + config, + out_dtype, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={M}") + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + block_n, + block_k, + configs, + save_path, +) -> None: + os.makedirs(save_path, exist_ok=True) + device_name = get_device_name().replace(" ", "_") + json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing best config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def main(args): + print(args) + + block_n = args.block_n + block_k = args.block_k + + tp_size = args.tp_size + assert args.out_dtype in ["float32", "float16", "bfloat16", "half"] + out_dtype = DTYPE_MAP[args.out_dtype] + save_path = args.save_path + + search_space = get_configs_compute_bound() + search_space = [ + config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0 + ] + + if args.batch_size is None: + batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, + ] + else: + batch_sizes = [args.batch_size] + + print(f"Start tuning over {len(search_space)} configurations...") + + weight_shapes = get_weight_shapes(tp_size) + start = time.time() + for shape in tqdm(weight_shapes): + N, K = shape[0], shape[1] + print(f"Tune for weight shape of `N: {N}, K: {K}`") + benchmark_results = [ + tune(batch_size, N, K, [block_n, block_k], out_dtype, search_space) + for batch_size in batch_sizes + ] + best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)} + save_configs(N, K, block_n, block_k, best_configs, save_path) + + end = time.time() + print(f"Tuning took {end - start:.2f} seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--tp-size", "-tp", type=int, default=8) + parser.add_argument( + "--out-dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="float16", + ) + parser.add_argument("--block-n", type=int, default=128) + parser.add_argument("--block-k", type=int, default=128) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument( + "--save-path", type=str, default="python/sglang/srt/layers/quantization/configs" + ) + args = parser.parse_args() + + main(args) diff --git a/benchmark/lora/launch_server.py b/benchmark/lora/launch_server.py index f139f0df6fe..418155dbf5b 100644 --- a/benchmark/lora/launch_server.py +++ b/benchmark/lora/launch_server.py @@ -1,10 +1,10 @@ import argparse import os -NUM_LORAS = 8 +NUM_LORAS = 4 LORA_PATH = { - "base": "mistralai/Mistral-7B-Instruct-v0.3", - "lora": "/home/ying/test_lora", + "base": "meta-llama/Llama-2-7b-hf", + "lora": "winddude/wizardLM-LlaMA-LoRA-7B", } @@ -21,7 +21,8 @@ def launch_server(args): cmd += f"{lora_name}={lora_path} " cmd += f"--disable-radix --disable-cuda-graph " cmd += f"--max-loras-per-batch {args.max_loras_per_batch} " - cmd += f"--max-running-requests {args.max_running_requests}" + cmd += f"--max-running-requests {args.max_running_requests} " + cmd += f"--lora-backend {args.lora_backend}" print(cmd) os.system(cmd) @@ -42,6 +43,11 @@ def launch_server(args): type=int, default=8, ) + parser.add_argument( + "--lora-backend", + type=str, + default="triton", + ) args = parser.parse_args() launch_server(args) diff --git a/benchmark/lora/lora_bench.py b/benchmark/lora/lora_bench.py index 713cbbf76ca..b5af65a7dd7 100644 --- a/benchmark/lora/lora_bench.py +++ b/benchmark/lora/lora_bench.py @@ -183,6 +183,7 @@ async def benchmark( api_url=api_url, prompt_len=test_prompt_len, output_len=test_output_len, + lora_name="dummy", # the lora_name argument will not be used extra_request_body=extra_request_body, ) test_output = await request_func(request_func_input=test_input) @@ -206,6 +207,7 @@ async def benchmark( api_url=api_url, prompt_len=prompt_len, output_len=output_len, + lora_name="dummy", extra_request_body=extra_request_body, ) tasks.append( @@ -255,6 +257,9 @@ async def benchmark( "Output token throughput (tok/s):", metrics.output_throughput ) ) + print( + "{:<40} {:<10.2f}".format("Total throughput (tok/s):", metrics.total_throughput) + ) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print( "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) diff --git a/docker/Dockerfile b/docker/Dockerfile index 1fe702d4014..264397f851b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -30,6 +30,8 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu121; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118; \ python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ @@ -39,22 +41,26 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ && cd sglang \ && if [ "$BUILD_TYPE" = "srt" ]; then \ if [ "$CUDA_VERSION" = "12.1.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ - python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer/; \ python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ else \ echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ fi; \ else \ if [ "$CUDA_VERSION" = "12.1.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ - python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer/; \ python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ else \ echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index f04254e54c9..01bc0137c2d 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,8 +1,8 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.4.2 -t v0.4.2-rocm620 -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.4.2.post2 -t v0.4.2.post2-rocm630 -f Dockerfile.rocm . # default base image -ARG BASE_IMAGE="rocmshared/vllm-rocm:20250114-tuned-elementwise-layernorm" +ARG BASE_IMAGE="rocm/vllm-dev:20250114" FROM $BASE_IMAGE AS base USER root @@ -28,6 +28,9 @@ RUN git clone ${SGL_REPO} \ echo "Using ${SGL_BRANCH} branch."; \ git checkout ${SGL_BRANCH}; \ fi \ + && cd sgl-kernel \ + && python setup_rocm.py install \ + && cd .. \ && if [ "$BUILD_TYPE" = "srt" ]; then \ python -m pip --no-cache-dir install -e "python[srt_hip]"; \ else \ @@ -58,6 +61,7 @@ RUN git clone ${ATER_REPO} \ # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 +ENV HSA_NO_SCRATCH_RECLAIM=1 ENV SGLANG_SET_CPU_AFFINITY=1 ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 ENV NCCL_MIN_NCHANNELS=112 diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 7e8f4ca0a54..d6b12b10569 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -124,6 +124,7 @@ Please consult the documentation below to learn more about the parameters you ma * `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supportet with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929). * `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model. +* `lora_backend`: The backend of running GEMM kernels for Lora modules, can be one of `triton` or `flashinfer`. Defaults to be `triton`. ## Kernel backend @@ -159,7 +160,7 @@ Please consult the documentation below to learn more about the parameters you ma * `disable_radix_cache`: Disable [Radix](https://lmsys.org/blog/2024-01-17-sglang/) backend for prefix caching. * `disable_jump_forward`: Disable [jump-forward](https://lmsys.org/blog/2024-02-05-compressed-fsm/#our-method-jump-forward-decoding-with-a-compressed-finite-state-machine) for outlines grammar backend. -* `disable_cuda_graph`: Disable [cuda graph](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for model forward. +* `disable_cuda_graph`: Disable [cuda graph](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for model forward. Use if encountering uncorrectable CUDA ECC errors. * `disable_cuda_graph_padding`: Disable cuda graph when padding is needed. In other case still use cuda graph. * `disable_outlines_disk_cache`: Disable disk cache for outlines grammar backend. * `disable_custom_all_reduce`: Disable usage of custom all reduce kernel. diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index d69436eed17..273d943d120 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -8,10 +8,11 @@ "\n", "SGLang now provides an EAGLE-based speculative decoding option. The implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n", "\n", + "**Note:** Currently, Speculative Decoding in SGLang does not support radix cache.\n", + "\n", "To run the following tests or benchmarks, you also need to install [**cutex**](https://pypi.org/project/cutex/): \n", - "> ```bash\n", - "> pip install cutex\n", - "> ```\n", + "\n", + "`pip install cutex`\n", "\n", "### Performance Highlights\n", "\n", diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md index 779c413977c..16b442554eb 100644 --- a/docs/developer/setup_github_runner.md +++ b/docs/developer/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post2-rocm630 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post2-rocm630 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/index.rst b/docs/index.rst index aaa46384490..f6f14725fd2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -57,11 +57,13 @@ The core features include: references/sampling_params.md references/hyperparameter_tuning.md references/benchmark_and_profiling.md + references/accuracy_evaluation.md references/custom_chat_template.md references/deepseek.md references/llama_405B.md references/modelscope.md references/contribution_guide.md references/troubleshooting.md + references/nvidia_jetson.md references/faq.md references/learn_more.md diff --git a/docs/references/accuracy_evaluation.md b/docs/references/accuracy_evaluation.md new file mode 100644 index 00000000000..123d1cab08b --- /dev/null +++ b/docs/references/accuracy_evaluation.md @@ -0,0 +1,60 @@ +# Measuring Model Accuracy in SGLang + +This guide shows how to evaluate model accuracy using SGLang's [built-in benchmarks](https://github.com/sgl-project/sglang/tree/b045841baeff37a5601fcde23fa98bd09d942c36/benchmark). Please include accuracy on crucial benchmarks in your PR if you make modifications on the model side, like the kernel and model architecture. + +## Benchmarking Model Accuracy + +This is a reference workflow for the [MMLU benchmark](https://github.com/sgl-project/sglang/tree/main/benchmark/mmlu). For more details or other benchmarks, please refer to the README in each specific benchmark folder under [sglang/benchmark](https://github.com/sgl-project/sglang/tree/b045841baeff37a5601fcde23fa98bd09d942c36/benchmark). + +```bash +# Step 1: Download the dataset +bash download_data.sh + +# Step 2: Launch the server +python3 -m sglang.launch_server \ + --model-path Qwen/Qwen2.5-Math-1.5B-Instruct \ # Model selection + --port 30000 \ # Network configuration + --mem-fraction-static 0.8 # Memory optimization + +# Step 3: Run the benchmark script +python3 bench_sglang.py --nsub 10 # Test 10 subjects + +# Step 4: Extract the accuracy +cat result.jsonl | grep -oP '"accuracy": \K\d+\.\d+' +``` + +## Customizing Benchmark Scripts + +Some benchmark implementations may differ from ours, causing accuracy discrepancies. To match [[Qwen2.5-Math]](https://github.com/QwenLM/Qwen2.5-Math)'s reported 76.8% GSM8K accuracy, customization is required. + +```python +# The GSM8K benchmark script includes few shot examples for evaluation by default. +# Here we exclude them. +for i in range(len(lines[num_shots:num_questions])): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) +``` + +```python +@sgl.function +def few_shot_gsm8k(s, question): + # System prompt given in https://github.com/QwenLM/Qwen2.5-Math + s += sgl.system("Please reason step by step, and put your final answer within \\boxed{}.") # Include system prompt + s += few_shot_examples + question + # Stopwords given in evaluation/math_eval.py of the Qwen2.5-Math repo + s += sgl.gen( + "answer", max_tokens=2048, stop=["Question", "Assistant:", "", "<|im_end|>", "<|endoftext|>"] + ) +``` + +These adjustments should return the desired accuracy. + +## Extending Evaluation Capabilities + +1. **Contribute New Benchmarks** + * Follow our [contribution guidelines](https://docs.sglang.ai/references/contribution_guide.html) to add new test scripts +2. **Request Implementations** + * Feel free to open an issue describing your evaluation needs +3. **Use Alternative Tools** + * [OpenCompass](https://opencompass.org.cn) + * [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) diff --git a/docs/references/benchmark_and_profiling.md b/docs/references/benchmark_and_profiling.md index 0600b192b4f..762cae27671 100644 --- a/docs/references/benchmark_and_profiling.md +++ b/docs/references/benchmark_and_profiling.md @@ -15,8 +15,46 @@ python3 -m sglang.bench_serving --backend sglang --num-prompt 10 ``` +## Profile with PyTorch Profiler +Pytorch Profiler is a convenient basic tool to inspect kernel execution time, call stack, and kernel overlap and occupancy. +- To profile a server +```bash +# set trace path +export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log + +# start server +python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct + +# send profiling request from client +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile +``` +Please make sure that the `SGLANG_TORCH_PROFILER_DIR` should be set at both server and client side, otherwise the trace file cannot be generated correctly . A secure way will be setting `SGLANG_TORCH_PROFILER_DIR` in the `.*rc` file of shell (e.g. `~/.bashrc` for bash shells). + +- To profile offline +```bash +export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log +python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8 +``` + +- View Traces + +Trace files can be loaded and visualized from: +1. https://ui.perfetto.dev/ (any browser) +2. chrome://tracing (Chrome browser only) + +If browser cannot open trace file due to its large size, +client can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs. +For example, when profiling a server, +```bash +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile +``` +sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly. + ## Profile with Nsight -0. Prerequisite +Nsight systems is an advanced tool that exposes more profiling details, such as register and shared memory usage, annotated code regions and low-level CUDA APIs and events. + +0. Prerequisite: install using apt, or run inside a [NVIDIA Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) or [SGLang Docker container](https://github.com/sgl-project/sglang/tree/main/docker). + ```bash # install nsys # https://docs.nvidia.com/nsight-systems/InstallationGuide/index.html @@ -41,12 +79,13 @@ nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node -o sglang.out python3 -m sglang.bench_serving --backend sglang --num-prompts 1000 --dataset-name random --random-input 1024 --random-output 512 ``` -3. Use NVTX, e.g. +3. Use NVTX to annotate code regions, e.g. to see their execution time. ```bash # install nvtx pip install nvtx - +``` +``` python # code snippets import nvtx with nvtx.annotate("description", color="color"): @@ -54,41 +93,7 @@ with nvtx.annotate("description", color="color"): ``` ## Other tips - 1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder. 2. You can benchmark a model with modified configs (e.g., less layers) by using `--json-model-override-args`. For example, you can benchmark a model with only 2 layers and 2 kv heads using `python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --batch 32 --input-len 256 --output-len 32 --load-format dummy --json-model-override-args '{"num_hidden_layers": 1, "num_key_value_heads": 1}'` - - -## Profile with PyTorch Profiler -- To profile a server -```bash -# set trace path -export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log - -# start server -python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct - -# send profiling request from client -python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile -``` -Please make sure that the `SGLANG_TORCH_PROFILER_DIR` should be set at both server and client side, otherwise the trace file cannot be generated correctly . A secure way will be setting `SGLANG_TORCH_PROFILER_DIR` in the `.*rc` file of shell (e.g. `~/.bashrc` for bash shells). - -- To profile offline -```bash -export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log -python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8 -``` - -- View Traces - -Trace files can be loaded and visualized from: -1. https://ui.perfetto.dev/ (any browser) -2. chrome://tracing (Chrome browser only) - -If browser cannot open trace file due to its large size, -client can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs. -For example, when profiling a server, -```bash -python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile -``` -sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly. +3. You can use `--python-backtrace=cuda` to see python call stack for all CUDA kernels, as in PyTorch Profiler. (Caveat: this can cause inaccurately long kernel runtimes for CUDA event based timing) +4. For more args please see https://docs.nvidia.com/nsight-systems/UserGuide/index.html diff --git a/docs/references/nvidia_jetson.md b/docs/references/nvidia_jetson.md new file mode 100644 index 00000000000..a36a42ba490 --- /dev/null +++ b/docs/references/nvidia_jetson.md @@ -0,0 +1,67 @@ +# Apply SGLang on NVIDIA Jetson Orin + +## Prerequisites + +Before starting, ensure the following: + +- [**NVIDIA Jetson AGX Orin Devkit**](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/jetson-orin/) is set up with **JetPack 6.1** or later. +- **CUDA Toolkit** and **cuDNN** are installed. +- Verify that the Jetson AGX Orin is in **high-performance mode**: + ```bash + sudo nvpmodel -m 0 + ``` +- A custom PyPI index hosted at https://pypi.jetson-ai-lab.dev/jp6/cu126, tailored for NVIDIA Jetson Orin platforms and CUDA 12.6. + +To install torch from this index: + ```bash +pip install torch --index-url https://pypi.jetson-ai-lab.dev/jp6/cu126 + ``` +* * * * * +## Installation +Please refer to [Installation Guide](https://docs.sglang.ai/start/install.html) to install FlashInfer and SGLang. +* * * * * + +Running Inference +----------------------------------------- + +Launch the server: +```bash +python -m sglang.launch_server \ + --model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ + --device cuda \ + --dtype half \ + --attention-backend flashinfer \ + --mem-fraction-static 0.8 \ + --context-length 8192 +``` +The quantization and limited context length (`--dtype half --context-length 8192`) are due to the limited computational resources in [Nvidia jetson kit](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/jetson-orin/). A detailed explanation can be found in [Server Arguments](https://docs.sglang.ai/backend/server_arguments.html). + +After launching the engine, refer to [Chat completions](https://docs.sglang.ai/backend/openai_api_completions.html#Usage) to test the usability. +* * * * * +Running quantization with TorchAO +------------------------------------- +TorchAO is suggested to NVIDIA Jetson Orin. +```bash +python -m sglang.launch_server \ + --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \ + --device cuda \ + --dtype bfloat16 \ + --attention-backend flashinfer \ + --mem-fraction-static 0.8 \ + --context-length 8192 \ + --torchao-config int4wo-128 +``` +This enables TorchAO's int4 weight-only quantization with a 128-group size. The usage of `--torchao-config int4wo-128` is also for memory efficiency. + + +* * * * * +Structured output with XGrammar +------------------------------- +Please refer to [SGLang doc structured output](https://docs.sglang.ai/backend/structured_outputs.html). +* * * * * + +Thanks to the support from [shahizat](https://github.com/shahizat). + +References +---------- +- [NVIDIA Jetson AGX Orin Documentation](https://developer.nvidia.com/embedded/jetson-agx-orin) diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 93c4273765d..85de12f9f47 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -2,7 +2,7 @@ ## Generative Models - Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 -- Mistral / Mixtral / Mistral NeMo +- Mistral / Mixtral / Mistral NeMo / Mistral Small 3 - Gemma / Gemma 2 - Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL - DeepSeek / DeepSeek 2 / [DeepSeek 3](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3) diff --git a/docs/start/install.md b/docs/start/install.md index 90964ac6b6c..19b4ab56acb 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -6,7 +6,7 @@ You can install SGLang using any of the methods below. ``` pip install --upgrade pip pip install sgl-kernel --force-reinstall --no-deps -pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ +pip install "sglang[all]>=0.4.2.post2" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ ``` Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. @@ -14,12 +14,12 @@ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/ ## Method 2: From source ``` # Use the last release branch -git clone -b v0.4.2 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2.post2 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip pip install sgl-kernel --force-reinstall --no-deps -pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ +pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ ``` Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. If you meet with issue like **ImportError: cannot import name `_grouped_size_compiled_for_decode_kernels`**, installing FlashInfer with some older version like 0.1.6 instead of the latest version could solve it. @@ -28,11 +28,13 @@ Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: ``` # Use the last release branch -git clone -b v0.4.2 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2.post2 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip -pip install sgl-kernel --force-reinstall --no-deps +cd sgl-kernel +python setup_rocm.py install +cd .. pip install -e "python[all_hip]" ``` @@ -54,7 +56,7 @@ docker run --gpus all \ Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.4.2 -t v0.4.2-rocm620 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.2.post2 -t v0.4.2.post2-rocm630 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -63,11 +65,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.4.2-rocm620 \ + v0.4.2.post2-rocm630 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.4.2-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.2.post2-rocm630 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/examples/runtime/engine/EAGLE_offline_batch_inference.py b/examples/runtime/engine/EAGLE_offline_batch_inference.py index 0885959b3fc..897d50ae2d3 100644 --- a/examples/runtime/engine/EAGLE_offline_batch_inference.py +++ b/examples/runtime/engine/EAGLE_offline_batch_inference.py @@ -21,6 +21,7 @@ def main(): speculative_num_steps=3, speculative_eagle_topk=4, speculative_num_draft_tokens=16, + cuda_graph_max_bs=8, ) outputs = llm.generate(prompts, sampling_params) diff --git a/python/pyproject.toml b/python/pyproject.toml index 11c984f82d7..d71cf0153e3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.4.2" +version = "0.4.2.post2" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" @@ -19,31 +19,29 @@ dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"] runtime_common = [ "aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "modelscope", - "orjson", "outlines>=0.0.44,<0.1.0", - "packaging", "pillow", "prometheus-client>=0.20.0", - "psutil", "pydantic", "python-multipart", - "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", - "xgrammar>=0.1.10" + "orjson", "packaging", "pillow", "prometheus-client>=0.20.0", + "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", + "torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.10" ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.3", "torch", "vllm==0.6.4.post1", - "flashinfer==0.1.6" + "sgl-kernel>=0.0.3.post1", "torch", "vllm==0.6.4.post1", + "flashinfer_python>=0.2.0.post2", "outlines>=0.0.44,<0.1.0" ] # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl -srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post2.dev1"] +srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11", "sgl-kernel>=0.0.3.post1"] # xpu is not enabled in public vllm and torch whl, # need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm -srt_xpu = ["sglang[runtime_common]"] +srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<0.1.0"] #For Intel Gaudi(device : hpu) follow the installation guide #https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html -srt_hpu = ["sglang[runtime_common]"] +srt_hpu = ["sglang[runtime_common]", "outlines>=0.0.44,<0.1.0"] # CPU: currently, there are no pre-built vllm wheels for CPU. # To install vllm for CPU, please follow the instruction here: # https://docs.vllm.ai/en/latest/getting_started/installation/cpu/index.html -srt_cpu = ["sglang[runtime_common]", "torch"] +srt_cpu = ["sglang[runtime_common]", "torch", "outlines>=0.0.44,<0.1.0"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index 4820d473959..91dbcba24f9 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -20,7 +20,6 @@ import interegular import torch from outlines.fsm.guide import RegexGuide -from outlines.fsm.json_schema import build_regex_from_schema from outlines.models.transformers import TransformerTokenizer from pydantic import BaseModel @@ -29,6 +28,15 @@ BaseGrammarObject, ) from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap +from sglang.srt.utils import is_hip + +is_hip_ = is_hip() + +if is_hip_: + from outlines_core.fsm.json_schema import build_regex_from_schema +else: + from outlines.fsm.json_schema import build_regex_from_schema + logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py new file mode 100644 index 00000000000..d770e9c085e --- /dev/null +++ b/python/sglang/srt/custom_op.py @@ -0,0 +1,40 @@ +import torch +from torch import nn + +_is_cuda = torch.cuda.is_available() and torch.version.cuda +_is_rocm = torch.cuda.is_available() and torch.version.hip + + +class CustomOp(nn.Module): + def __init__(self): + super().__init__() + self._forward_method = self.dispatch_forward() + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + raise NotImplementedError + + def forward_cuda(self, *args, **kwargs): + raise NotImplementedError + + def forward_hip(self, *args, **kwargs): + return self.forward_cuda(*args, **kwargs) + + def forward_xpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def forward_hpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def forward_cpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def dispatch_forward(self): + if _is_cuda: + return self.forward_cuda + elif _is_rocm: + return self.forward_hip + else: + return self.forward_native diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 098a3d1e325..7f01e312cdc 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -316,8 +316,8 @@ def _set_envs_and_config(server_args: ServerArgs): # Check flashinfer version if server_args.attention_backend == "flashinfer": assert_pkg_version( - "flashinfer", - "0.1.6", + "flashinfer_python", + "0.2.0.post2", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index d69d854ab2e..82c39c2acbc 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -25,21 +25,18 @@ if is_cuda_available(): from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul -from vllm.model_executor.custom_op import CustomOp - +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import set_weight_attrs logger = logging.getLogger(__name__) -@register_custom_op("sglang_silu_and_mul") class SiluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 @@ -53,7 +50,6 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return out -@register_custom_op("sglang_gelu_and_mul") class GeluAndMul(CustomOp): def __init__(self, approximate="tanh"): super().__init__() @@ -76,6 +72,15 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return out +class QuickGELU(CustomOp): + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + # TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel + return self.forward_native(x) + + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index a5e54f32d51..c807e8753fe 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): # Lazy import to avoid the initialization of cuda context from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import ( + extend_attention_fwd, flash_decode_attention_fwd, flash_decode_sparse_attention_fwd, ) - from sglang.srt.layers.attention.triton_ops.extend_attention import ( - extend_attention_fwd, - ) super().__init__() diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 7540515c5fd..1f701f9464f 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -10,6 +10,7 @@ import os from dataclasses import dataclass from enum import Enum, auto +from functools import partial from typing import TYPE_CHECKING, List, Optional, Union import torch @@ -34,6 +35,7 @@ BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.cascade import merge_state + from flashinfer.decode import PosEncodingMode class WrapperDispatch(Enum): @@ -53,10 +55,19 @@ class PrefillMetadata: extend_no_prefix: bool +# Reuse this workspace buffer across all flashinfer wrappers +global_workspace_buffer = None + + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" - def __init__(self, model_runner: ModelRunner): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): super().__init__() # Parse constants @@ -69,6 +80,7 @@ def __init__(self, model_runner: ModelRunner): ), ) self.max_context_len = model_runner.model_config.context_len + self.skip_prefill = skip_prefill assert not ( model_runner.sliding_window_size is not None @@ -90,16 +102,26 @@ def __init__(self, model_runner: ModelRunner): global_config.flashinfer_workspace_size = 512 * 1024 * 1024 # Allocate buffers - self.workspace_buffer = torch.empty( - global_config.flashinfer_workspace_size, - dtype=torch.uint8, - device=model_runner.device, - ) + global global_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + global_config.flashinfer_workspace_size, + dtype=torch.uint8, + device=model_runner.device, + ) + self.workspace_buffer = global_workspace_buffer max_bs = model_runner.req_to_token_pool.size - self.kv_indptr = [ - torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) - for _ in range(self.num_wrappers) - ] + if kv_indptr_buf is None: + self.kv_indptr = [ + torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + for _ in range(self.num_wrappers) + ] + else: + assert self.num_wrappers == 1 + self.kv_indptr = [kv_indptr_buf] + self.kv_last_page_len = torch.ones( (max_bs,), dtype=torch.int32, device=model_runner.device ) @@ -122,12 +144,17 @@ def __init__(self, model_runner: ModelRunner): self.prefill_wrappers_verify = [] self.decode_wrappers = [] for _ in range(self.num_wrappers): - self.prefill_wrappers_paged.append( - BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") - ) - self.prefill_wrappers_verify.append( - BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") - ) + if not skip_prefill: + self.prefill_wrappers_paged.append( + BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + backend="fa2", + ) + ) + self.prefill_wrappers_verify.append( + BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + ) self.decode_wrappers.append( BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, @@ -137,10 +164,11 @@ def __init__(self, model_runner: ModelRunner): ) # Create indices updater + if not skip_prefill: + self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill( + model_runner, self + ) self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self) - self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill( - model_runner, self - ) # Other metadata self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None @@ -211,23 +239,30 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.prefill_wrappers_paged, use_ragged, extend_no_prefix ) - def init_cuda_graph_state(self, max_bs: int): - cuda_graph_kv_indices = torch.zeros( - (max_bs * self.max_context_len,), - dtype=torch.int32, - device="cuda", - ) + def init_cuda_graph_state( + self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None + ): + if kv_indices_buf is None: + cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len,), + dtype=torch.int32, + device="cuda", + ) + else: + cuda_graph_kv_indices = kv_indices_buf + self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [ cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) ] - self.cuda_graph_custom_mask = torch.zeros( - (max_bs * self.max_context_len), - dtype=torch.uint8, - device="cuda", - ) - self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] - self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] + if not self.skip_prefill: + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.uint8, + device="cuda", + ) + self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] + self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] def init_forward_metadata_capture_cuda_graph( self, @@ -279,7 +314,7 @@ def init_forward_metadata_capture_cuda_graph( paged_kv_indices_buf=self.cuda_graph_kv_indices[i], paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], custom_mask_buf=self.cuda_graph_custom_mask, - qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1], + mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1], ) ) seq_lens_sum = seq_lens.sum().item() @@ -602,11 +637,8 @@ def call_begin_forward( self.req_to_token.shape[1], ) else: - bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode( - req_pool_indices, - paged_kernel_lens, - self.req_to_token, - ) + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 wrapper.end_forward() wrapper.begin_forward( @@ -800,7 +832,9 @@ def call_begin_forward( kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + paged_kernel_lens_sum + 256, + dtype=torch.int32, + device=req_pool_indices.device, ) create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, @@ -852,6 +886,132 @@ def call_begin_forward( ) +class FlashInferMultiStepDraftBackend: + """ + Wrap multiple flashinfer attention backends as one for multiple consecutive + draft decoding steps. + """ + + def __init__( + self, + model_runner: ModelRunner, + topk: int, + speculative_num_steps: int, + ): + from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices + max_bs = model_runner.req_to_token_pool.size + self.kv_indptr = torch.zeros( + ( + self.speculative_num_steps, + max_bs + 1, + ), + dtype=torch.int32, + device=model_runner.device, + ) + self.attn_backends = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + FlashInferAttnBackend( + model_runner, + skip_prefill=True, + kv_indptr_buf=self.kv_indptr[i], + ) + ) + self.max_context_len = self.attn_backends[0].max_context_len + # Cached variables for generate_draft_decode_kv_indices + self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] + self.kv_indptr_stride = self.kv_indptr.shape[1] + + def common_template(self, forward_batch: ForwardBatch, call_fn: int): + num_seqs = forward_batch.batch_size + bs = self.topk * num_seqs + seq_lens_sum = forward_batch.seq_lens_sum + self.generate_draft_decode_kv_indices[ + (self.speculative_num_steps, num_seqs, self.topk) + ]( + forward_batch.req_pool_indices, + forward_batch.req_to_token_pool.req_to_token, + forward_batch.seq_lens, + self.cuda_graph_kv_indices, + self.kv_indptr, + forward_batch.positions, + num_seqs, + self.topk, + self.pool_len, + self.kv_indptr_stride, + self.kv_indptr.shape[1], + triton.next_power_of_2(num_seqs), + triton.next_power_of_2(self.speculative_num_steps), + triton.next_power_of_2(bs), + ) + for i in range(self.speculative_num_steps): + forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] + forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][ + : seq_lens_sum * self.topk + bs * (i + 1) + ] + call_fn(i, forward_batch) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + forward_batch.spec_info.kv_indptr = ( + forward_batch.spec_info.kv_indptr.clone() + ) + forward_batch.spec_info.kv_indices = ( + forward_batch.spec_info.kv_indices.clone() + ) + self.attn_backends[i].init_forward_metadata(forward_batch) + + self.common_template(forward_batch, call_fn) + + def init_cuda_graph_state(self, max_bs: int): + self.cuda_graph_kv_indices = torch.zeros( + (self.speculative_num_steps, max_bs * self.max_context_len), + dtype=torch.int32, + device="cuda", + ) + self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1] + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state( + max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] + ) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[ + forward_batch.batch_size + ][0] + decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper) + + self.common_template(forward_batch, call_fn) + + def init_forward_metadata_replay_cuda_graph(self, forward_batch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + forward_batch.batch_size, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + seq_lens_sum=-1, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + self.common_template(forward_batch, call_fn) + + @triton.jit def create_flashinfer_kv_indices_triton( req_to_token_ptr, # [max_batch, max_context_len] @@ -935,3 +1095,88 @@ def should_use_tensor_core( return gqa_group_size > 4 else: return False + + +def fast_decode_plan( + self, + indptr: torch.Tensor, + indices: torch.Tensor, + last_page_len: torch.Tensor, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + pos_encoding_mode: str = "NONE", + window_left: int = -1, + logits_soft_cap: Optional[float] = None, + data_type: Union[str, torch.dtype] = "float16", + q_data_type: Optional[Union[str, torch.dtype]] = None, + sm_scale: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +) -> None: + """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.""" + batch_size = len(last_page_len) + if logits_soft_cap is None: + logits_soft_cap = 0.0 + if self.is_cuda_graph_enabled: + if batch_size != self._fixed_batch_size: + raise ValueError( + "The batch size should be fixed in cudagraph mode, the runtime batch size {} " + " mismatches the batch size set during initialization {}".format( + batch_size, self._fixed_batch_size + ) + ) + if len(indices) > len(self._paged_kv_indices_buf): + raise ValueError( + "The size of indices should be less than or equal to the allocated buffer" + ) + else: + self._paged_kv_indptr_buf = indptr + self._paged_kv_indices_buf = indices + self._paged_kv_last_page_len_buf = last_page_len + # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info + if not q_data_type: + q_data_type = data_type + if not hasattr(self, "empty_q_data"): + self.empty_q_data = torch.empty( + 0, + dtype=( + getattr(torch, q_data_type) + if isinstance(q_data_type, str) + else q_data_type + ), + ) + self.empty_kv_cache = torch.empty( + 0, + dtype=( + getattr(torch, data_type) if isinstance(data_type, str) else data_type + ), + ) + self.last_page_len = torch.ones(32768, dtype=torch.int32) + empty_q_data = self.empty_q_data + empty_kv_cache = self.empty_kv_cache + stream = torch.cuda.current_stream() + self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + indptr.to("cpu"), + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + window_left, + logits_soft_cap, + head_dim, + empty_q_data, + empty_kv_cache, + stream.cuda_stream, + ) + self._pos_encoding_mode = pos_encoding_mode + self._window_left = window_left + self._logits_soft_cap = logits_soft_cap + self._sm_scale = sm_scale + self._rope_scale = rope_scale + self._rope_theta = rope_theta diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index fade8ed292d..3475df72192 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -5,6 +5,9 @@ import torch from sglang.srt.layers.attention import AttentionBackend +from sglang.srt.layers.attention.flashinfer_backend import ( + create_flashinfer_kv_indices_triton, +) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -29,6 +32,15 @@ def __init__(self, model_runner: ModelRunner): self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_fwd + max_bs = model_runner.req_to_token_pool.size + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.qo_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + self.num_head = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) @@ -45,6 +57,9 @@ def __init__(self, model_runner: ModelRunner): def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + if forward_batch.forward_mode.is_decode(): attn_logits = torch.empty( ( @@ -58,11 +73,60 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) max_extend_len = None + + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = None + custom_mask = None else: + kv_indptr[1 : bs + 1] = torch.cumsum( + forward_batch.extend_prefix_lens, dim=0 + ) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.extend_prefix_lens.sum().item(), + dtype=torch.int32, + device=self.device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.extend_prefix_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.qo_indptr + qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + custom_mask = None + attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() - self.forward_metadata = attn_logits, max_extend_len + self.forward_metadata = ( + attn_logits, + max_extend_len, + kv_indptr, + kv_indices, + qo_indptr, + custom_mask, + ) def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len @@ -73,7 +137,12 @@ def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_attn_logits = torch.empty( (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), dtype=torch.float32, - device="cuda", + device=self.device, + ) + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.cuda_graph_max_seq_len), + dtype=torch.int32, + device=self.device, ) def init_forward_metadata_capture_cuda_graph( @@ -90,9 +159,27 @@ def init_forward_metadata_capture_cuda_graph( assert forward_mode.is_decode(), "Not supported" assert spec_info is None, "Not supported" + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + self.forward_metadata = ( self.cuda_graph_attn_logits, None, + kv_indptr, + kv_indices, + None, + None, ) def init_forward_metadata_replay_cuda_graph( @@ -109,6 +196,20 @@ def init_forward_metadata_replay_cuda_graph( self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + def get_cuda_graph_seq_len_fill_value(self): return 1 @@ -132,7 +233,9 @@ def forward_extend( layer, forward_batch.out_cache_loc, k, v ) - _, max_extend_len = self.forward_metadata + _, max_extend_len, kv_indptr, kv_indices, qo_indptr, custom_mask = ( + self.forward_metadata + ) self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -140,11 +243,9 @@ def forward_extend( o.view(-1, layer.tp_q_head_num, layer.v_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), - forward_batch.req_to_token_pool.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.extend_seq_lens, - forward_batch.extend_start_loc, + qo_indptr, + kv_indptr, + kv_indices, max_extend_len, layer.scaling, layer.logit_cap, @@ -170,7 +271,7 @@ def forward_decode( else: o = torch.empty_like(q) - attn_logits, _ = self.forward_metadata + attn_logits, _, kv_indptr, kv_indices, _, _ = self.forward_metadata if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( @@ -182,9 +283,8 @@ def forward_decode( forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - forward_batch.req_to_token_pool.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, + kv_indptr, + kv_indices, attn_logits, self.num_kv_splits, layer.scaling, diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 2b4871af98c..f2274322c52 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -49,11 +49,9 @@ def _fwd_kernel_stage1( K_Buffer, V_Buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, Att_Out, - stride_req_to_tokens_b, stride_qbs, stride_qh, stride_buf_kbs, @@ -82,8 +80,9 @@ def _fwd_kernel_stage1( offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lk mask_dv = offs_dv < Lv - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d q = tl.load(Q + off_q, mask=mask_d, other=0.0) @@ -100,7 +99,7 @@ def _fwd_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + kv_indices + cur_batch_kv_start_idx + offs_n, mask=offs_n < split_kv_end, other=0, ) @@ -173,19 +172,21 @@ def _decode_att_m_fwd( k_buffer, v_buffer, att_out, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, ): BLOCK = 64 + # [TODO] work around SGPR limit on MI3xx + if is_hip_: + BLOCK = 8 NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] - batch, head_num = B_req_idx.shape[0], q.shape[1] + batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] grid = (batch, head_num, NUM_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[1] @@ -194,6 +195,8 @@ def _decode_att_m_fwd( num_warps = 4 else: num_warps = 2 + if is_hip_: + num_warps = 1 BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DV = triton.next_power_of_2(Lv) @@ -203,11 +206,9 @@ def _decode_att_m_fwd( k_buffer, v_buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, att_out, - Req_to_tokens.stride(0), q.stride(0), q.stride(1), k_buffer.stride(0), @@ -236,11 +237,9 @@ def _fwd_grouped_kernel_stage1( K_Buffer, V_Buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, Att_Out, - stride_req_to_tokens_b, stride_qbs, stride_qh, stride_buf_kbs, @@ -279,8 +278,9 @@ def _fwd_grouped_kernel_stage1( offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lk mask_dv = offs_dv < Lv - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) @@ -307,7 +307,7 @@ def _fwd_grouped_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + kv_indices + cur_batch_kv_start_idx + offs_n, mask=offs_n < split_kv_end, other=0, ) @@ -395,9 +395,8 @@ def _decode_grouped_att_m_fwd( k_buffer, v_buffer, att_out, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, @@ -421,7 +420,7 @@ def _decode_grouped_att_m_fwd( BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - batch, head_num = B_req_idx.shape[0], q.shape[1] + batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1] BLOCK_H = 16 @@ -433,21 +432,21 @@ def _decode_grouped_att_m_fwd( ) extra_kargs = {} + num_stages = 2 if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py - extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} + num_stages = 1 _fwd_grouped_kernel_stage1[grid]( q, k_buffer, v_buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, att_out, - Req_to_tokens.stride(0), q.stride(0), q.stride(1), k_buffer.stride(0), @@ -467,7 +466,7 @@ def _decode_grouped_att_m_fwd( NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, num_warps=4, - num_stages=2, + num_stages=num_stages, Lk=Lk, Lv=Lv, **extra_kargs, @@ -478,7 +477,7 @@ def _decode_grouped_att_m_fwd( def _fwd_kernel_stage2( Mid_O, O, - B_Seqlen, + kv_indptr, stride_mid_ob, stride_mid_oh, stride_mid_os, @@ -491,7 +490,9 @@ def _fwd_kernel_stage2( cur_batch = tl.program_id(0) cur_head = tl.program_id(1) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load( + kv_indptr + cur_batch + ) offs_d = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lv @@ -535,7 +536,7 @@ def _decode_softmax_reducev_fwd( q, o, v_buffer, - b_seq_len, + kv_indptr, num_kv_splits, ): batch, head_num = q.shape[0], q.shape[1] @@ -554,7 +555,7 @@ def _decode_softmax_reducev_fwd( _fwd_kernel_stage2[grid]( logits, o, - b_seq_len, + kv_indptr, logits.stride(0), logits.stride(1), logits.stride(2), @@ -574,9 +575,8 @@ def decode_attention_fwd_normal( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -587,14 +587,13 @@ def decode_attention_fwd_normal( k_buffer, v_buffer, attn_logits, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) def decode_attention_fwd_grouped( @@ -602,9 +601,8 @@ def decode_attention_fwd_grouped( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -615,14 +613,13 @@ def decode_attention_fwd_grouped( k_buffer, v_buffer, attn_logits, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) def decode_attention_fwd( @@ -630,9 +627,8 @@ def decode_attention_fwd( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -648,9 +644,8 @@ def decode_attention_fwd( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -663,9 +658,8 @@ def decode_attention_fwd( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, diff --git a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py index 79e148e9c9c..db0fb6b4dbd 100644 --- a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py @@ -3,6 +3,13 @@ import triton.language as tl from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.utils import is_hip + +is_cuda_available = torch.cuda.is_available() +if is_cuda_available: + CUDA_CAPABILITY = torch.cuda.get_device_capability() + +is_hip_ = is_hip() if global_server_args_dict.get("attention_reduce_in_fp32", False): REDUCE_TRITON_TYPE = tl.float32 @@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq): return -import torch - - def flash_decode_attention_fwd( q, k_buffer, @@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd( ) sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ) + + +# Extend attention kernel for Double Sparsity +# Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +@triton.jit +def _fwd_kernel( + Q_Extend, + K_Extend, + V_Extend, + O_Extend, + K_Buffer, + V_Buffer, + Req_to_tokens, + B_req_idx, + B_Seq_Len, + B_Start_Loc_Extend, + B_Seq_Len_Extend, + sm_scale, + kv_group_num, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_req_to_tokens_b, + logit_cap: tl.constexpr, + Lq: tl.constexpr, + Lv: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_seq = tl.program_id(0) + cur_head = tl.program_id(1) + cur_block_m = tl.program_id(2) + cur_kv_head = cur_head // kv_group_num + + cur_seq_len = tl.load(B_Seq_Len + cur_seq) + cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq) + cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend + + cur_seq_prefix_start_in_loc = 0 + cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq) + cur_batch_req_idx = tl.load(B_req_idx + cur_seq) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + offs_m = tl.arange(0, BLOCK_M) + mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend + + mask_d = offs_d < Lq + mask_dv = offs_dv < Lv + + offs_q = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + q = tl.load( + Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0 + ) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + offs_qpe = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_dpe[None, :] + ) + qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0) + + # stage 1: compute scores with prefix + offs_n = tl.arange(0, BLOCK_N) + + acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) + deno = tl.zeros([BLOCK_M], dtype=tl.float32) + e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + + for start_n in range(0, cur_seq_len_prefix, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_seq_len_prefix + offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + ( + cur_seq_prefix_start_in_loc + start_n + offs_n + ) + offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0) + + # load k in transposed way + offs_buf_k = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) + + qk = tl.dot(q.to(k.dtype), k) + if BLOCK_DPE > 0: + offs_kpe = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe.to(kpe.dtype), kpe) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_buf_v = ( + offs_kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + # stage 2: compute the trianlge part + + cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) + for start_n in range(0, cur_block_m_end, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_block_m_end + + # load k in transposed way + offs_k = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] + ) + k = tl.load( + K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) + + qk = tl.dot(q, k, out_dtype=tl.float32) + if BLOCK_DPE > 0: + offs_kpe = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) + * stride_kbs + + cur_kv_head * stride_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Extend + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe, kpe) + + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( + start_n + offs_n[None, :] + ) + mask_causual &= mask_m[:, None] & mask_n[None, :] + qk = tl.where(mask_causual, qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_v = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs + + cur_kv_head * stride_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + offs_o = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_obs + + cur_head * stride_oh + + offs_dv[None, :] + ) + tl.store( + O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :] + ) + + +def extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_seq_len_extend, + b_start_loc_extend, + max_len_extend, + sm_scale=None, + logit_cap=0.0, +): + """ + q_extend, k_extend, v_extend, o_extend: contiguous tensors + + k_buffer, v_buffer: (prefix + extend) tensors in mem_manager + """ + Lq, Lk, Lv = ( + q_extend.shape[-1], + k_extend.shape[-1], + v_extend.shape[-1], + ) + + if Lq == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lq == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + elif Lq == 192: + BLOCK_DMODEL = 128 + BLOCK_DPE = 64 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lq) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + if is_hip_: + BLOCK_M, BLOCK_N = (64, 64) + num_warps = 4 + + else: + if is_cuda_available and CUDA_CAPABILITY[0] >= 9: + if Lq <= 256: + BLOCK_M, BLOCK_N = (128, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: + if Lq <= 128: + BLOCK_M, BLOCK_N = (128, 128) + elif Lq <= 256: + BLOCK_M, BLOCK_N = (64, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + else: + BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) + + num_warps = 4 if Lk <= 64 else 8 + + sm_scale = sm_scale or 1.0 / (Lq**0.5) + batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] + kv_group_num = q_extend.shape[1] // k_extend.shape[1] + + grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) + num_stages = 1 + + extra_kargs = {} + if is_hip_: + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + + _fwd_kernel[grid]( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_start_loc_extend, + b_seq_len_extend, + sm_scale, + kv_group_num, + q_extend.stride(0), + q_extend.stride(1), + k_extend.stride(0), + k_extend.stride(1), + v_extend.stride(0), + v_extend.stride(1), + o_extend.stride(0), + o_extend.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), + req_to_tokens.stride(0), + logit_cap=logit_cap, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + Lq=Lq, + Lv=Lv, + num_warps=num_warps, + num_stages=num_stages, + **extra_kargs, + ) diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index b2654f1f780..6c9976931d0 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -46,11 +46,9 @@ def _fwd_kernel( O_Extend, K_Buffer, V_Buffer, - Req_to_tokens, - B_req_idx, - B_Seq_Len, - B_Start_Loc_Extend, - B_Seq_Len_Extend, + qo_indptr, + kv_indptr, + kv_indices, sm_scale, kv_group_num, stride_qbs, @@ -65,7 +63,6 @@ def _fwd_kernel( stride_buf_kh, stride_buf_vbs, stride_buf_vh, - stride_req_to_tokens_b, logit_cap: tl.constexpr, Lq: tl.constexpr, Lv: tl.constexpr, @@ -80,13 +77,10 @@ def _fwd_kernel( cur_block_m = tl.program_id(2) cur_kv_head = cur_head // kv_group_num - cur_seq_len = tl.load(B_Seq_Len + cur_seq) - cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq) - cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend - - cur_seq_prefix_start_in_loc = 0 - cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq) - cur_batch_req_idx = tl.load(B_req_idx + cur_seq) + cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq) + cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx + cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq) + cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) @@ -97,7 +91,7 @@ def _fwd_kernel( mask_dv = offs_dv < Lv offs_q = ( - (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] @@ -109,7 +103,7 @@ def _fwd_kernel( if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) offs_qpe = ( - (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_dpe[None, :] @@ -126,10 +120,9 @@ def _fwd_kernel( for start_n in range(0, cur_seq_len_prefix, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) mask_n = (start_n + offs_n) < cur_seq_len_prefix - offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + ( - cur_seq_prefix_start_in_loc + start_n + offs_n + offs_kv_loc = tl.load( + kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0 ) - offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0) # load k in transposed way offs_buf_k = ( @@ -188,7 +181,7 @@ def _fwd_kernel( # load k in transposed way offs_k = ( - (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs + (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] ) @@ -199,8 +192,7 @@ def _fwd_kernel( qk = tl.dot(q, k, out_dtype=tl.float32) if BLOCK_DPE > 0: offs_kpe = ( - (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) - * stride_kbs + (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs + cur_kv_head * stride_kh + offs_dpe[:, None] ) @@ -228,7 +220,7 @@ def _fwd_kernel( deno = deno * re_scale + tl.sum(p, 1) offs_v = ( - (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs + (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs + cur_kv_head * stride_vh + offs_dv[None, :] ) @@ -241,7 +233,7 @@ def _fwd_kernel( e_max = n_e_max offs_o = ( - (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_dv[None, :] @@ -258,11 +250,9 @@ def extend_attention_fwd( o_extend, k_buffer, v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_seq_len_extend, - b_start_loc_extend, + qo_indptr, + kv_indptr, + kv_indices, max_len_extend, sm_scale=None, logit_cap=0.0, @@ -315,7 +305,7 @@ def extend_attention_fwd( num_warps = 4 if Lk <= 64 else 8 sm_scale = sm_scale or 1.0 / (Lq**0.5) - batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] + batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) @@ -332,11 +322,9 @@ def extend_attention_fwd( o_extend, k_buffer, v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_start_loc_extend, - b_seq_len_extend, + qo_indptr, + kv_indptr, + kv_indices, sm_scale, kv_group_num, q_extend.stride(0), @@ -351,7 +339,6 @@ def extend_attention_fwd( k_buffer.stride(1), v_buffer.stride(0), v_buffer.stride(1), - req_to_tokens.stride(0), logit_cap=logit_cap, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, diff --git a/python/sglang/srt/layers/custom_op_util.py b/python/sglang/srt/layers/custom_op_util.py deleted file mode 100644 index 92e186cd207..00000000000 --- a/python/sglang/srt/layers/custom_op_util.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from vllm.model_executor.custom_op import CustomOp - - -def register_custom_op(op_name): - def decorator(cls): - if hasattr(CustomOp, "register"): - return CustomOp.register(op_name)(cls) - else: - return cls - - return decorator diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 207ba8d1b7a..e3b23a2a926 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -29,14 +29,11 @@ rmsnorm, ) -from vllm.model_executor.custom_op import CustomOp - -from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.custom_op import CustomOp logger = logging.getLogger(__name__) -@register_custom_op("sglang_rmsnorm") class RMSNorm(CustomOp): def __init__( self, @@ -79,7 +76,6 @@ def forward_native( return x, residual -@register_custom_op("sglang_gemma_rmsnorm") class GemmaRMSNorm(CustomOp): def __init__( self, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index bc927621a84..4d6040646b3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -4,13 +4,12 @@ import torch from torch.nn import Module from vllm import _custom_ops as ops -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.moe.ep_moe.kernels import ( grouped_gemm_triton, post_reorder_triton_kernel, @@ -407,7 +406,6 @@ def _load_fp8_scale( param_data[expert_id] = loaded_weight -@register_custom_op("sglang_unquantized_ep_moe") class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): def create_weights( self, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..a7be90051f8 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..2840e9f4727 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 00000000000..6a976788f9b --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 00000000000..0a46390b2e3 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 00000000000..91011e64c7d --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..bb17743b609 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json @@ -0,0 +1,178 @@ +{ + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16384": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32768": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "65536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "131072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 00000000000..f807d4a5aba --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..92c41a28bee --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json @@ -0,0 +1,175 @@ +{ + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16384": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32768": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "65536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "131072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 32c8fcbb625..fab71809b1b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -15,18 +15,10 @@ from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import ( - direct_register_custom_op, - get_device_name, - is_cuda_available, - is_hip, -) +from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip -is_cuda = is_cuda_available() is_hip_flag = is_hip() -if is_cuda: - from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size - +from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @@ -415,7 +407,7 @@ def moe_align_block_size( ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) if num_experts >= 224: - if enable_moe_align_block_size_triton or is_hip_flag: + if enable_moe_align_block_size_triton: moe_align_block_size_triton( topk_ids, num_experts, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index b71a878a0ba..dc7152da934 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -5,14 +5,13 @@ from typing import Callable, List, Optional, Tuple import torch -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( @@ -67,7 +66,6 @@ def apply( raise NotImplementedError -@register_custom_op("sglang_unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 527a7d499b6..dc53e4445db 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -17,6 +17,8 @@ import torch import torch.nn.functional as F +from sglang.srt.utils import get_compiler_backend + def fused_topk_native( hidden_states: torch.Tensor, @@ -74,6 +76,7 @@ def fused_topk( # This is used by the Deepseek-V2 model +@torch.compile(dynamic=True, backend=get_compiler_backend()) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -108,6 +111,7 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def biased_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..c098ef2dbb9 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..77ba0d7477b --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..6f5adbb9361 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..0a5d7bfdba4 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..cb91a279d42 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..7febe3d272b --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..4225c78eb72 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..9d7658bfc41 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..5e6789d00e0 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..03dba5ad15b --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..49ac14d2a57 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..9a5ff48b894 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..dcbb0efc53e --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..386928de139 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..9c908e80406 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..f78e7060e68 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..dfe5c1e43d6 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..1d3ce5c94c2 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..a87f5de1b18 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..3ab5796ee15 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..468f9e78da0 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..3cb7eaa07c7 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index b0b5b8952a1..f5a0005a282 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -290,6 +290,13 @@ def process_weights_after_loading(self, layer: Module) -> None: weight_scale, requires_grad=False ) layer.input_scale = None + else: + layer.weight = torch.nn.Parameter( + layer.weight.data, requires_grad=False + ) + layer.weight_scale_inv = torch.nn.Parameter( + layer.weight_scale_inv.data, requires_grad=False + ) return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index fe57838e591..ddd614fdfd9 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -from sglang.srt.utils import get_device_name, is_hip +from sglang.srt.utils import get_device_core_count, get_device_name, is_hip is_hip_ = is_hip() fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn @@ -220,6 +220,132 @@ def _w8a8_block_fp8_matmul( tl.store(c_ptrs, c, mask=c_mask) +@triton.jit +def _w8a8_block_fp8_matmul_unrolledx4( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output + tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # manually unroll to 4 iterations + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K) // 4): + # 1st iteration + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # 2nd iteration + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k_start + BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # 3rd iteration + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k_start + BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # 4th iteration + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k_start + BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + @functools.lru_cache def get_w8a8_block_fp8_configs( N: int, K: int, block_n: int, block_k: int @@ -324,7 +450,19 @@ def grid(META): triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - _w8a8_block_fp8_matmul[grid]( + # Use manually unrolledx4 kernel on AMD GPU when the grid size is small. + # Empirical testing shows the sweet spot lies when it's less than the # of + # compute units available on the device. + num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( + N, config["BLOCK_SIZE_N"] + ) + kernel = ( + _w8a8_block_fp8_matmul_unrolledx4 + if (is_hip_ == True and num_workgroups <= get_device_core_count()) + else _w8a8_block_fp8_matmul + ) + + kernel[grid]( A, B, C, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 7093bb90d81..ef8a96c9854 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -7,9 +7,8 @@ import torch import torch.nn as nn from vllm import _custom_ops as ops -from vllm.model_executor.custom_op import CustomOp -from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.custom_op import CustomOp from sglang.srt.utils import is_cuda_available _is_cuda_available = is_cuda_available() @@ -59,7 +58,6 @@ def _apply_rotary_emb( return torch.stack((o1, o2), dim=-1).flatten(-2) -@register_custom_op("sglang_rotary_embedding") class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index b24bfc8dacf..181aadeaa73 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -72,9 +72,11 @@ def forward( # NOTE: the top_p_renorm_prob from flashinfer has numerical problems, # https://github.com/flashinfer-ai/flashinfer/issues/708 # so we use the torch implementation. + + # clamp to avoid -inf logprobs = torch.log( top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ) + ).clamp(min=torch.finfo(probs.dtype).min) max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( @@ -83,7 +85,7 @@ def forward( if sampling_info.need_min_p_sampling: probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_p_renorm_prob(probs, sampling_info.top_ps) - batch_next_token_ids, success = min_p_sampling_from_probs( + batch_next_token_ids = min_p_sampling_from_probs( probs, uniform_samples, sampling_info.min_ps ) else: @@ -95,9 +97,9 @@ def forward( filter_apply_order="joint", ) - if self.use_nan_detectioin and not torch.all(success): - logger.warning("Detected errors during sampling!") - batch_next_token_ids = torch.zeros_like(batch_next_token_ids) + if self.use_nan_detectioin and not torch.all(success): + logger.warning("Detected errors during sampling!") + batch_next_token_ids = torch.zeros_like(batch_next_token_ids) elif global_server_args_dict["sampling_backend"] == "pytorch": # A slower fallback implementation with torch native operations. @@ -109,9 +111,10 @@ def forward( sampling_info.need_min_p_sampling, ) if return_logprob: + # clamp to avoid -inf logprobs = torch.log( top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ) + ).clamp(min=torch.finfo(probs.dtype).min) else: raise ValueError( f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" diff --git a/python/sglang/srt/lora/backend/__init__.py b/python/sglang/srt/lora/backend/__init__.py new file mode 100644 index 00000000000..ed377b4b4ad --- /dev/null +++ b/python/sglang/srt/lora/backend/__init__.py @@ -0,0 +1,8 @@ +from .base_backend import BaseLoraBackend +from .flashinfer_backend import FlashInferLoraBackend +from .triton_backend import TritonLoraBackend + +__all__ = [ + "FlashInferLoraBackend", + "TritonLoraBackend", +] diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py new file mode 100644 index 00000000000..d6c72a14e73 --- /dev/null +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -0,0 +1,95 @@ +from typing import Tuple, Union + +import torch + +from sglang.srt.lora.lora import LoraBatchInfo + + +def get_fuse_output_scaling_add_from_name(name: str) -> bool: + mapping = { + "triton": True, + "flashinfer": False, + } + return mapping.get(name, False) + + +def get_fuse_qkv_lora_b_from_name(name: str) -> bool: + mapping = { + "triton": True, + "flashinfer": False, + } + return mapping.get(name, False) + + +class BaseLoraBackend: + """Base class for different Lora backends. + Each backend has its own implementation of Lora kernels. + + Args: + name: name of backend + batch_info: information of current batch for use + fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward, + and the operation of scaling and adding will be fused into kernel + """ + + def __init__(self, name: str, batch_info: LoraBatchInfo = None): + self.name = name + self.batch_info = batch_info + self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name) + self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name) + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + """Run segment Gemm of lora a modules with current backend. + The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. + + Args: + x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths + weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank + usually input_dim is much larger than r + Returns: + result with shape (s, r) + """ + pass + + def run_lora_b_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + """Run segment Gemm of lora b modules with current backend. + The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. + + Args: + x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank + weights: a set of lora weights with shape (num_lora, output_dim, r) + usually output_dim is much larger than r + Returns: + result with shape (s, output_dim) + """ + pass + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]], + *args, + **kwargs + ) -> torch.Tensor: + """Run the lora pass for QKV Layer. + + Args: + x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths + qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim) + qkv_lora_b: lora_b module for qkv. + If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r) + If passed in as a tuple of two tensors containing: + a lora_b module for q, with shape (1, num_lora, output_dim_q, r) + and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r) + Returns: + result with shape (s, output_dim_q + 2 * output_dim_kv) + """ + pass + + def set_batch_info(self, batch_info: LoraBatchInfo): + self.batch_info = batch_info diff --git a/python/sglang/srt/lora/backend/flashinfer_backend.py b/python/sglang/srt/lora/backend/flashinfer_backend.py new file mode 100644 index 00000000000..91c15be3c0a --- /dev/null +++ b/python/sglang/srt/lora/backend/flashinfer_backend.py @@ -0,0 +1,91 @@ +from typing import Tuple + +import torch + +from sglang.srt.lora.backend import BaseLoraBackend +from sglang.srt.lora.lora import LoraBatchInfo +from sglang.srt.utils import is_flashinfer_available + +if is_flashinfer_available(): + from flashinfer import SegmentGEMMWrapper + + +class FlashInferLoraBackend(BaseLoraBackend): + + def __init__(self, name: str, batch_info: LoraBatchInfo = None): + super().__init__(name, batch_info) + + # Set up SGemm Wrapper from flashinfer + # FIXME wait for flashinfer segment gemm update + workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") + self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + + return self.segment_gemm.run( + x=x, + weights=weights, + batch_size=self.batch_info.bs, + weight_column_major=True, + seg_indptr=self.batch_info.seg_indptr, + weight_indices=self.batch_info.weight_indices, + ) + + def run_lora_b_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + + return self.segment_gemm.run( + x=x, + weights=weights, + batch_size=self.batch_info.bs, + weight_column_major=True, + seg_indptr=self.batch_info.seg_indptr, + weight_indices=self.batch_info.weight_indices, + ) + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: Tuple[torch.Tensor], + *args, + **kwargs, + ) -> torch.Tensor: + + # Shape of lora_a_output: (s, 3 * r) + lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a) + + q_lora_b, kv_lora_b = qkv_lora_b + lora_rank = kv_lora_b.shape[-1] + output_dim_q = q_lora_b.shape[-2] + output_dim_kv = kv_lora_b.shape[-2] + lora_output = torch.empty( + (x.shape[0], output_dim_q + 2 * output_dim_kv), + device=x.device, + dtype=x.dtype, + ) + + # q + lora_output[:, :output_dim_q] = self.run_lora_b_sgemm( + x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0] + ) + + # kv + lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = ( + self.run_lora_b_sgemm( + x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(), + weights=kv_lora_b[0], + ) + ) + + lora_output[ + :, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv + ] = self.run_lora_b_sgemm( + x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(), + weights=kv_lora_b[1], + ) + + return lora_output diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py new file mode 100644 index 00000000000..357040bf9d9 --- /dev/null +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -0,0 +1,61 @@ +import torch + +from sglang.srt.lora.backend import BaseLoraBackend +from sglang.srt.lora.lora import LoraBatchInfo +from sglang.srt.lora.triton_ops import ( + qkv_lora_b_fwd, + sgemm_lora_a_fwd, + sgemm_lora_b_fwd, +) + + +class TritonLoraBackend(BaseLoraBackend): + + def __init__(self, name: str, batch_info: LoraBatchInfo = None): + super().__init__(name, batch_info) + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + return sgemm_lora_a_fwd(x, weights, self.batch_info) + + def run_lora_b_sgemm( + self, + x: torch.Tensor, + weights: torch.Tensor, + base_output: torch.Tensor = None, + scaling: float = 1.0, + *args, + **kwargs + ) -> torch.Tensor: + return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling) + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: torch.Tensor, + output_offset: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor = None, + scaling: float = 1.0, + *args, + **kwargs + ) -> torch.Tensor: + + # x: (s, input_dim) + # qkv_lora_a: (num_lora, 3 * r, input_dim) + # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r) + assert isinstance(qkv_lora_b, torch.Tensor) + + lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info) + lora_output = qkv_lora_b_fwd( + lora_a_output, + qkv_lora_b, + self.batch_info, + output_offset, + max_qkv_out_dim, + base_output, + scaling, + ) + return lora_output diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index c8cbe36602b..9de3b9236b9 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -18,12 +18,11 @@ # LoRA layers class inheritance adapted from: # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py - import re +from dataclasses import dataclass import torch from torch import nn -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -31,17 +30,36 @@ QKVParallelLinear, RowParallelLinear, ) +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_loader.loader import DefaultModelLoader +@dataclass +class LoraBatchInfo: + # Batch size + bs: int + + # Lengths of each sequence in shape (bs,) + seg_lens: torch.Tensor + + # Indice pointers of each sequence in shape (bs + 1, ) + seg_indptr: torch.Tensor + + # Maximum sequence length of current batch + max_len: int + + # The index of lora adapter used by each sequence, in shape (bs,) + weight_indices: torch.Tensor + + class BaseLayerWithLoRA(nn.Module): - def __init__(self, base_layer, segment_gemm, lora_rank, scaling): + def __init__(self, base_layer, lora_rank, scaling, lora_backend): super().__init__() self.base_layer = base_layer - self.segment_gemm = segment_gemm self.lora_rank = lora_rank self.scaling = scaling self.set_lora = False + self.lora_backend = lora_backend def forward(self, x: torch.Tensor): return self.base_layer.forward(x) @@ -52,17 +70,17 @@ def set_lora_info(self, *args): class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__( - self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling + self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) self.weight = base_layer.weight class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__( - self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling + self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: # TODO @@ -88,136 +106,127 @@ def forward(self, input_: torch.Tensor): class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): def __init__( - self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling + self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) - def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices): + def set_lora_info( + self, + A_buffer, + B_buffer, + ): self.set_lora = True self.A_buffer = A_buffer self.B_buffer = B_buffer - self.bs = bs - self.seg_indptr = seg_indptr - self.weight_indices = weight_indices def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_a_output = self.segment_gemm.run( - x=x, - weights=self.A_buffer, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, - ) - # FIXME + lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer) + + output_dim = base_output.shape[-1] lora_output = torch.empty_like(base_output) - output_dim = lora_output.shape[-1] // 2 - for i in range(2): - left = output_dim * i - right = left + output_dim - lora_output[:, left:right] = self.segment_gemm.run( - x=lora_a_output[ - :, self.lora_rank * i : self.lora_rank * (i + 1) - ].contiguous(), - weights=self.B_buffer[:, left:right, :].contiguous(), - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm( + x=lora_a_output[:, 0 : self.lora_rank].contiguous(), + weights=self.B_buffer[0], + ) + + lora_output[:, output_dim : 2 * output_dim] = ( + self.lora_backend.run_lora_b_sgemm( + x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(), + weights=self.B_buffer[1], ) + ) + return base_output + lora_output * self.scaling class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - def __init__( - self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling + def init__( + self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) def set_lora_info( - self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices + self, + A_buffer_qkv, + B_buffer_q, + B_buffer_kv, ): self.set_lora = True self.A_buffer_qkv = A_buffer_qkv - self.B_buffer_q = B_buffer_q - self.B_buffer_kv = B_buffer_kv - self.bs = bs - self.seg_indptr = seg_indptr - self.weight_indices = weight_indices + + if self.lora_backend.fuse_qkv_lora_b: + assert ( + B_buffer_q.shape[-1] == B_buffer_kv.shape[-1] + ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b" + output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2] + + # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r) + self.B_buffer_qkv = torch.cat( + (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2 + ).contiguous() + + # Offsets of q/k/v in output dimension + self.output_offset = torch.tensor( + [ + 0, + output_dim_q, + output_dim_q + output_dim_kv, + output_dim_q + 2 * output_dim_kv, + ], + dtype=torch.int32, + device=B_buffer_q.device, + ) + # For computing number of launched blocks + self.max_qkv_out_dim = max(output_dim_q, output_dim_kv) + else: + self.B_buffer_qkv = ( + B_buffer_q, + B_buffer_kv, + ) + self.output_offset = None + self.max_qkv_out_dim = None def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_a_output = self.segment_gemm.run( - x=x, - weights=self.A_buffer_qkv, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + lora_output = self.lora_backend.run_qkv_lora( + x, + self.A_buffer_qkv, + self.B_buffer_qkv, + output_offset=self.output_offset, + max_qkv_out_dim=self.max_qkv_out_dim, + base_output=base_output, + scaling=self.scaling, ) - # FIXME parallelize qkv - lora_output = torch.empty_like(base_output) - # q - output_dim_q = self.B_buffer_q.shape[-2] - lora_output[:, :output_dim_q] = self.segment_gemm.run( - x=lora_a_output[:, : self.lora_rank].contiguous(), - weights=self.B_buffer_q, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling ) - # kv - output_dim_kv = self.B_buffer_kv.shape[-2] // 2 - for i in range(2): - left = output_dim_kv * i - right = left + output_dim_kv - lora_output[:, output_dim_q + left : output_dim_q + right] = ( - self.segment_gemm.run( - x=lora_a_output[ - :, self.lora_rank * (i + 1) : self.lora_rank * (i + 2) - ].contiguous(), - weights=self.B_buffer_kv[:, left:right, :].contiguous(), - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, - ) - ) - return base_output + lora_output * self.scaling class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__( - self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling + self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) - def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices): + def set_lora_info(self, A_buffer, B_buffer): self.set_lora = True self.A_buffer = A_buffer self.B_buffer = B_buffer - self.bs = bs - self.seg_indptr = seg_indptr - self.weight_indices = weight_indices def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_output = self.segment_gemm.run( - x=x, - weights=self.A_buffer, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) + lora_output = self.lora_backend.run_lora_b_sgemm( + lora_a_output, + self.B_buffer[0], + base_output=base_output, + scaling=self.scaling, ) - lora_output = self.segment_gemm.run( - x=lora_output, - weights=self.B_buffer, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling ) - return base_output + lora_output * self.scaling def forward(self, input_): # duplicate the logic in RowParallelLinear @@ -255,7 +264,7 @@ def forward(self, input_): def get_lora_layer( - layer: nn.Module, segment_gemm, lora_rank, scaling + layer: nn.Module, lora_rank, scaling, lora_backend ) -> BaseLayerWithLoRA: supported_layer_types = { # the order matters @@ -267,7 +276,7 @@ def get_lora_layer( } for src_layer_type, lora_layer_type in supported_layer_types.items(): if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck - ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling) + ret = lora_layer_type(layer, lora_rank, scaling, lora_backend) return ret raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") @@ -297,13 +306,14 @@ def offload_from_gpu(self): class LoRAAdapter(nn.Module): - def __init__(self, uid, config, base_hf_config, load_config): + def __init__(self, uid, config, base_hf_config, load_config, lora_backend): super().__init__() self.uid = uid self.config = config assert self.config.hf_config["peft_type"].lower() == "lora" self.base_hf_config = base_hf_config self.load_config = load_config + self.lora_backend = lora_backend self.scaling = self.config.lora_alpha / self.config.r self.layers = nn.ModuleList( @@ -376,20 +386,25 @@ def initialize_weights(self): layer.weights.pop(weight_name) layer.weights.pop(v_name) else: - layer.weights[kv_name] = torch.cat( - ( + layer.weights[kv_name] = torch.stack( + [ layer.weights[weight_name], layer.weights[v_name], - ), - 0, + ], + dim=0, ) layer.weights.pop(weight_name) layer.weights.pop(v_name) elif "gate_proj" in weight_name: up_name = weight_name.replace("gate_proj", "up_proj") gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") - layer.weights[gate_up_name] = torch.cat( - (layer.weights[weight_name], layer.weights[up_name]), 0 - ) + if "lora_A" in weight_name: + layer.weights[gate_up_name] = torch.cat( + (layer.weights[weight_name], layer.weights[up_name]), 0 + ) + else: + layer.weights[gate_up_name] = torch.stack( + [layer.weights[weight_name], layer.weights[up_name]], dim=0 + ) layer.weights.pop(weight_name) layer.weights.pop(up_name) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 0449e252453..404f3f50700 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -20,16 +20,14 @@ import torch -from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer +from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend +from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import is_flashinfer_available, replace_submodule logger = logging.getLogger(__name__) -if is_flashinfer_available(): - from flashinfer import SegmentGEMMWrapper - def get_module_name(name): # Fallback solution of mapping from config module name to module name in model class. @@ -77,6 +75,20 @@ def get_stacked_name(name): return params_mapping.get(name, (name, name)) +def get_backend_from_name(name): + backend_mapping = { + "triton": TritonLoraBackend, + "flashinfer": FlashInferLoraBackend, + } + + if name in backend_mapping: + return backend_mapping[name] + + raise Exception( + f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}" + ) + + def get_layer_id(name): match = re.search(r"layers\.(\d+)\.", name) if match is None: @@ -93,6 +105,7 @@ def __init__( max_loras_per_batch, load_config, dtype, + lora_backend, ): self.base_model = base_model self.lora_paths = lora_paths @@ -101,8 +114,9 @@ def __init__( self.load_config = load_config self.dtype = dtype - workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") - self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) + logger.info(f"Using {lora_backend} as backend of Lora kernels.") + backend_type = get_backend_from_name(lora_backend) + self.lora_backend = backend_type(lora_backend) self.init_loras() self.init_lora_memory_pool() @@ -123,7 +137,7 @@ def get_target_modules(self): def set_lora_module(self, module_name, module): lora_module = get_lora_layer( - module, self.segment_gemm, self.max_lora_dim, self.scaling + module, self.max_lora_dim, self.scaling, self.lora_backend ) replace_submodule(self.base_model, module_name, lora_module) return lora_module @@ -162,7 +176,11 @@ def init_loras(self): self.lora_id[name] = len(self.loras) self.loras.append( LoRAAdapter( - name, self.configs[name], self.base_hf_config, self.load_config + name, + self.configs[name], + self.base_hf_config, + self.load_config, + self.lora_backend, ) ) self.loras[-1].initialize_weights() @@ -226,8 +244,9 @@ def init_lora_memory_pool(self): self.B_buffer[module_B] = [ torch.empty( ( + c, self.max_loras_per_batch, - hidden_dim_B * c, + hidden_dim_B, self.max_lora_dim, ), dtype=self.dtype, @@ -263,7 +282,16 @@ def load_lora(self, uid, buffer_id): else: lora_weight_name = self.get_weight_name(name, 1) if lora_weight_name: - self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights) + c = self.loras[-1].get_stacked_multiply(lora_weight_name) + if c > 1: + for j in range(c): + self.B_buffer[lora_weight_name][i][j][buffer_id].copy_( + weights[j] + ) + else: + self.B_buffer[lora_weight_name][i][0][buffer_id].copy_( + weights + ) def prepare_lora_batch(self, forward_batch: ForwardBatch): # load active loras into lora memory pool @@ -292,20 +320,30 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): if cur_uids == set([None]): return - # setup lora in forward modules + # set up batch info shared by all lora moruldes bs = forward_batch.batch_size seg_lens = ( forward_batch.extend_seq_lens if forward_batch.forward_mode.is_extend() else torch.ones(bs, device="cuda") ) - # FIXME: reuse the data rather than recompute seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) + max_len = int(torch.max(seg_lens)) weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") for i, lora_path in enumerate(forward_batch.lora_paths): weight_indices[i] = self.buffer_id[lora_path] + batch_info = LoraBatchInfo( + bs=bs, + seg_lens=seg_lens, + seg_indptr=seg_indptr, + max_len=max_len, + weight_indices=weight_indices, + ) + self.lora_backend.set_batch_info(batch_info) + + # call set_lora_info for each lora modules for module_name, module in self.lora_modules: layer_id = get_layer_id(module_name) @@ -314,16 +352,10 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): module.set_lora_info( self.A_buffer[weight_name][layer_id], self.B_buffer[weight_name][layer_id], - bs, - seg_indptr, - weight_indices, ) else: module.set_lora_info( self.A_buffer["qkv_proj"][layer_id], self.B_buffer["q_proj"][layer_id], self.B_buffer["kv_proj"][layer_id], - bs, - seg_indptr, - weight_indices, ) diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py new file mode 100644 index 00000000000..efc76bb8b47 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -0,0 +1,5 @@ +from .qkv_lora_b import qkv_lora_b_fwd +from .sgemm_lora_a import sgemm_lora_a_fwd +from .sgemm_lora_b import sgemm_lora_b_fwd + +__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"] diff --git a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py new file mode 100644 index 00000000000..3e090f4dc37 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py @@ -0,0 +1,182 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _qkv_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Parameters of size + K, # K = R + max_qkv_out_dim, # max(output_q_dim, output_kv_dim) + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Offsets of q/k/v slice on output dimension + n_offs, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # For fused output scaling and adding + fuse_scaling_add, + scaling, +): + # This kernel packs 3 sgemms (q/k/v) into a single kernel. + + # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank + # weights: (num_lora, N_Q + 2 * N_KV, K) + # output: (s, N_Q + 2 * N_KV) + # N_Q >> K, N_KV >> K + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len. + # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v) + batch_id = tl.program_id(axis=2) + qkv_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + n_start = tl.load(n_offs + qkv_id) + n_size = tl.load(n_offs + qkv_id + 1) - n_start + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + + x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size) + if fuse_scaling_add: + partial_sum += tl.load(output_ptr, mask=output_mask) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def qkv_lora_b_fwd( + x: torch.Tensor, + qkv_lora_b: torch.Tensor, + batch_info: LoraBatchInfo, + output_offset: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor = None, + scaling: float = 1.0, +) -> torch.Tensor: + + # x: (s, 3 * r) + # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r) + # output_offset = [0, output_dim_q, output_dim_q + output_dim_kv, + # output_dim_q + 2 * output_dim_kv] + # max_qkv_out_dim = max(output_dim_q, output_dim_kv) + # output: (s, output_dim_q + 2 * output_dim_kv) + + # Compute lora_output with shape (s, output_dim) as follows: + # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], ) + # lora_output[:, output_dim_q: output_dim_q + output_dim_kv] + # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0]) + # lora_output[:, output_dim_q + output_dim_kv: ] + # = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1]) + + # Get dims + s = x.shape[0] + input_dim = x.shape[1] + r = qkv_lora_b.shape[-1] + output_dim = qkv_lora_b.shape[-2] + assert input_dim == 3 * r + assert output_offset.shape[0] == 4 + + BLOCK_S = 16 + BLOCK_R = 16 + BLOCK_OUT = 64 + + grid_b = ( + triton.cdiv(batch_info.max_len, BLOCK_S) + * triton.cdiv(max_qkv_out_dim, BLOCK_OUT), + 3, # this dimension decides current block computes on q, k or v + batch_info.bs, + ) + + if base_output is None: + output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype) + fuse_scaling_add = False + else: + output = base_output + fuse_scaling_add = True + + _qkv_lora_b_kernel[grid_b]( + x, + qkv_lora_b, + output, + r, + max_qkv_out_dim, + x.stride(0), + x.stride(1), + qkv_lora_b.stride(0), + qkv_lora_b.stride(1), + qkv_lora_b.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + output_offset, + BLOCK_S, + BLOCK_OUT, + BLOCK_R, + fuse_scaling_add, + scaling, + ) + + return output diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py new file mode 100644 index 00000000000..305bb8c5f0e --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py @@ -0,0 +1,143 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _sgemm_lora_a_kernel( + # Pointers to matrices + x, + weights, + output, + # Matrix dimensions + N, # r + K, # input_dim + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + + # x: (s, K), s is the sum of sequence lengths + # weights: (num_lora, N, K) + # output: (s, N) + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len + batch_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = (x + seg_start * x_stride_0) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def sgemm_lora_a_fwd( + x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo +) -> torch.Tensor: + # x: (s, input_dim) + # weights: (num_lora, r, input_dim) + # output: (s, r) + # when called by run_qkv_lora, the weights.shape[-2] will be 3 * r + # input_dim is much larger than r + + assert x.is_contiguous() + assert weights.is_contiguous() + assert len(x.shape) == 2 + assert len(weights.shape) == 3 + + S = x.shape[0] + R = weights.shape[-2] + K = weights.shape[-1] + assert x.shape[-1] == K + + # Block shapes + BLOCK_S = 16 + BLOCK_K = 256 + BLOCK_R = 16 + + grid = ( + triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R), + batch_info.bs, + ) + + output = torch.empty((S, R), device=x.device, dtype=x.dtype) + _sgemm_lora_a_kernel[grid]( + x, + weights, + output, + R, + K, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + BLOCK_S, + BLOCK_R, + BLOCK_K, + ) + return output diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py new file mode 100644 index 00000000000..c0bc913630c --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py @@ -0,0 +1,159 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _sgemm_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Matrix dimensions + N, # output_dim + K, # r + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # For fused output scaling and adding + fuse_scaling_add, + scaling, +): + # x: (s, K), s is the sum of sequence lengths + # weights: (num_lora, N, K) + # output: (s, N) + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len + batch_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = (x + seg_start * x_stride_0) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = s_offset[:, None] < seg_len + if fuse_scaling_add: + partial_sum += tl.load(output_ptr, mask=output_mask) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def sgemm_lora_b_fwd( + x: torch.Tensor, + weights: torch.Tensor, + batch_info: LoraBatchInfo, + base_output: torch.Tensor = None, + scaling: float = 1.0, +) -> torch.Tensor: + # x: (s, r) + # weights: (num_lora, output_dim, r) + # output: (s, output_dim) + # output_dim is much larger than r + + assert x.is_contiguous() + assert weights.is_contiguous() + assert len(x.shape) == 2 + assert len(weights.shape) == 3 + + S = x.shape[0] + N = weights.shape[-2] + R = weights.shape[-1] + assert x.shape[-1] == R + + # Block shapes + BLOCK_S = 16 + BLOCK_R = 16 + BLOCK_N = 256 + + grid = ( + triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N), + batch_info.bs, + ) + + if base_output is None: + output = torch.empty((S, N), device=x.device, dtype=x.dtype) + fuse_scaling_add = False + else: + output = base_output + fuse_scaling_add = True + + _sgemm_lora_b_kernel[grid]( + x, + weights, + output, + N, + R, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + BLOCK_S, + BLOCK_N, + BLOCK_R, + fuse_scaling_add, + scaling, + ) + return output diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index e8ae866f435..5a2a53ffaa8 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -21,8 +21,8 @@ import torch import tqdm -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -103,72 +103,75 @@ def set_torch_compile_config(): torch._dynamo.config.cache_size_limit = 1024 +def get_batch_sizes_to_capture(model_runner: ModelRunner): + server_args = model_runner.server_args + capture_bs = server_args.cuda_graph_bs + if capture_bs is None: + if server_args.disable_cuda_graph_padding: + capture_bs = list(range(1, 33)) + [64, 128] + else: + capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + if max(capture_bs) > model_runner.req_to_token_pool.size: + # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests + # is very samll. We add more values here to make sure we capture the maximum bs. + capture_bs = list( + sorted( + set( + capture_bs + + [model_runner.req_to_token_pool.size - 1] + + [model_runner.req_to_token_pool.size] + ) + ) + ) + capture_bs = [ + bs + for bs in capture_bs + if bs <= model_runner.req_to_token_pool.size + and bs <= server_args.cuda_graph_max_bs + ] + compile_bs = ( + [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] + if server_args.enable_torch_compile + else [] + ) + return capture_bs, compile_bs + + +# Reuse this memory pool across all cuda graph runners. +global_graph_memory_pool = None + + +def get_global_graph_memory_pool(): + return global_graph_memory_pool + + +def set_global_graph_memory_pool(val): + global global_graph_memory_pool + global_graph_memory_pool = val + + class CudaGraphRunner: """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" - def __init__(self, model_runner: "ModelRunner"): + def __init__(self, model_runner: ModelRunner): # Parse args self.model_runner = model_runner self.graphs = {} - self.input_buffers = {} self.output_buffers = {} - self.flashinfer_handlers = {} - self.graph_memory_pool = None - self.use_torch_compile = model_runner.server_args.enable_torch_compile + self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.disable_padding = model_runner.server_args.disable_cuda_graph_padding - self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder - self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention - self.enable_hip_attention = self.model_runner.server_args.enable_hip_attention - if self.enable_hip_attention: - self.hip_config = self.model_runner.server_args.hip_attention_config - self.tp_size = self.model_runner.tp_size - self.dp_size = self.model_runner.server_args.dp_size + self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder + self.enable_dp_attention = model_runner.server_args.enable_dp_attention + self.tp_size = model_runner.server_args.tp_size + self.dp_size = model_runner.server_args.dp_size # Batch sizes to capture - self.capture_bs = self.model_runner.server_args.cuda_graph_bs - if self.capture_bs is None: - if model_runner.server_args.disable_cuda_graph_padding: - self.capture_bs = list(range(1, 33)) + [64, 128] - else: - self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] - - if max(self.capture_bs) > model_runner.req_to_token_pool.size: - # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests - # is very samll. We add more values here to make sure we capture the maximum bs. - self.capture_bs = list( - sorted( - set( - self.capture_bs - + [model_runner.req_to_token_pool.size - 1] - + [model_runner.req_to_token_pool.size] - ) - ) - ) - - self.capture_bs = [ - bs - for bs in self.capture_bs - if bs <= model_runner.req_to_token_pool.size - and bs <= model_runner.server_args.cuda_graph_max_bs - ] - - self.compile_bs = ( - [ - bs - for bs in self.capture_bs - if bs <= self.model_runner.server_args.torch_compile_max_bs - ] - if self.use_torch_compile - else [] - ) - + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_forward_mode = ForwardMode.DECODE self.num_tokens_per_bs = 1 if model_runner.spec_algorithm.is_eagle(): if self.model_runner.is_draft_worker: - self.num_tokens_per_bs = ( - self.model_runner.server_args.speculative_eagle_topk - ) + raise RuntimeError("This should not happen") else: self.capture_forward_mode = ForwardMode.TARGET_VERIFY self.num_tokens_per_bs = ( @@ -185,10 +188,10 @@ def __init__(self, model_runner: "ModelRunner"): # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 - if self.use_torch_compile: + if self.enable_torch_compile: set_torch_compile_config() - # Common inputs + # Graph inputs with torch.device("cuda"): self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) @@ -320,7 +323,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable, capture_config: tup stream = self.stream num_tokens = bs * self.num_tokens_per_bs - # Common inputs + # Graph inputs input_ids = self.input_ids[:num_tokens] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] @@ -339,7 +342,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable, capture_config: tup global_num_tokens = None gathered_buffer = None - spec_info = self.get_spec_info(num_tokens, positions) + spec_info = self.get_spec_info(num_tokens) hip_num_cached_stages = None if self.enable_hip_attention: @@ -360,7 +363,6 @@ def capture_one_batch_size(self, bs: int, forward: Callable, capture_config: tup seq_lens_sum=seq_lens.sum(), encoder_lens=encoder_lens, return_logprob=False, - top_logprobs_nums=[0] * bs, positions=positions, global_num_tokens=global_num_tokens, gathered_buffer=gathered_buffer, @@ -400,13 +402,14 @@ def run_once(): torch.cuda.synchronize() self.model_runner.tp_group.barrier() - with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): + global global_graph_memory_pool + with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream): out = run_once() torch.cuda.synchronize() self.model_runner.tp_group.barrier() - self.graph_memory_pool = graph.pool() + global_graph_memory_pool = graph.pool() return graph, out def replay(self, forward_batch: ForwardBatch): @@ -467,35 +470,26 @@ def replay(self, forward_batch: ForwardBatch): ) return logits_output - def get_spec_info(self, num_tokens: int, positions: torch.Tensor): + def get_spec_info(self, num_tokens: int): spec_info = None if self.model_runner.spec_algorithm.is_eagle(): - from sglang.srt.speculative.eagle_utils import ( - EAGLEDraftInput, - EagleVerifyInput, - ) + from sglang.srt.speculative.eagle_utils import EagleVerifyInput if self.model_runner.is_draft_worker: - spec_info = EAGLEDraftInput() - spec_info.load_server_args(self.model_runner.server_args) - spec_info.hidden_states = self.hidden_states[:num_tokens] - spec_info.positions = positions - spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + raise RuntimeError("This should not happen.") else: spec_info = EagleVerifyInput( - None, - None, - None, - None, - None, - None, - self.model_runner.server_args.speculative_num_draft_tokens, - ) - spec_info.custom_mask = torch.zeros( - (num_tokens * self.model_runner.model_config.context_len), - dtype=torch.bool, - device="cuda", + draft_token=None, + custom_mask=torch.zeros( + (num_tokens * self.model_runner.model_config.context_len), + dtype=torch.bool, + device="cuda", + ), + positions=None, + retrive_index=None, + retrive_cum_len=None, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + capture_hidden_mode=CaptureHiddenMode.FULL, ) - spec_info.capture_hidden_mode = CaptureHiddenMode.FULL return spec_info diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 9304fbecad3..1c38da7e5cf 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -203,64 +203,6 @@ class ForwardBatch: # For Qwen2-VL mrope_positions: torch.Tensor = None - def compute_mrope_positions( - self, model_runner: ModelRunner, batch: ModelWorkerBatch - ): - device = model_runner.device - hf_config = model_runner.model_config.hf_config - mrope_positions_list = [None] * self.seq_lens.shape[0] - if self.forward_mode.is_decode(): - for i, _ in enumerate(mrope_positions_list): - mrope_position_delta = ( - 0 - if batch.image_inputs[i] is None - else batch.image_inputs[i].mrope_position_delta - ) - mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions( - mrope_position_delta, - int(self.seq_lens[i]) - 1, - int(self.seq_lens[i]), - ) - elif self.forward_mode.is_extend(): - extend_start_loc_cpu = self.extend_start_loc.cpu().numpy() - for i, image_inputs in enumerate(batch.image_inputs): - extend_start_loc, extend_seq_len, extend_prefix_len = ( - extend_start_loc_cpu[i], - batch.extend_seq_lens[i], - batch.extend_prefix_lens[i], - ) - if image_inputs is None: - # text only - mrope_positions = [ - [ - pos - for pos in range( - extend_prefix_len, extend_prefix_len + extend_seq_len - ) - ] - ] * 3 - else: - # TODO: current qwen2-vl do not support radix cache since mrope position calculation - mrope_positions, mrope_position_delta = ( - MRotaryEmbedding.get_input_positions( - input_tokens=self.input_ids[ - extend_start_loc : extend_start_loc + extend_seq_len - ], - image_grid_thw=image_inputs.image_grid_thws, - vision_start_token_id=hf_config.vision_start_token_id, - spatial_merge_size=hf_config.vision_config.spatial_merge_size, - context_len=0, - ) - ) - batch.image_inputs[i].mrope_position_delta = mrope_position_delta - mrope_positions_list[i] = mrope_positions - - self.mrope_positions = torch.concat( - [torch.tensor(pos, device=device) for pos in mrope_positions_list], - axis=1, - ) - self.mrope_positions = self.mrope_positions.to(torch.int64) - @classmethod def init_new( cls, @@ -343,7 +285,7 @@ def init_new( ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens if model_runner.model_is_mrope: - ret.compute_mrope_positions(model_runner, batch) + ret._compute_mrope_positions(model_runner, batch) # Init HiP attention information if model_runner.hip_metadata_cache_pool is not None: @@ -356,6 +298,63 @@ def init_new( return ret + def _compute_mrope_positions( + self, model_runner: ModelRunner, batch: ModelWorkerBatch + ): + device = model_runner.device + hf_config = model_runner.model_config.hf_config + mrope_positions_list = [None] * self.seq_lens.shape[0] + if self.forward_mode.is_decode(): + for i, _ in enumerate(mrope_positions_list): + mrope_position_delta = ( + 0 + if batch.image_inputs[i] is None + else batch.image_inputs[i].mrope_position_delta + ) + mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions( + mrope_position_delta, + int(self.seq_lens[i]) - 1, + int(self.seq_lens[i]), + ) + elif self.forward_mode.is_extend(): + extend_start_loc_cpu = self.extend_start_loc.cpu().numpy() + for i, image_inputs in enumerate(batch.image_inputs): + extend_start_loc, extend_seq_len, extend_prefix_len = ( + extend_start_loc_cpu[i], + batch.extend_seq_lens[i], + batch.extend_prefix_lens[i], + ) + if image_inputs is None: + # text only + mrope_positions = [ + [ + pos + for pos in range( + extend_prefix_len, extend_prefix_len + extend_seq_len + ) + ] + ] * 3 + else: + # TODO: current qwen2-vl do not support radix cache since mrope position calculation + mrope_positions, mrope_position_delta = ( + MRotaryEmbedding.get_input_positions( + input_tokens=self.input_ids[ + extend_start_loc : extend_start_loc + extend_seq_len + ], + image_grid_thw=image_inputs.image_grid_thws, + vision_start_token_id=hf_config.vision_start_token_id, + spatial_merge_size=hf_config.vision_config.spatial_merge_size, + context_len=0, + ) + ) + batch.image_inputs[i].mrope_position_delta = mrope_position_delta + mrope_positions_list[i] = mrope_positions + self.mrope_positions = torch.concat( + [torch.tensor(pos, device=device) for pos in mrope_positions_list], + axis=1, + ) + self.mrope_positions = self.mrope_positions.to(torch.int64) + def on_model_start(self): self.token_to_kv_pool.on_model_start(self) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 88f564b950f..7d48b93700a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -56,6 +56,7 @@ MLATokenToKVPool, ReqToTokenPool, ) +from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model from sglang.srt.server_args import ServerArgs @@ -550,6 +551,7 @@ def init_lora_manager(self): max_loras_per_batch=self.server_args.max_loras_per_batch, load_config=self.load_config, dtype=self.dtype, + lora_backend=self.server_args.lora_backend, ) logger.info("LoRA manager ready.") @@ -768,8 +770,6 @@ def init_double_sparsity_channel_config(self, selected_channel): def init_cuda_graphs(self): """Capture cuda graphs.""" - from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner - self.cuda_graph_runner = None if not self.is_generation: diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 365891544e0..adc50508190 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -31,10 +31,10 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from vllm.model_executor.layers.activation import QuickGELU from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.activation import QuickGELU from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 256612d9ab0..fc6ce4f2032 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -127,6 +127,7 @@ class ServerArgs: # LoRA lora_paths: Optional[List[str]] = None max_loras_per_batch: int = 8 + lora_backend: str = "triton" # Kernel backend attention_backend: Optional[str] = None @@ -287,6 +288,10 @@ def __post_init__(self): ) and check_gguf_file(self.model_path): self.quantization = self.load_format = "gguf" + # AMD-specific Triton attention KV splits default number + if is_hip(): + self.triton_attention_num_kv_splits = 16 + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args @@ -701,13 +706,19 @@ def add_cli_args(parser: argparse.ArgumentParser): nargs="*", default=None, action=LoRAPathAction, - help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}", + help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.", ) parser.add_argument( "--max-loras-per-batch", type=int, default=8, - help="Maximum number of adapters for a running batch, include base-only request", + help="Maximum number of adapters for a running batch, include base-only request.", + ) + parser.add_argument( + "--lora-backend", + type=str, + default="triton", + help="Choose the kernel backend for multi-LoRA serving.", ) # Kernel backend diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index 6412825ed8c..e0ac9fe0bb6 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -79,11 +79,13 @@ ) -def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token): +def build_tree_kernel( + parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token +): bs = seq_lens.numel() device = parent_list.device tree_mask = torch.full( - (torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,), + (seq_lens_sum * draft_token + draft_token * draft_token * bs,), True, device=device, ) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py new file mode 100644 index 00000000000..41ff5c19e5a --- /dev/null +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import bisect +import time +from typing import TYPE_CHECKING, Callable + +import torch + +from sglang.srt.model_executor.cuda_graph_runner import ( + CudaGraphRunner, + get_batch_sizes_to_capture, + get_global_graph_memory_pool, + set_global_graph_memory_pool, + set_torch_compile_config, +) +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.speculative.eagle_utils import EagleDraftInput + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.eagle_worker import EAGLEWorker + + +class EAGLEDraftCudaGraphRunner: + def __init__(self, eagle_worker: EAGLEWorker): + # Parse args + self.eagle_worker = eagle_worker + self.model_runner = model_runner = eagle_worker.model_runner + self.graphs = {} + self.output_buffers = {} + self.enable_torch_compile = model_runner.server_args.enable_torch_compile + self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.tp_size = self.model_runner.tp_size + self.dp_size = model_runner.server_args.dp_size + self.topk = model_runner.server_args.speculative_eagle_topk + self.speculative_num_steps = model_runner.server_args.speculative_num_steps + server_args = model_runner.server_args + + assert self.disable_padding + + # Batch sizes to capture + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) + self.num_tokens_per_bs = server_args.speculative_eagle_topk + + # Attention backend + self.max_bs = max(self.capture_bs) + self.max_num_token = self.max_bs * self.num_tokens_per_bs + self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token) + self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[ + 0 + ].get_cuda_graph_seq_len_fill_value() + + if self.enable_torch_compile: + set_torch_compile_config() + + # Graph inputs + with torch.device("cuda"): + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) + self.seq_lens = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) + self.out_cache_loc = torch.zeros( + (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64 + ) + self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) + self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) + self.hidden_states = torch.zeros( + (self.max_bs, self.model_runner.model_config.hidden_size), + dtype=self.model_runner.dtype, + ) + + # Capture + try: + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture cuda graph failed: {e}\n" + "Possible solutions:\n" + "1. disable cuda graph by --disable-cuda-graph\n" + "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" + "3. disable torch compile by not using --enable-torch-compile\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" + ) + + def can_run(self, forward_batch: ForwardBatch): + is_bs_supported = ( + forward_batch.batch_size in self.graphs + if self.disable_padding + else forward_batch.batch_size <= self.max_bs + ) + return is_bs_supported + + def capture(self): + CudaGraphRunner.capture(self) + + def capture_one_batch_size(self, num_seqs: int, forward: Callable): + graph = torch.cuda.CUDAGraph() + stream = self.stream + num_tokens = num_seqs * self.num_tokens_per_bs + + # Graph inputs + req_pool_indices = self.req_pool_indices[:num_seqs] + seq_lens = self.seq_lens[:num_seqs] + out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps] + positions = self.positions[:num_tokens] + topk_p = self.topk_p[:num_seqs] + topk_index = self.topk_index[:num_seqs] + hidden_states = self.hidden_states[:num_seqs] + + spec_info = EagleDraftInput( + topk_p=topk_p, + topk_index=topk_index, + hidden_states=hidden_states, + ) + + # Forward batch + forward_batch = ForwardBatch( + forward_mode=ForwardMode.DECODE, + batch_size=num_seqs, + input_ids=None, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens.sum(), + return_logprob=False, + positions=positions, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=spec_info, + capture_hidden_mode=( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ), + ) + + # Attention backend + self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph( + forward_batch + ) + + # Run and capture + def run_once(): + # Backup two fileds, which will be modified in-place in `draft_forward`. + output_cache_loc_backup = forward_batch.out_cache_loc + hidden_states_backup = forward_batch.spec_info.hidden_states + + ret = self.eagle_worker.draft_forward(forward_batch) + + forward_batch.out_cache_loc = output_cache_loc_backup + forward_batch.spec_info.hidden_states = hidden_states_backup + return ret + + for _ in range(2): + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + run_once() + + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + with torch.cuda.graph( + graph, pool=get_global_graph_memory_pool(), stream=stream + ): + out = run_once() + + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + set_global_graph_memory_pool(graph.pool()) + return graph, out + + def replay(self, forward_batch: ForwardBatch): + assert forward_batch.out_cache_loc is not None + raw_bs = forward_batch.batch_size + raw_num_token = raw_bs * self.num_tokens_per_bs + + # Pad + index = bisect.bisect_left(self.capture_bs, raw_bs) + bs = self.capture_bs[index] + if bs != raw_bs: + self.seq_lens.fill_(1) + self.out_cache_loc.zero_() + + # Common inputs + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + self.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_( + forward_batch.out_cache_loc + ) + self.positions[:raw_num_token].copy_(forward_batch.positions) + self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) + self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) + self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + + # Attention backend + self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( + forward_batch + ) + + # Replay + self.graphs[bs].replay() + + return self.output_buffers[bs] diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 97cdb264043..4abcba9550d 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses from typing import TYPE_CHECKING, List import torch @@ -9,201 +10,33 @@ from sglang.srt.layers.attention.flashinfer_backend import ( create_flashinfer_kv_indices_triton, ) -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.build_eagle_tree import build_tree_kernel -from sglang.srt.speculative.spec_info import SpecInfo if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch - from sglang.srt.server_args import ServerArgs -@triton.jit -def eagle_verify_retrive( - retrive_index, - accept_mask, - retrive_cum_len, - accept_index, - accept_length, - extract_index, - max_len: tl.constexpr, - draft_token_num: tl.constexpr, - max_len_upper: tl.constexpr, -): - pid = tl.program_id(axis=0) - - retrive_end = tl.load(retrive_cum_len + pid + 1) - retrive_start = tl.load(retrive_cum_len + pid) - retrive_len = retrive_end - retrive_start - accept_ptr = accept_mask + retrive_start - accept_offset = tl.arange(0, draft_token_num) - accept_load_mask = accept_offset < retrive_len - accept_len_list = tl.load( - accept_ptr + accept_offset, mask=accept_load_mask, other=-1 - ) - - accept_len = tl.max(accept_len_list) - max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True) - # triton is not support argmax with tie_break_right, so I need implement it by some way - mask_max = accept_len_list == accept_len - - count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32) - count = tl.sum(tl.where(mask_max, 1, count_mask)) - if count > 1: - index = tl.arange(0, draft_token_num) - mask_left = index != max_index - remained_index = tl.where(mask_max and mask_left, index, 0) - max_index = tl.max(remained_index) - - tl.store(accept_length + pid, accept_len) - retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len - retrive_offset = tl.arange(0, max_len_upper) - retrive_load_mask = retrive_offset < accept_len + 1 - data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask) - - tl.store( - accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask - ) - - extract_load_ptr = accept_index + pid * max_len + accept_len - if accept_len == max_len - 1: - extract_data = tl.load(extract_load_ptr - 1) - tl.store(extract_index + pid * 2, extract_data) - extract_data = tl.load(extract_load_ptr) - tl.store(extract_index + pid * 2 + 1, extract_data) - - else: - extract_data = tl.load(extract_load_ptr) - tl.store(extract_index + pid * 2, extract_data) - - -@triton.jit -def create_extend_spec_info( - verified_id, - seq_len, - accept_len, - accept_len_cum, - positions, - new_verified_id, - accept_len_upper: tl.constexpr, -): - pid = tl.program_id(axis=0) - offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1) - seq_length = tl.load(seq_len + pid) - accept_length = tl.load(accept_len + pid) - positions_ptr = positions + offset - data = tl.arange(0, accept_len_upper) - mask = data < accept_length - tl.store(positions_ptr + data, seq_length - accept_length + data, mask) - - offset = tl.load(accept_len_cum + pid) - 1 - verified_id_data = tl.load(verified_id + offset) - tl.store(new_verified_id + pid, verified_id_data) - - -@triton.jit -def assign_req_to_token_pool( - req_pool_indices, - req_to_token, - start_offset, - end_offset, - out_cache_loc, - pool_len: tl.constexpr, - bs_upper: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 32 - pid = tl.program_id(axis=0) - kv_start = tl.load(start_offset + pid) - kv_end = tl.load(end_offset + pid) - token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len - - length_offset = tl.arange(0, bs_upper) - start = tl.load(start_offset + length_offset, mask=length_offset < pid) - end = tl.load(end_offset + length_offset, mask=length_offset < pid) - out_offset = tl.sum(end - start, axis=0) - - out_cache_ptr = out_cache_loc + out_offset - - save_offset = tl.arange(0, BLOCK_SIZE) + kv_start - load_offset = tl.arange(0, BLOCK_SIZE) - - num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for _ in range(num_loop): - mask = save_offset < kv_end - data = tl.load(out_cache_ptr + load_offset, mask=mask) - tl.store(token_pool + save_offset, data, mask=mask) - save_offset += BLOCK_SIZE - load_offset += BLOCK_SIZE +@dataclasses.dataclass +class EagleDraftInput: + # The inputs for decode + # shape: (b, topk) + topk_p: torch.Tensor = None + topk_index: torch.Tensor = None + # shape: (b, hidden_size) + hidden_states: torch.Tensor = None + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL + # Inputs for extend + # shape: (b,) + verified_id: torch.Tensor = None + accept_length: torch.Tensor = None + accept_length_cpu: List[int] = None -@triton.jit -def generate_draft_decode_kv_indices( - req_pool_indices, - req_to_token, - paged_kernel_lens, - kv_indices, - iters: tl.constexpr, - topk: tl.constexpr, - pool_len: tl.constexpr, - bs_upper: tl.constexpr, - iter_upper: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 128 - bid = tl.program_id(axis=0) - topk_id = tl.program_id(axis=1) - - load_offset = tl.arange(0, bs_upper) - seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid) - seq_len = tl.load(paged_kernel_lens + bid) - cum_seq_len = tl.sum(seq_lens) - - kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) - kv_ptr = kv_indices + kv_offset - token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len - - kv_offset = tl.arange(0, BLOCK_SIZE) - num_loop = tl.cdiv(seq_len, BLOCK_SIZE) - for _ in range(num_loop): - mask = kv_offset < seq_len - data = tl.load(token_pool_ptr + kv_offset, mask=mask) - tl.store(kv_ptr + kv_offset, data, mask=mask) - kv_offset += BLOCK_SIZE - - extend_offset = tl.arange(0, iter_upper) - extend_data = tl.load( - token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id, - mask=extend_offset < iters, - ) - tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) - - -class EAGLEDraftInput(SpecInfo): - def __init__(self): - self.prev_mode = ForwardMode.DECODE - - self.scores: torch.Tensor = None - self.score_list: List[torch.Tensor] = [] - self.token_list: List[torch.Tensor] = [] - self.origin_score_list: List[torch.Tensor] = [] # used for sampling - self.parents_list: List[torch.Tensor] = [] - self.cache_list: List[torch.Tenor] = [] - self.iter = 0 - - # shape: (b, hidden_size) - self.hidden_states: torch.Tensor = None - # shape: (b,) - self.verified_id: torch.Tensor = None - # shape: (b, vocab_size) - self.sample_output: torch.Tensor = None - - self.positions: torch.Tensor = None - self.accept_length: torch.Tensor = None - self.accept_length_cpu: List[int] = None - - def load_server_args(self, server_args: ServerArgs): - self.topk: int = server_args.speculative_eagle_topk - self.num_verify_token: int = server_args.speculative_num_draft_tokens - self.spec_steps = server_args.speculative_num_steps + # Inputs for the attention backends + # shape: (b + 1,) + kv_indptr: torch.Tensor = None + kv_indices: torch.Tensor = None def prepare_for_extend(self, batch: ScheduleBatch): req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) @@ -231,95 +64,12 @@ def prepare_for_extend(self, batch: ScheduleBatch): assert len(batch.extend_lens) == 1 batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id)) - def filter_batch( - self, - new_indices: torch.Tensor, - ): - self.sample_output = self.sample_output[: len(new_indices)] - self.hidden_states = self.hidden_states[: len(new_indices)] - self.verified_id = self.verified_id[: len(new_indices)] - - def prepare_for_decode(self, batch: ScheduleBatch): - prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab) - top = torch.topk(prob, self.topk, dim=-1) - topk_index, topk_p = ( - top.indices, - top.values, - ) # shape: (b * top_k, top_k) or (b, top_k) - - if self.prev_mode.is_decode(): - scores = torch.mul( - self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk) - ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) - topk_cs = torch.topk( - scores.flatten(start_dim=1), self.topk, dim=-1 - ) # (b, topk) - topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values - - selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange( - 0, batch.batch_size() * self.topk, step=self.topk, device="cuda" - ).repeat_interleave(self.topk) - - batch.spec_info.hidden_states = batch.spec_info.hidden_states[ - selected_input_index, : - ] - - topk_index = topk_index.reshape(-1, self.topk**2) - batch.input_ids = torch.gather( - topk_index, index=topk_cs_index, dim=1 - ).flatten() - batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) - - self.scores = topk_cs_p - self.score_list.append(scores) # (b, topk, topk) - self.token_list.append(topk_index) # (b, topk * topk) - self.origin_score_list.append(topk_p.reshape(topk_index.shape)) - self.parents_list.append( - topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk) - ) # shape: (b, topk) - else: - # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND - batch.spec_info.hidden_states = ( - batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0) - ) - - batch.input_ids = topk_index.flatten() - batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) - - self.scores = topk_p # shape: (b, topk) - self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk) - self.token_list.append(topk_index) # shape: (b, topk) - self.origin_score_list.append(topk_p) - self.parents_list.append( - torch.arange(-1, self.topk, dtype=torch.long, device="cuda") - .unsqueeze(0) - .repeat(self.scores.shape[0], 1) - ) # shape: (b, topk + 1) - self.cache_list.append(batch.out_cache_loc) - self.positions = ( - batch.seq_lens[:, None] - + torch.full( - [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long - ) - ).flatten() - - bs = len(batch.seq_lens) - assign_req_to_token_pool[(bs,)]( - batch.req_pool_indices, - batch.req_to_token_pool.req_to_token, - batch.seq_lens + self.topk * self.iter, - batch.seq_lens + self.topk * (self.iter + 1), - batch.out_cache_loc, - batch.req_to_token_pool.req_to_token.shape[1], - triton.next_power_of_2(bs), - ) - self.iter += 1 - - def prepare_extend_after_decode(self, batch: ScheduleBatch): + def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps): batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) accept_length_cpu = batch.spec_info.accept_length_cpu batch.extend_lens = [x + 1 for x in accept_length_cpu] batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend + batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend seq_lens_cpu = batch.seq_lens.tolist() pt = 0 @@ -348,86 +98,13 @@ def prepare_extend_after_decode(self, batch: ScheduleBatch): torch.cumsum(self.accept_length, axis=0, dtype=torch.int), self.positions, new_verified_id, - triton.next_power_of_2(self.spec_steps + 1), + triton.next_power_of_2(speculative_num_steps + 1), ) batch.seq_lens_sum = sum(seq_lens_cpu) batch.input_ids = self.verified_id self.verified_id = new_verified_id - def prepare_for_verify(self, batch: ScheduleBatch): - score_list = torch.cat(self.score_list, dim=1).flatten( - 1 - ) # b, n, topk; n= 1+(self.iter-1)*self.topk - ss_token_list = torch.cat( - self.token_list, dim=1 - ) # b, (self.topk+(self.iter-1)*self.topk) - origin_token_list = torch.cat(self.origin_score_list, dim=1) - top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1) - top_scores_index = top_scores.indices - top_scores_index = torch.sort(top_scores_index).values - - draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) - scores = torch.gather(origin_token_list, index=top_scores_index, dim=1) - draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1) - parent_list = torch.cat(self.parents_list[:-1], dim=1) - - tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( - parent_list, - top_scores_index, - batch.seq_lens, - self.topk, - self.iter - 1, - self.num_verify_token, - ) - - return EagleVerifyInput( - draft_tokens.flatten(), - scores.flatten(), - tree_mask, - position, - retrive_index, - retrive_cum_len, - self.num_verify_token, - ) - - def generate_attn_arg_decode( - self, - req_pool_indices: torch.Tensor, - paged_kernel_lens: torch.Tensor, - req_to_token: torch.Tensor, - ): - seq_num = req_pool_indices.numel() - bs = self.topk * req_pool_indices.numel() - seq_len = self.positions.reshape(-1).contiguous() - - cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") - cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0) - total_len = torch.sum(paged_kernel_lens).item() - - kv_indices = torch.empty( - (total_len * self.topk + seq_num * self.iter * self.topk,), - dtype=torch.int32, - device="cuda", - ) - - generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)]( - req_pool_indices, - req_to_token, - paged_kernel_lens, - kv_indices, - self.iter, - self.topk, - req_to_token.shape[1], - triton.next_power_of_2(seq_num), - triton.next_power_of_2(self.spec_steps), - ) - return bs, kv_indices, cum_kv_seq_len - - def clear_draft_cache(self, batch): - draft_cache = torch.cat(self.cache_list, dim=0) - batch.token_to_kv_pool.free(draft_cache) - def generate_attn_arg_prefill( self, req_pool_indices: torch.Tensor, @@ -454,12 +131,18 @@ def generate_attn_arg_prefill( return kv_indices, cum_kv_seq_len, qo_indptr, None - def merge_batch(self, spec_info: EAGLEDraftInput): + def filter_batch(self, new_indices: torch.Tensor): + self.topk_p = self.topk_p[: len(new_indices)] + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] + + def merge_batch(self, spec_info: EagleDraftInput): if self.hidden_states is None: self.hidden_states = spec_info.hidden_states self.verified_id = spec_info.verified_id - self.sample_output = spec_info.sample_output - self.prev_mode = spec_info.prev_mode + self.topk_p = spec_info.topk_p + self.topk_index = spec_info.topk_index return if spec_info.hidden_states is None: return @@ -467,32 +150,68 @@ def merge_batch(self, spec_info: EAGLEDraftInput): [self.hidden_states, spec_info.hidden_states], axis=0 ) self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) - self.sample_output = torch.cat([self.sample_output, spec_info.sample_output]) - - -class EagleVerifyInput(SpecInfo): - def __init__( - self, - draft_token: torch.Tensor, - draft_score: torch.Tensor, - tree_mask: torch.Tensor, - positions: torch.Tensor, - retrive_index: torch.Tensor, - retrive_cum_len: torch.Tensor, - draft_token_num: int, + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) + + +@dataclasses.dataclass +class EagleVerifyInput: + draft_token: torch.Tensor + custom_mask: torch.Tensor + positions: torch.Tensor + retrive_index: torch.Tensor + retrive_cum_len: torch.Tensor + draft_token_num: int + capture_hidden_mode: CaptureHiddenMode + + @classmethod + def create( + cls, + verified_id: torch.Tensor, + score_list: List[torch.Tensor], + token_list: List[torch.Tensor], + parents_list: List[torch.Tensor], + seq_lens: torch.Tensor, + seq_lens_sum: int, + topk: int, + spec_steps: int, + num_verify_token: int, ): - self.draft_token = draft_token - self.draft_score = draft_score - self.custom_mask = tree_mask - self.positions = positions - self.retrive_index = retrive_index - self.retrive_cum_len = retrive_cum_len - self.draft_token_num = draft_token_num + score_list = torch.cat(score_list, dim=1).flatten( + 1 + ) # b, n, topk; n= 1 + (num_steps-1) * self.topk + ss_token_list = torch.cat( + token_list, dim=1 + ) # b, (self.topk + (num_steps-1) * self.topk) + top_scores = torch.topk(score_list, num_verify_token - 1, dim=-1) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) + draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1) + parent_list = torch.cat(parents_list[:-1], dim=1) + tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( + parent_list, + top_scores_index, + seq_lens, + seq_lens_sum, + topk, + spec_steps, + num_verify_token, + ) + return cls( + draft_tokens.flatten(), + tree_mask, + position, + retrive_index, + retrive_cum_len, + num_verify_token, + CaptureHiddenMode.FULL, + ) def prepare_for_verify(self, batch: ScheduleBatch): batch.input_ids = self.draft_token batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) - bs = batch.seq_lens.numel() + bs = batch.batch_size() assign_req_to_token_pool[(bs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, @@ -573,7 +292,6 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten triton.next_power_of_2(max_draft_len), ) - draft_input = EAGLEDraftInput() new_accept_index = [] unfinished_index = [] finished_extend_len = {} # {rid:accept_length + 1} @@ -625,18 +343,23 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten ) batch.seq_lens.add_(accept_length + 1) + draft_input = EagleDraftInput() if len(new_accept_index) > 0: new_accept_index = torch.tensor(new_accept_index, device="cuda") - draft_input.verified_id = predict[new_accept_index] draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] + draft_input.verified_id = predict[new_accept_index] draft_input.accept_length = accept_length[unfinished_index] draft_input.accept_length_cpu = [ accept_length_cpu[i] for i in unfinished_index ] if has_finished: draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] + draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[ + unfinished_index + ] else: draft_input.seq_lens_for_draft_extend = batch.seq_lens + draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices logits_output.next_token_logits = logits_output.next_token_logits[accept_index] return ( @@ -646,3 +369,269 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten finished_extend_len, accept_length_cpu, ) + + +@triton.jit +def eagle_verify_retrive( + retrive_index, + accept_mask, + retrive_cum_len, + accept_index, + accept_length, + extract_index, + max_len: tl.constexpr, + draft_token_num: tl.constexpr, + max_len_upper: tl.constexpr, +): + pid = tl.program_id(axis=0) + + retrive_end = tl.load(retrive_cum_len + pid + 1) + retrive_start = tl.load(retrive_cum_len + pid) + retrive_len = retrive_end - retrive_start + accept_ptr = accept_mask + retrive_start + accept_offset = tl.arange(0, draft_token_num) + accept_load_mask = accept_offset < retrive_len + accept_len_list = tl.load( + accept_ptr + accept_offset, mask=accept_load_mask, other=-1 + ) + + accept_len = tl.max(accept_len_list) + max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True) + # triton is not support argmax with tie_break_right, so I need implement it by some way + mask_max = accept_len_list == accept_len + + count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32) + count = tl.sum(tl.where(mask_max, 1, count_mask)) + if count > 1: + index = tl.arange(0, draft_token_num) + mask_left = index != max_index + remained_index = tl.where(mask_max and mask_left, index, 0) + max_index = tl.max(remained_index) + + tl.store(accept_length + pid, accept_len) + retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len + retrive_offset = tl.arange(0, max_len_upper) + retrive_load_mask = retrive_offset < accept_len + 1 + data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask) + + tl.store( + accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask + ) + + extract_load_ptr = accept_index + pid * max_len + accept_len + if accept_len == max_len - 1: + extract_data = tl.load(extract_load_ptr - 1) + tl.store(extract_index + pid * 2, extract_data) + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2 + 1, extract_data) + + else: + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2, extract_data) + + +@triton.jit +def create_extend_spec_info( + verified_id, + seq_len, + accept_len, + accept_len_cum, + positions, + new_verified_id, + accept_len_upper: tl.constexpr, +): + pid = tl.program_id(axis=0) + offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1) + seq_length = tl.load(seq_len + pid) + accept_length = tl.load(accept_len + pid) + positions_ptr = positions + offset + data = tl.arange(0, accept_len_upper) + mask = data < accept_length + tl.store(positions_ptr + data, seq_length - accept_length + data, mask) + + offset = tl.load(accept_len_cum + pid) - 1 + verified_id_data = tl.load(verified_id + offset) + tl.store(new_verified_id + pid, verified_id_data) + + +@triton.jit +def assign_req_to_token_pool( + req_pool_indices, + req_to_token, + start_offset, + end_offset, + out_cache_loc, + pool_len: tl.constexpr, + bs_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 32 + pid = tl.program_id(axis=0) + kv_start = tl.load(start_offset + pid) + kv_end = tl.load(end_offset + pid) + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + + length_offset = tl.arange(0, bs_upper) + start = tl.load(start_offset + length_offset, mask=length_offset < pid) + end = tl.load(end_offset + length_offset, mask=length_offset < pid) + out_offset = tl.sum(end - start, axis=0) + + out_cache_ptr = out_cache_loc + out_offset + + save_offset = tl.arange(0, BLOCK_SIZE) + kv_start + load_offset = tl.arange(0, BLOCK_SIZE) + + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = save_offset < kv_end + data = tl.load(out_cache_ptr + load_offset, mask=mask) + tl.store(token_pool + save_offset, data, mask=mask) + save_offset += BLOCK_SIZE + load_offset += BLOCK_SIZE + + +@triton.jit +def assign_draft_cache_locs( + req_pool_indices, + req_to_token, + seq_lens, + out_cache_loc, + pool_len: tl.constexpr, + topk: tl.constexpr, + speculative_num_steps: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 32 + pid = tl.program_id(axis=0) + kv_start = tl.load(seq_lens + pid) + kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps + + num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE) + for i in range(num_loop): + save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start + load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = save_offset < kv_end + data = tl.load(out_cache_ptr + load_offset, mask=mask) + tl.store(token_pool + save_offset, data, mask=mask) + + +@triton.jit +def generate_draft_decode_kv_indices( + req_pool_indices, + req_to_token, + paged_kernel_lens, + kv_indices, + kv_indptr, + positions, + num_seqs: tl.constexpr, + topk: tl.constexpr, + pool_len: tl.constexpr, + kv_indices_stride: tl.constexpr, + kv_indptr_stride: tl.constexpr, + bs_upper: tl.constexpr, + iter_upper: tl.constexpr, + num_tokens_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 128 + iters = tl.program_id(axis=0) + bid = tl.program_id(axis=1) + topk_id = tl.program_id(axis=2) + + kv_indices += kv_indices_stride * iters + kv_indptr += kv_indptr_stride * iters + iters += 1 + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid) + seq_len = tl.load(paged_kernel_lens + bid) + cum_seq_len = tl.sum(seq_lens) + + kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) + kv_ptr = kv_indices + kv_offset + token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len + + kv_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(seq_len, BLOCK_SIZE) + for _ in range(num_loop): + mask = kv_offset < seq_len + data = tl.load(token_pool_ptr + kv_offset, mask=mask) + tl.store(kv_ptr + kv_offset, data, mask=mask) + kv_offset += BLOCK_SIZE + + extend_offset = tl.arange(0, iter_upper) + extend_data = tl.load( + token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id, + mask=extend_offset < iters, + ) + tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) + + # Update kv_indptr + bs_offset = tl.arange(0, num_tokens_upper) + + zid = bid * topk + topk_id + if zid == 0: + zid = num_seqs * topk + positions = tl.load(positions + bs_offset, mask=bs_offset < zid) + base = tl.sum(positions) + tl.store(kv_indptr + zid, base + zid * iters) + + +@torch.compile +def select_top_k_tokens( + i: int, + topk_p: torch.Tensor, + topk_index: torch.Tensor, + hidden_states: torch.Tensor, + scores: torch.Tensor, + topk: int, +): + if i == 0: + # The first step after extend + input_ids = topk_index.flatten() + hidden_states = hidden_states.repeat_interleave(topk, dim=0) + scores = topk_p # shape: (b, topk) + + tree_info = ( + topk_p.unsqueeze(1), # shape: (b, 1, topk) + topk_index, # shape: (b, topk) + torch.arange(-1, topk, dtype=torch.long, device="cuda") + .unsqueeze(0) + .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) + ) + + else: + # The later decode steps + expand_scores = torch.mul( + scores.unsqueeze(2), topk_p.reshape(-1, topk, topk) + ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) + + topk_cs_p, topk_cs_index = fast_topk( + expand_scores.flatten(start_dim=1), topk, dim=-1 + ) # (b, topk) + scores = topk_cs_p # shape: (b, topk) + + topk_index = topk_index.reshape(-1, topk**2) + input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() + + selected_input_index = topk_cs_index.flatten() // topk + torch.arange( + 0, hidden_states.shape[0], step=topk, device="cuda" + ).repeat_interleave(topk) + hidden_states = hidden_states[selected_input_index, :] + + tree_info = ( + expand_scores, # shape: (b, topk, topk) + topk_index, # shape: (b, topk * topk) + topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk) + ) + + return input_ids, hidden_states, scores, tree_info + + +def fast_topk(values, topk, dim): + if topk == 1: + # Use max along the specified dimension to get both value and index + max_value, max_index = torch.max(values, dim=dim) + return max_value.unsqueeze(1), max_index.unsqueeze(1) + else: + # Use topk for efficiency with larger k values + return torch.topk(values, topk, dim=dim) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 06a4372fce2..6d84cc30510 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1,3 +1,5 @@ +import logging +import time from typing import List, Optional, Union import torch @@ -12,8 +14,18 @@ ) from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs -from sglang.srt.speculative.eagle_utils import EAGLEDraftInput -from sglang.srt.utils import rank0_print +from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( + EAGLEDraftCudaGraphRunner, +) +from sglang.srt.speculative.eagle_utils import ( + EagleDraftInput, + EagleVerifyInput, + assign_draft_cache_locs, + fast_topk, + select_top_k_tokens, +) + +logger = logging.getLogger(__name__) class EAGLEWorker(TpModelWorker): @@ -40,41 +52,47 @@ def __init__( is_draft_worker=True, ) self.target_worker = target_worker - self.server_args = server_args self.finish_extend_len = [] + # Parse arguments + self.topk = server_args.speculative_eagle_topk + self.speculative_num_steps = server_args.speculative_num_steps + self.server_args = server_args + # Share the embedding and lm_head embed, head = self.target_worker.model_runner.model.get_embed_and_head() self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph - self.model_runner.init_cuda_graphs() - def forward_draft_decode(self, batch: ScheduleBatch): - batch.spec_info.prepare_for_decode(batch) - batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST - model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch) - self.capture_for_decode(logits_output, forward_batch) + # Create multi-step attn backends and cuda graph runners + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferMultiStepDraftBackend, + ) - def forward_draft_extend(self, batch: ScheduleBatch): - self._set_mem_pool(batch, self.model_runner) - batch.spec_info.prepare_for_extend(batch) - batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST - model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch) - self.capture_for_decode(logits_output, forward_batch) - self._set_mem_pool(batch, self.target_worker.model_runner) + self.draft_attn_backend = FlashInferMultiStepDraftBackend( + self.model_runner, + self.topk, + self.speculative_num_steps, + ) + self.model_runner.draft_attn_backend = self.draft_attn_backend + self.init_cuda_graphs() + + def init_cuda_graphs(self): + """Capture cuda graphs.""" + self.cuda_graph_runner = None + + if self.server_args.disable_cuda_graph: + return + + tic = time.time() + logger.info("Capture cuda graph begin. This can take up to several minutes.") + self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) + logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") def forward_batch_speculative_generation(self, batch: ScheduleBatch): if batch.forward_mode.is_decode(): # Draft - self._set_mem_pool(batch, self.model_runner) - for i in range(self.server_args.speculative_num_steps): - self.forward_draft_decode(batch) - batch.spec_info.clear_draft_cache(batch) - self._set_mem_pool(batch, self.target_worker.model_runner) + spec_info: EagleVerifyInput = self.draft(batch) # Verify ( @@ -84,8 +102,7 @@ def forward_batch_speculative_generation(self, batch: ScheduleBatch): self.finish_extend_len, accept_length_cpu, model_worker_batch, - ) = self.verify(batch) - next_draft_input.load_server_args(self.server_args) + ) = self.verify(batch, spec_info) batch.spec_info = next_draft_input # if it is None, means all requsets are finished if batch.spec_info.verified_id is not None: @@ -107,39 +124,156 @@ def forward_batch_speculative_generation(self, batch: ScheduleBatch): ) # Forward with the draft model. - spec_info = EAGLEDraftInput() - spec_info.load_server_args(self.server_args) - spec_info.hidden_states = logits_output.hidden_states - spec_info.verified_id = next_token_ids - batch.spec_info = spec_info + batch.spec_info = EagleDraftInput( + hidden_states=logits_output.hidden_states, + verified_id=next_token_ids, + ) self.forward_draft_extend(batch) return logits_output, next_token_ids, model_worker_batch, 0 - def verify(self, batch: ScheduleBatch): - verify_input = batch.spec_info.prepare_for_verify(batch) - verify_input.prepare_for_verify(batch) + def draft(self, batch: ScheduleBatch): + self._set_mem_pool(batch, self.model_runner) + + # Parse args + num_seqs = batch.batch_size() + spec_info = batch.spec_info + + # Allocate cache locations + out_cache_loc = batch.alloc_token_slots( + num_seqs * self.topk * self.speculative_num_steps + ) + assign_draft_cache_locs[(num_seqs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + self.topk, + self.speculative_num_steps, + ) + + batch.out_cache_loc = out_cache_loc + batch.seq_lens_sum = torch.sum(batch.seq_lens).item() + spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) + + # Get forward batch + spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run( + forward_batch + ) + + if can_cuda_graph: + score_list, token_list, parents_list = self.cuda_graph_runner.replay( + forward_batch + ) + else: + # Initialize attention backend + self.draft_attn_backend.init_forward_metadata(forward_batch) + + # Run forward steps + score_list, token_list, parents_list = self.draft_forward(forward_batch) + + ret = EagleVerifyInput.create( + spec_info.verified_id, + score_list, + token_list, + parents_list, + batch.seq_lens, + batch.seq_lens_sum, + self.topk, + self.speculative_num_steps, + self.server_args.speculative_num_draft_tokens, + ) + + # Free cache locations + batch.token_to_kv_pool.free(out_cache_loc) + self._set_mem_pool(batch, self.target_worker.model_runner) + return ret + + def draft_forward(self, forward_batch: ForwardBatch): + # Parse args + spec_info = forward_batch.spec_info + out_cache_loc = forward_batch.out_cache_loc + topk_p, topk_index, hidden_states = ( + spec_info.topk_p, + spec_info.topk_index, + spec_info.hidden_states, + ) + + # Return values + score_list: List[torch.Tensor] = [] + token_list: List[torch.Tensor] = [] + parents_list: List[torch.Tensor] = [] + + # Forward multiple steps + scores = None + for i in range(self.speculative_num_steps): + input_ids, hidden_states, scores, tree_info = select_top_k_tokens( + i, topk_p, topk_index, hidden_states, scores, self.topk + ) + score_list.append(tree_info[0]) + token_list.append(tree_info[1]) + parents_list.append(tree_info[2]) + + # Set inputs + forward_batch.input_ids = input_ids + forward_batch.out_cache_loc = out_cache_loc[ + forward_batch.batch_size + * self.topk + * i : forward_batch.batch_size + * self.topk + * (i + 1) + ] + forward_batch.positions.add_(1) + forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] + spec_info.hidden_states = hidden_states + + # Run forward + logits_output = self.model_runner.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + hidden_states = logits_output.hidden_states + + return score_list, token_list, parents_list + + def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): + spec_info.prepare_for_verify(batch) batch.forward_mode = ForwardMode.TARGET_VERIFY - batch.spec_info = verify_input - batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch() logits_output, _ = self.target_worker.forward_batch_generation( model_worker_batch, skip_sample=True ) - verify_input.hidden_states = logits_output.hidden_states - res = verify_input.verify(batch, logits_output) + spec_info.hidden_states = logits_output.hidden_states + res = spec_info.verify(batch, logits_output) batch.forward_mode = ForwardMode.DECODE return res + (model_worker_batch,) + def forward_draft_extend(self, batch: ScheduleBatch): + self._set_mem_pool(batch, self.model_runner) + batch.spec_info.prepare_for_extend(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) + self._set_mem_pool(batch, self.target_worker.model_runner) + def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): batch.token_to_kv_pool = runner.token_to_kv_pool batch.req_to_token_pool = runner.req_to_token_pool def forward_draft_extend_after_decode(self, batch: ScheduleBatch): seq_lens_backup = batch.seq_lens + req_pool_indices_backup = batch.req_pool_indices self._set_mem_pool(batch, self.model_runner) batch.forward_mode = ForwardMode.DRAFT_EXTEND - batch.spec_info.prepare_extend_after_decode(batch) + batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) @@ -151,17 +285,15 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch): # This is because `seq_lens` can be modified in `prepare_extend_after_decode` batch.forward_mode = ForwardMode.DECODE batch.seq_lens = seq_lens_backup + batch.req_pool_indices = req_pool_indices_backup def capture_for_decode( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch ): - sample_output = torch.softmax( - logits_output.next_token_logits, dim=-1 - ) # TODO(kavioyu): Support more sampling methods + probs = torch.softmax(logits_output.next_token_logits, dim=-1) spec_info = forward_batch.spec_info - spec_info.sample_output = sample_output + spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1) spec_info.hidden_states = logits_output.hidden_states - spec_info.prev_mode = forward_batch.forward_mode # Don't support prefix share now. def finish_request(self, reqs: Union[Req, List[Req]]): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 898f2debb9f..090c4f7f61e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1070,6 +1070,13 @@ def get_device_name(device_id: int = 0) -> str: return torch.hpu.get_device_name(device_id) +def get_device_core_count(device_id: int = 0) -> int: + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return torch.cuda.get_device_properties(device_id).multi_processor_count + + return 0 + + def get_device_capability(device_id: int = 0) -> Tuple[int, int]: major, minor = None, None if hasattr(torch, "cuda") and torch.cuda.is_available(): diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index bae0fcf2a49..6486b2550da 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -272,6 +272,7 @@ def __init__( port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, lora_paths: List[str] = None, max_loras_per_batch: int = 4, + lora_backend: str = "triton", disable_cuda_graph: bool = False, disable_radix_cache: bool = False, ): @@ -287,6 +288,7 @@ def __init__( is_embedding=not self.is_generation, lora_paths=lora_paths, max_loras_per_batch=max_loras_per_batch, + lora_backend=lora_backend, disable_cuda_graph=disable_cuda_graph, disable_radix_cache=disable_radix_cache, ) diff --git a/python/sglang/version.py b/python/sglang/version.py index df12433297b..615d4c40d17 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.4.2" +__version__ = "0.4.2.post2" diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index 1a059d5ff68..ffe405d5aa6 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -4,16 +4,17 @@ set -euxo pipefail # Install the dependency in CI. # Use repo from environment variable, passed from GitHub Actions -FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.4/flashinfer}" +FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer}" SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" bash "${SCRIPT_DIR}/killall_sglang.sh" pip install --upgrade pip -pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ +pip uninstall flashinfer -y +pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ # Force reinstall flashinfer and torch_memory_saver -pip install flashinfer==0.1.6 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps +pip install flashinfer_python==0.2.0.post2 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps pip install torch_memory_saver --force-reinstall pip install transformers==4.45.2 sentence_transformers accelerate peft diff --git a/sgl-kernel/3rdparty/cutlass b/sgl-kernel/3rdparty/cutlass index b78588d1630..3c28697b9f4 160000 --- a/sgl-kernel/3rdparty/cutlass +++ b/sgl-kernel/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit b78588d1630aa6643bf021613717bafb705df4ef +Subproject commit 3c28697b9f41fee4517b1758ffe83a85ac3ce2b4 diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer index 4f1f08989c7..e5a3befbe3e 160000 --- a/sgl-kernel/3rdparty/flashinfer +++ b/sgl-kernel/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 4f1f08989c71f92df181e346548c2ca48ae6daf5 +Subproject commit e5a3befbe3e63025f0158bc96b218a9c5f402ac7 diff --git a/sgl-kernel/THIRDPARTYNOTICES.txt b/sgl-kernel/THIRDPARTYNOTICES.txt index c930aa5dd3d..fcae14df3aa 100644 --- a/sgl-kernel/THIRDPARTYNOTICES.txt +++ b/sgl-kernel/THIRDPARTYNOTICES.txt @@ -223,3 +223,208 @@ BSD 3-Clause "New" License 3rdparty/cutlass include/flashinfer/attention/hopper/block_sparse_gather.cuh + +Notice for NVIDIA/TensorRT-LLM +------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index aca6f045054..bb7d6943348 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.3" +version = "0.0.3.post1" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index f887f5c19f0..9a93ae99229 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -1,5 +1,21 @@ +# Copyright 2025 SGLang Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import multiprocessing import os +import sys from pathlib import Path import torch @@ -9,14 +25,8 @@ root = Path(__file__).parent.resolve() -def _update_wheel_platform_tag(): - wheel_dir = Path("dist") - if wheel_dir.exists() and wheel_dir.is_dir(): - old_wheel = next(wheel_dir.glob("*.whl")) - new_wheel = wheel_dir / old_wheel.name.replace( - "linux_x86_64", "manylinux2014_x86_64" - ) - old_wheel.rename(new_wheel) +if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv: + sys.argv.extend(["--plat-name", "manylinux2014_x86_64"]) def _get_cuda_version(): @@ -162,5 +172,3 @@ def _get_version(): }, options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) - -_update_wheel_platform_tag() diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py new file mode 100644 index 00000000000..6530cd7c743 --- /dev/null +++ b/sgl-kernel/setup_rocm.py @@ -0,0 +1,92 @@ +# Copyright 2025 SGLang Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing +import os +import sys +from pathlib import Path + +import torch +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +root = Path(__file__).parent.resolve() + +if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv: + sys.argv.extend(["--plat-name", "manylinux2014_x86_64"]) + + +def _get_version(): + with open(root / "pyproject.toml") as f: + for line in f: + if line.startswith("version"): + return line.split("=")[1].strip().strip('"') + + +operator_namespace = "sgl_kernels" +include_dirs = [ + root / "src" / "sgl-kernel" / "include", + root / "src" / "sgl-kernel" / "csrc", +] + +sources = [ + "src/sgl-kernel/torch_extension_rocm.cc", + "src/sgl-kernel/csrc/moe_align_kernel.cu", +] + +cxx_flags = ["-O3"] +libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"] +extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] + +hipcc_flags = [ + "-DNDEBUG", + f"-DOPERATOR_NAMESPACE={operator_namespace}", + "-O3", + "-Xcompiler", + "-fPIC", + "-std=c++17", + "-D__HIP_PLATFORM_AMD__=1", + "--amdgpu-target=gfx942", + "-DENABLE_BF16", + "-DENABLE_FP8", +] + +setup( + name="sgl-kernel", + version=_get_version(), + packages=find_packages(), + package_dir={"": "src"}, + ext_modules=[ + CUDAExtension( + name="sgl_kernel.ops._kernels", + sources=sources, + include_dirs=include_dirs, + extra_compile_args={ + "nvcc": hipcc_flags, + "cxx": cxx_flags, + }, + libraries=libraries, + extra_link_args=extra_link_args, + py_limited_api=True, + ), + ], + cmdclass={ + "build_ext": BuildExtension.with_options( + use_ninja=True, max_jobs=multiprocessing.cpu_count() + ) + }, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, + install_requires=["torch"], +) diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h index c83cf49ad83..f5cd4381563 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Adapted from // https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h index 33e82decc2b..3de9ff078b6 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Adapted from // https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h #pragma once diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h index 674e191a077..11fc872505f 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Adapted from // https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu index 3e33e143c0c..36b9585f349 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Adapted from // https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h // https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu index 4c4ecb966ee..a4ae14ae59d 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu @@ -1,115 +1,25 @@ -// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh -// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu -// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0 +/* Copyright 2025 SGLang Team. All Rights Reserved. -#include - -#include -#include -#include -#include - -#include "utils.h" - -using namespace flashinfer; - -template -__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, T* __restrict__ weight, - const uint32_t d, float eps) { - const uint32_t bx = blockIdx.x; - const uint32_t tx = threadIdx.x, ty = threadIdx.y; - constexpr uint32_t warp_size = 32; - const uint32_t num_warps = blockDim.y; - const uint32_t thread_id = tx + ty * warp_size; - const uint32_t num_threads = num_warps * warp_size; - const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); - extern __shared__ float smem[]; - - float sum_sq = 0.f; - - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - input_vec.fill(0.f); - vec_t residual_vec; - residual_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - float x = float(input_vec[j]); - x += float(residual_vec[j]); - sum_sq += x * x; - residual_vec[j] = (T)x; - } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } - } +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - // first, warp reduce sum -#pragma unroll - for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); - } + http://www.apache.org/licenses/LICENSE-2.0 - smem[ty] = sum_sq; - __syncthreads(); - // then, cross warp reduce sum using only the first warp - if (ty == 0) { - sum_sq = (tx < num_warps) ? smem[tx] : 0.f; -#pragma unroll - for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); - } - smem[0] = sum_sq; - } - __syncthreads(); +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ - float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); - - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - vec_t weight_vec; - vec_t residual_vec; - input_vec.fill(0.f); - weight_vec.fill(0.f); - residual_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]); - } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } - } -} - -template -cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, float eps = 1e-5, - cudaStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); +#include - const uint32_t block_size = std::min(1024, d / vec_size); - const uint32_t num_warps = ceil_div(block_size, 32); - dim3 nblks(batch_size); - dim3 nthrs(32, num_warps); - const uint32_t smem_size = num_warps * sizeof(float); - void* args[] = {&input, &residual, &weight, &d, &eps}; +#include - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = FusedAddRMSNormKernel; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); +#include "utils.h" - return cudaSuccess; -} +using namespace flashinfer; void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) { CHECK_INPUT(input); @@ -130,9 +40,9 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); // support float16, bfloat16 and float32 DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { - cudaError_t status = - FusedAddRMSNorm(static_cast(input.data_ptr()), static_cast(residual.data_ptr()), - static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); + cudaError_t status = norm::FusedAddRMSNorm( + static_cast(input.data_ptr()), static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); return true; diff --git a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu index c77851c32b6..4a8130d667e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include diff --git a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu index e62a154cb18..e9fc1c0ecdd 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index 19e9850b51a..d51ca517599 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Adapted from https://github.com/vllm-project/vllm/blob/v0.6.5/csrc/moe/moe_align_sum_kernels.cu #include diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index 2ee0c98c91e..fa9e3a2c5d2 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // reference: // https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu /* diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu index fd0483e39ee..af129de52ef 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h #include diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index c5cc30c1888..1fdcc9c35ae 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #pragma once #include diff --git a/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh index 46522348aaf..f4b01230cf3 100644 --- a/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // reference: // https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp /* diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index 55594f7b273..b714df77543 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #pragma once #include diff --git a/sgl-kernel/src/sgl-kernel/ops/utils.py b/sgl-kernel/src/sgl-kernel/ops/utils.py index 31a6bbf9919..683748da0f5 100644 --- a/sgl-kernel/src/sgl-kernel/ops/utils.py +++ b/sgl-kernel/src/sgl-kernel/ops/utils.py @@ -1,3 +1,18 @@ +# Copyright 2025 SGLang Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, Tuple import torch diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc index 01f93199ccb..aaed142a1ef 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include diff --git a/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc b/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc new file mode 100644 index 00000000000..22f40da1091 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc @@ -0,0 +1,29 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "sgl_kernels_ops.h" + +TORCH_LIBRARY_EXPAND(sgl_kernels, m) { + // moe_align_block_size + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); +} + +REGISTER_EXTENSION(_kernels) diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py index 27fdca497c3..647733203b6 100644 --- a/sgl-kernel/version.py +++ b/sgl-kernel/version.py @@ -1 +1 @@ -__version__ = "0.0.3" +__version__ = "0.0.3.post1" diff --git a/test/srt/models/test_lora_backend.py b/test/srt/models/test_lora_backend.py new file mode 100644 index 00000000000..6d61633004c --- /dev/null +++ b/test/srt/models/test_lora_backend.py @@ -0,0 +1,183 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing as mp +import unittest + +import torch + +from sglang.test.runners import HFRunner, SRTRunner +from sglang.test.test_utils import calculate_rouge_l + +LORA_SETS = [ + {"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]}, + # {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]} +] +TORCH_DTYPES = [torch.float16] + +PROMPTS = [ + "AI is a field of computer science focused on", + """ + ### Instruction: + Tell me about llamas and alpacas + ### Response: + Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. + ### Question 2: + What do you know about llamas? + ### Answer: + """, +] + +BACKENDS = ["triton", "flashinfer"] + +prefill_tolerance: float = 5e-2 +decode_tolerance: float = 5e-2 +rouge_l_tolerance: float = 1 + + +class TestLoRABackend(unittest.TestCase): + + def run_backend( + self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens, backend + ): + print(f"=================== testing {backend} backend =======================") + base_path = lora_set["base"] + all_lora_paths = lora_set["loras"] + batch_lora_paths = [] + i = 0 + for _ in range(len(prompts)): + batch_lora_paths.append(all_lora_paths[i]) + i = (i + 1) % len(all_lora_paths) + print(f"batch lora paths={batch_lora_paths}") + with SRTRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + tp_size=tp_size, + lora_paths=all_lora_paths, + max_loras_per_batch=3, + lora_backend=backend, + disable_cuda_graph=True, + disable_radix_cache=True, + ) as srt_runner: + srt_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + with HFRunner( + base_path, torch_dtype=torch_dtype, model_type="generation" + ) as hf_runner: + hf_outputs = hf_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + model_type="generation", + ) as srt_runner: + srt_no_lora_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + with HFRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + ) as hf_runner: + hf_no_lora_outputs = hf_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + for i in range(len(prompts)): + print(f"Prompt {i} with lora path {batch_lora_paths[i]}:") + + # compare input logprobs + hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) + hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i]) + srt_no_lora_logprobs = torch.Tensor( + srt_no_lora_outputs.top_input_logprobs[i] + ) + print( + "max input diff between hf_lora and srt_lora", + torch.max(abs(hf_logprobs - srt_logprobs)), + ) + print( + "max input diff between srt_base and srt_lora", + torch.max(abs(srt_no_lora_logprobs - srt_logprobs)), + ) + print( + "max input diff between srt_base and hf_base", + torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)), + ) + print( + "max input diff between hf_lora and hf_base", + torch.max(abs(hf_logprobs - hf_no_lora_logprobs)), + ) + if hf_logprobs.shape[0] <= 100: + assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), ( + f"prefill logprobs are not all close with model_path={base_path}," + f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}" + f"prefill_tolerance={prefill_tolerance}." + f"{hf_logprobs=}, {srt_logprobs=}" + ) + + # compare output logprobs + hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i]) + print( + "max output diff between hf_lora and srt_lora", + torch.max(abs(hf_logprobs - srt_logprobs)), + "\n", + ) + if hf_logprobs.shape[0] <= 100: + assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), ( + f"decode logprobs are not all close with model_path={base_path}," + f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}" + f"decode_tolerance={decode_tolerance}." + f"{hf_logprobs=}, {srt_logprobs=}" + ) + + # compare output strings + srt_output_str = srt_outputs.output_strs[i].strip(" ") + hf_output_str = hf_outputs.output_strs[i] + print(f"srt_output_str={srt_output_str}") + print(f"hf_output_str={hf_output_str}") + rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str]) + print(f"{rouge_l_scores=}") + assert ( + rouge_l_scores[0] >= rouge_l_tolerance + ), f"ROUGE-L scores of prompt {i} outputs are greater than rouge_l_tolerance={rouge_l_tolerance}" + + def test_all(self): + for lora_set in LORA_SETS: + print(f"Testing lora set {lora_set}: ") + for torch_dtype in TORCH_DTYPES: + tp_size = 1 + max_new_tokens = 32 + for backend in BACKENDS: + self.run_backend( + PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens, backend + ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main(warnings="ignore") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 603bab957bd..039fde96a72 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -8,6 +8,7 @@ "models/test_embedding_models.py", "models/test_generation_models.py", "models/test_lora.py", + "models/test_lora_backend.py", "models/test_qwen_models.py", "models/test_reward_models.py", "sampling/penaltylib", @@ -51,7 +52,6 @@ "test_vision_llm.py", "test_vision_openai_server.py", "test_w8a8_quantization.py", - "test_fp8_kvcache.py", "test_fp8_kernel.py", ], "nightly": [ diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index 5245905f79b..d9d77a9ae24 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -58,8 +58,7 @@ def run_decode( "logprob_start_len": 0, }, ) - print(json.dumps(response.json())) - print("=" * 100) + assert response.status_code == 200, "Request failed: " + response.text def test_default_values(self): self.run_decode() @@ -112,4 +111,4 @@ def test_repetition_penalty(self): if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=3) diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index 80aeab257c3..6534a4a60d0 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -1,6 +1,8 @@ import unittest import torch +import torch.nn.functional as F +from tqdm import tqdm from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm from sglang.srt.layers.activation import SiluAndMul @@ -11,6 +13,37 @@ class TestFusedMOE(unittest.TestCase): NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] + @staticmethod + def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01): + """Create a random CUDA tensor + + Args: + shape: Tensor shape + dtype: Data type + mean: Mean value + std: Standard deviation + + Returns: + torch.Tensor: Randomly initialized CUDA tensor + """ + return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std) + + def get_tolerance(self, dtype): + """Get tolerance values for different data types + + Args: + dtype: Data type + + Returns: + tuple: (relative tolerance, absolute tolerance) + """ + if dtype == torch.float32: + return 1e-3, 1e-5 + elif dtype in [torch.float16, torch.bfloat16]: + return 1e-1, 1e-2 + else: + return 1e-2, 1e-2 # Default values for other types + def torch_naive_moe(self, a, w1, w2, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) @@ -30,23 +63,25 @@ def torch_naive_moe(self, a, w1, w2, score, topk): ).sum(dim=1) def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): + rtol, atol = self.get_tolerance(dtype) + if use_fp8_w8a8: # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 capability = torch.cuda.get_device_capability() if not (capability[0] >= 9 or capability == (8, 9)): return - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + a = self.create_random_cuda_tensor((m, k), dtype) + w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) + w2 = self.create_random_cuda_tensor((e, k, n), dtype) w1 = w1.to(torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fn) - score = torch.randn((m, e), device="cuda", dtype=dtype) + score = self.create_random_cuda_tensor((m, e), dtype) - w1_scale = torch.randn(e, dtype=torch.float32, device="cuda") - w2_scale = torch.randn(e, dtype=torch.float32, device="cuda") - a1_scale = torch.randn(1, dtype=torch.float32, device="cuda") - a2_scale = torch.randn(1, dtype=torch.float32, device="cuda") + w1_scale = self.create_random_cuda_tensor(e, torch.float32) + w2_scale = self.create_random_cuda_tensor(e, torch.float32) + a1_scale = self.create_random_cuda_tensor(1, torch.float32) + a2_scale = self.create_random_cuda_tensor(1, torch.float32) sglang_output = fused_moe( a, @@ -76,17 +111,19 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): a2_scale=a2_scale, ) - torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0) + torch.testing.assert_close(sglang_output, vllm_output, rtol=rtol, atol=atol) else: - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) + a = self.create_random_cuda_tensor((m, k), dtype) + w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) + w2 = self.create_random_cuda_tensor((e, k, n), dtype) + score = self.create_random_cuda_tensor((m, e), dtype) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = self.torch_naive_moe(a, w1, w2, score, topk) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close( + triton_output, torch_output, rtol=rtol, atol=atol + ) def test_various_configurations(self): m_values = [1, 33, 64, 222, 1024 * 128] @@ -95,31 +132,45 @@ def test_various_configurations(self): dtypes = [torch.float16, torch.bfloat16] fp8_modes = [False, True] - for m in m_values: - for n in n_values: - for k in k_values: - for e in self.NUM_EXPERTS: - for topk in self.TOP_KS: - for dtype in dtypes: - for use_fp8_w8a8 in fp8_modes: - with self.subTest( - m=m, - n=n, - k=k, - e=e, - topk=topk, - dtype=dtype, - fp8=use_fp8_w8a8, - ): - self._test_case( - m, - n, - k, - e, - topk, - dtype, - use_fp8_w8a8=use_fp8_w8a8, - ) + # Calculate total number of tests + total_tests = ( + len(m_values) + * len(n_values) + * len(k_values) + * len(self.NUM_EXPERTS) + * len(self.TOP_KS) + * len(dtypes) + * len(fp8_modes) + ) + + # Create progress bar + with tqdm(total=total_tests, desc="Running MoE tests") as pbar: + for m in m_values: + for n in n_values: + for k in k_values: + for e in self.NUM_EXPERTS: + for topk in self.TOP_KS: + for dtype in dtypes: + for use_fp8_w8a8 in fp8_modes: + with self.subTest( + m=m, + n=n, + k=k, + e=e, + topk=topk, + dtype=dtype, + fp8=use_fp8_w8a8, + ): + self._test_case( + m, + n, + k, + e, + topk, + dtype, + use_fp8_w8a8=use_fp8_w8a8, + ) + pbar.update(1) if __name__ == "__main__": diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index 34bc4b44645..6305732509b 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -62,7 +62,12 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--trust-remote-code"], + other_args=[ + "--trust-remote-code", + "--enable-torch-compile", + "--cuda-graph-max-bs", + "2", + ], ) @classmethod diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 2398af9b0a7..3617e17be2a 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -45,16 +45,20 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): max_len_in_batch = torch.max(b_seq_len, 0)[0].item() b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") - req_to_tokens = torch.empty( - (B, max_len_in_batch), dtype=torch.int32, device="cuda" - ) b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) + kv_indices = torch.zeros( + (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda" + ) + for i in range(B): - req_to_tokens[i, : b_seq_len[i]] = torch.arange( - b_start_loc[i], b_start_loc[i] + b_seq_len[i] + kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i] ) total_token_num = torch.sum(b_seq_len).item() @@ -90,9 +94,10 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): ) b_seq_len_extend = b_seq_len - b_seq_len_prefix - b_start_loc_extend = torch.zeros_like(b_seq_len) - b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) + extend_attention_fwd( q_extend, k_extend, @@ -100,11 +105,9 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): o_extend, k_buffer, v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_seq_len_extend, - b_start_loc_extend, + qo_indptr, + kv_indptr, + kv_indices, max_len_extend, ) @@ -194,10 +197,12 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): # o will have the same shape as q o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") - req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) - b_req_idx = torch.arange(B, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0) + kv_indices = torch.arange(total_tokens, device="cuda") + attn_logits = torch.empty( (B, H_Q, num_kv_splits, D + 1), dtype=torch.float32, @@ -209,9 +214,8 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -250,10 +254,12 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") - req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) - b_req_idx = torch.arange(B, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0) + kv_indices = torch.arange(total_tokens, device="cuda") + attn_logits = torch.empty( (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, @@ -265,9 +271,8 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -284,9 +289,8 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): k_buffer, v_buffer, o_grouped, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits1, num_kv_splits, sm_scale,