From 20453cef6288bdeb5998dd2eee8955f13f88ffea Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Thu, 30 Jan 2025 02:01:23 -0800 Subject: [PATCH 01/52] [test] Lower number of top logprobs to get rid of `-inf` (#3212) --- .../sampling/penaltylib/test_srt_endpoint_with_penalizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index 5245905f79..34565c9ff6 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -36,7 +36,7 @@ def tearDownClass(cls): def run_decode( self, return_logprob=True, - top_logprobs_num=5, + top_logprobs_num=3, return_text=True, n=1, **sampling_params, From c38b5fb4f45ad8dd1c4ad1b7b05170c87c0f3ea1 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 30 Jan 2025 19:32:21 +0800 Subject: [PATCH 02/52] update 3rdparty and rms norm for sgl-kernel (#3213) --- sgl-kernel/3rdparty/cutlass | 2 +- sgl-kernel/3rdparty/flashinfer | 2 +- sgl-kernel/pyproject.toml | 2 +- .../csrc/fused_add_rms_norm_kernel.cu | 113 +----------------- sgl-kernel/version.py | 2 +- 5 files changed, 8 insertions(+), 113 deletions(-) diff --git a/sgl-kernel/3rdparty/cutlass b/sgl-kernel/3rdparty/cutlass index b78588d163..bdd641790a 160000 --- a/sgl-kernel/3rdparty/cutlass +++ b/sgl-kernel/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit b78588d1630aa6643bf021613717bafb705df4ef +Subproject commit bdd641790ad49353b40ada41330552a78d2f8b5a diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer index 4f1f08989c..e5a3befbe3 160000 --- a/sgl-kernel/3rdparty/flashinfer +++ b/sgl-kernel/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 4f1f08989c71f92df181e346548c2ca48ae6daf5 +Subproject commit e5a3befbe3e63025f0158bc96b218a9c5f402ac7 diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index aca6f04505..bb7d694334 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.3" +version = "0.0.3.post1" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu index 4c4ecb966e..f0f3a51744 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu @@ -1,116 +1,11 @@ -// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh -// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu -// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0 - #include -#include -#include -#include -#include +#include #include "utils.h" using namespace flashinfer; -template -__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, T* __restrict__ weight, - const uint32_t d, float eps) { - const uint32_t bx = blockIdx.x; - const uint32_t tx = threadIdx.x, ty = threadIdx.y; - constexpr uint32_t warp_size = 32; - const uint32_t num_warps = blockDim.y; - const uint32_t thread_id = tx + ty * warp_size; - const uint32_t num_threads = num_warps * warp_size; - const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); - extern __shared__ float smem[]; - - float sum_sq = 0.f; - - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - input_vec.fill(0.f); - vec_t residual_vec; - residual_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - float x = float(input_vec[j]); - x += float(residual_vec[j]); - sum_sq += x * x; - residual_vec[j] = (T)x; - } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } - } - - // first, warp reduce sum -#pragma unroll - for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); - } - - smem[ty] = sum_sq; - __syncthreads(); - // then, cross warp reduce sum using only the first warp - if (ty == 0) { - sum_sq = (tx < num_warps) ? smem[tx] : 0.f; -#pragma unroll - for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); - } - smem[0] = sum_sq; - } - __syncthreads(); - - float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); - - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - vec_t weight_vec; - vec_t residual_vec; - input_vec.fill(0.f); - weight_vec.fill(0.f); - residual_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]); - } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } - } -} - -template -cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, float eps = 1e-5, - cudaStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t block_size = std::min(1024, d / vec_size); - const uint32_t num_warps = ceil_div(block_size, 32); - dim3 nblks(batch_size); - dim3 nthrs(32, num_warps); - const uint32_t smem_size = num_warps * sizeof(float); - void* args[] = {&input, &residual, &weight, &d, &eps}; - - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = FusedAddRMSNormKernel; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - - return cudaSuccess; -} - void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) { CHECK_INPUT(input); CHECK_INPUT(residual); @@ -130,9 +25,9 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); // support float16, bfloat16 and float32 DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { - cudaError_t status = - FusedAddRMSNorm(static_cast(input.data_ptr()), static_cast(residual.data_ptr()), - static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); + cudaError_t status = norm::FusedAddRMSNorm( + static_cast(input.data_ptr()), static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); return true; diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py index 27fdca497c..647733203b 100644 --- a/sgl-kernel/version.py +++ b/sgl-kernel/version.py @@ -1 +1 @@ -__version__ = "0.0.3" +__version__ = "0.0.3.post1" From 468d23cff971b3174c37938f74a007646f9cfb78 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 30 Jan 2025 19:47:50 +0800 Subject: [PATCH 03/52] update setup for sgl-kernel (#3214) --- sgl-kernel/setup.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index f887f5c19f..90c3cbc1d3 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -1,5 +1,6 @@ import multiprocessing import os +import sys from pathlib import Path import torch @@ -9,14 +10,8 @@ root = Path(__file__).parent.resolve() -def _update_wheel_platform_tag(): - wheel_dir = Path("dist") - if wheel_dir.exists() and wheel_dir.is_dir(): - old_wheel = next(wheel_dir.glob("*.whl")) - new_wheel = wheel_dir / old_wheel.name.replace( - "linux_x86_64", "manylinux2014_x86_64" - ) - old_wheel.rename(new_wheel) +if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv: + sys.argv.extend(["--plat-name", "manylinux2014_x86_64"]) def _get_cuda_version(): @@ -162,5 +157,3 @@ def _get_version(): }, options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) - -_update_wheel_platform_tag() From 222ce6f1da31b6bfe168513ff85b2d5cad34fb85 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 30 Jan 2025 23:04:41 +0800 Subject: [PATCH 04/52] add tensorrt_llm common and cutlass_extensions as 3rdparty (#3216) Co-authored-by: BBuf <35585791+BBuf@users.noreply.github.com> --- .clang-format-ignore | 1 + .../tensorrt_llm/common/CMakeLists.txt | 22 + .../3rdparty/tensorrt_llm/common/assert.cpp | 34 + .../tensorrt_llm/common/cublasMMWrapper.cpp | 360 +++++++ .../tensorrt_llm/common/cublasMMWrapper.h | 148 +++ .../tensorrt_llm/common/cublasVersionCheck.h | 35 + .../tensorrt_llm/common/cudaBf16Fallbacks.cuh | 313 ++++++ .../tensorrt_llm/common/cudaDriverWrapper.cpp | 187 ++++ .../tensorrt_llm/common/cudaDriverWrapper.h | 138 +++ .../tensorrt_llm/common/cudaFp8Utils.cu | 436 +++++++++ .../tensorrt_llm/common/cudaProfilerUtils.cpp | 84 ++ .../tensorrt_llm/common/cudaTypeUtils.cuh | 752 +++++++++++++++ .../common/customAllReduceUtils.h | 36 + .../3rdparty/tensorrt_llm/common/envUtils.cpp | 214 +++++ .../3rdparty/tensorrt_llm/common/envUtils.h | 60 ++ .../3rdparty/tensorrt_llm/common/logger.cpp | 70 ++ .../3rdparty/tensorrt_llm/common/mathUtils.h | 37 + .../tensorrt_llm/common/memoryUtils.cu | 906 ++++++++++++++++++ .../tensorrt_llm/common/memoryUtils.h | 292 ++++++ .../3rdparty/tensorrt_llm/common/mpiUtils.cpp | 588 ++++++++++++ .../3rdparty/tensorrt_llm/common/nvtxUtils.h | 46 + .../3rdparty/tensorrt_llm/common/opUtils.cpp | 323 +++++++ .../3rdparty/tensorrt_llm/common/opUtils.h | 215 +++++ .../tensorrt_llm/common/quantTypeUtils.cuh | 55 ++ .../tensorrt_llm/common/reduceKernelUtils.cuh | 399 ++++++++ .../3rdparty/tensorrt_llm/common/stlUtils.h | 123 +++ .../tensorrt_llm/common/stringUtils.cpp | 76 ++ .../tensorrt_llm/common/timestampUtils.cpp | 42 + .../tensorrt_llm/common/timestampUtils.h | 25 + .../tensorrt_llm/common/tllmException.cpp | 105 ++ .../3rdparty/tensorrt_llm/common/workspace.h | 87 ++ .../arch/copy_red_global.hpp | 352 +++++++ .../include/cutlass_extensions/arch/mma.h | 120 +++ .../cutlass_extensions/compute_occupancy.h | 88 ++ .../collective/epilogue_moe_finalize.hpp | 550 +++++++++++ .../epilogue/thread/fused_activations.h | 105 ++ .../epilogue_per_row_per_col_scale.h | 352 +++++++ .../threadblock/epilogue_tensor_op_int32.h | 282 ++++++ .../cutlass_extensions/epilogue_helpers.h | 141 +++ .../builders/sm90_gmma_builder_gated.inl | 221 +++++ .../collective/collective_builder_gated.hpp | 58 ++ .../gemm/collective/collective_mma_gated.hpp | 59 ++ ..._mma_gated_tma_gmma_ss_warpspecialized.hpp | 642 +++++++++++++ ..._gated_tma_gmma_ss_warpspecialized_fp8.hpp | 665 +++++++++++++ .../gemm/device/gemm_universal_base_compat.h | 438 +++++++++ .../gemm/device/splitk_gemm_grouped.h | 542 +++++++++++ .../gemm/kernel/default_fpA_intB_traits.h | 162 ++++ .../gemm/kernel/default_int8_traits.h | 57 ++ .../gemm/kernel/default_splitk_gemm_grouped.h | 207 ++++ .../gemm/kernel/fpA_intB_gemm.h | 566 +++++++++++ .../gemm/kernel/fused_moe_kernel.cuh | 218 +++++ .../gemm/kernel/fused_moe_kernel_routine.cuh | 799 +++++++++++++++ .../gemm/kernel/fused_moe_kernel_traits.cuh | 215 +++++ .../gemm/kernel/gemm_moe_problem_visitor.h | 73 ++ .../gemm/kernel/gemm_universal_gated.hpp | 70 ++ .../gemm/kernel/gemm_with_epilogue_visitor.h | 585 +++++++++++ .../gemm/kernel/mixed_gemm_B_layout.h | 143 +++ .../gemm/kernel/moe_cute_util.cuh | 185 ++++ .../gemm/kernel/moe_cutlass_kernel.h | 553 +++++++++++ .../gemm/kernel/moe_problem_visitor.h | 344 +++++++ ..._gated_tma_warpspecialized_cooperative.hpp | 646 +++++++++++++ ...emm_gated_tma_warpspecialized_pingpong.hpp | 621 ++++++++++++ .../gemm/kernel/splitk_gemm_grouped.h | 494 ++++++++++ .../gemm/threadblock/default_dq_mma.h | 125 +++ .../threadblock/default_dq_mma_multistage.h | 302 ++++++ .../threadblock/default_dq_mma_pipelined.h | 284 ++++++ .../gemm/threadblock/default_mma.h | 351 +++++++ .../gemm/threadblock/default_mma_bf16.h | 353 +++++++ .../gemm/threadblock/dq_mma_base.h | 257 +++++ .../gemm/threadblock/dq_mma_multistage.h | 110 +++ .../dq_mma_multistage_finegrained.h | 708 ++++++++++++++ .../threadblock/dq_mma_multistage_percol.h | 647 +++++++++++++ .../gemm/threadblock/dq_mma_pipelined.h | 106 ++ .../dq_mma_pipelined_finegrained.h | 486 ++++++++++ .../threadblock/dq_mma_pipelined_percol.h | 399 ++++++++ .../gemm/warp/default_mma_tensor_op.h | 107 +++ .../warp/mma_tensorop_compute_B_with_f16.h | 306 ++++++ .../gemm/warp/mma_tensorop_dequantizer.h | 463 +++++++++ .../include/cutlass_extensions/gemm_configs.h | 224 +++++ .../interleaved_numeric_conversion.h | 447 +++++++++ .../tile_interleaved_layout.h | 66 ++ .../fine_grained_scale_zero_iterator.h | 250 +++++ .../cutlass_extensions/util/gather_tensor.hpp | 181 ++++ .../cutlass_extensions/weight_only_quant_op.h | 58 ++ sgl-kernel/THIRDPARTYNOTICES.txt | 205 ++++ sgl-kernel/setup.py | 4 + 86 files changed, 23201 insertions(+) create mode 100644 .clang-format-ignore create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt create mode 100755 sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h diff --git a/.clang-format-ignore b/.clang-format-ignore new file mode 100644 index 0000000000..15c76cc457 --- /dev/null +++ b/.clang-format-ignore @@ -0,0 +1 @@ +sgl-kernel/3rdparty/tensorrt_llm/* diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt b/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt new file mode 100644 index 0000000000..e479b298db --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt @@ -0,0 +1,22 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# +file(GLOB SRCS *.cpp) +file(GLOB CU_SRCS *.cu) + +add_library(common_src OBJECT ${SRCS} ${CU_SRCS}) +set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp new file mode 100755 index 0000000000..eaaf662447 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/assert.h" + +namespace +{ + +bool initCheckDebug() +{ + auto constexpr kDebugEnabled = "TLLM_DEBUG_MODE"; + auto const debugEnabled = std::getenv(kDebugEnabled); + return debugEnabled && debugEnabled[0] == '1'; +} +} // namespace + +bool DebugConfig::isCheckDebugEnabled() +{ + static bool const debugEnabled = initCheckDebug(); + return debugEnabled; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp new file mode 100644 index 0000000000..351257f4d2 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cublasMMWrapper.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cublasVersionCheck.h" +#include + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#endif + +namespace tensorrt_llm +{ +namespace common +{ + +CublasMMWrapper::CublasMMWrapper(std::shared_ptr cublasHandle, + std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace) + : mCublasHandle(cublasHandle) + , mCublasLtHandle(cublasltHandle) + , mStream(stream) + , mCublasWorkspace(workspace) +{ +} + +CublasMMWrapper::~CublasMMWrapper() {} + +CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper) + : mCublasHandle(wrapper.mCublasHandle) + , mCublasLtHandle(wrapper.mCublasLtHandle) + , mStream(wrapper.mStream) +{ +} + +void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, int const lda, int const ldb, int const ldc, int8_t fastAcc) +{ + // -------------------------------------- + // Create descriptors for the original matrices + check_cuda_error( + cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda)); + check_cuda_error( + cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb)); + check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc)); + check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType)); + check_cuda_error(cublasLtMatmulDescSetAttribute( + mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t))); + check_cuda_error(cublasLtMatmulDescSetAttribute( + mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t))); + check_cuda_error( + cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t))); +} + +void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b) +{ + check_cuda_error( + cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*))); + check_cuda_error( + cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*))); +} + +void CublasMMWrapper::destroyDescriptors() +{ + check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc)); + check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc)); + check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc)); + check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc)); + mOperationDesc = NULL; + mADesc = NULL; + mBDesc = NULL; + mCDesc = NULL; +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc) +{ + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f); +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, + std::optional const& heuristic) +{ + if (heuristic) + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo, + (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE, + /* usingCublasLt */ true); + } + else + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false, + /* usingCublasLt */ true); + } +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + std::optional const& heuristic) +{ + if (heuristic) + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, /* hasAlgo */ (*heuristic).algo, + (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE, + /* usingCublasLt */ true); + } + else + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false, + /* usingCublasLt */ true); + } +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta) +{ + bool usingCublasLt = mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3; + + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false, + /* usingCublasLt */ usingCublasLt); +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt) +{ + half h_alpha = (half) (f_alpha); + half h_beta = (half) (f_beta); + + // TODO: default cublas libs + usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3); + bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F; + int batch_count = 1; + // fp32 use cublas as default + // fp16 use cublasLt as default + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; + + if (usingCublasLt) + { + if (hasAlgo) + { + hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo); + } + + check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C, + mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream)); + + sync_check_cuda_error(); + } + else + { + check_cuda_error(cublasSetStream(getCublasHandle(), mStream)); + check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize)); + // Go with default heuristic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+ + cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT; + check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb, + beta, C, mCType, ldc, mComputeType, static_cast(cublasAlgo))); + sync_check_cuda_error(); + } +} + +void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, + const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha, + float const f_beta) +{ + half h_alpha = (half) f_alpha; + half h_beta = (half) f_beta; + + int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + + check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, + strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType, + mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +} + +void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, + void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, + cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType) +{ + half h_alpha = (half) f_alpha; + half h_beta = (half) f_beta; + + bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + + check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda, + strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType, + mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +} + +void CublasMMWrapper::setWorkspace(void* workspace) +{ + mCublasWorkspace = workspace; +} + +void CublasMMWrapper::setFP32GemmConfig() +{ + setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F); +} + +void CublasMMWrapper::setFP16GemmConfig(cudaDataType_t outputType) +{ + setGemmConfig(CUDA_R_16F, CUDA_R_16F, outputType, CUDA_R_32F); +} + +#ifdef ENABLE_BF16 +void CublasMMWrapper::setBF16GemmConfig(cudaDataType_t outputType) +{ + setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, outputType, CUDA_R_32F); +} +#endif + +#ifdef ENABLE_FP8 +void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType) +{ + setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F); +} +#endif + +void CublasMMWrapper::setGemmConfig( + cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType) +{ + mAType = aType; + mBType = bType; + mCType = cType; + bool isFp16ComputeType = computeType == CUDA_R_16F; + if (isFp16ComputeType) + { + mComputeType = CUBLAS_COMPUTE_16F; + mScaleType = CUDA_R_16F; + } + else + { + mComputeType = CUBLAS_COMPUTE_32F; + mScaleType = CUDA_R_32F; + } +} + +CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type) +{ + if (data_type == CUDA_R_16F) + { + return HALF_DATATYPE; + } + else if (data_type == CUDA_R_32F) + { + return FLOAT_DATATYPE; + } + else if (data_type == CUDA_R_8I) + { + return INT8_DATATYPE; + } +#ifdef ENABLE_BF16 + else if (data_type == CUDA_R_16BF) + { + return BFLOAT16_DATATYPE; + } +#endif + return FLOAT_DATATYPE; +} + +void CublasMMWrapper::setStream(cudaStream_t stream) +{ + mStream = stream; +} + +bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo) +{ + TLLM_CHECK_WITH_INFO( + descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); + + int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; + + cublasLtMatmulHeuristicResult_t heurResult; + cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( + getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult); + + if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS + || heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE) + { + return false; + } + + sync_check_cuda_error(); + + return true; +} + +std::vector CublasMMWrapper::getTactics(cublasOperation_t transa, + cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc) +{ + TLLM_CHECK_WITH_INFO( + descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); + + auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc); + + sync_check_cuda_error(); + + return heuristics; +} + +std::vector CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc) +{ +#if TLLM_CUBLAS_VER_LE(11, 4, 2) + TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2."); + return {}; +#else + std::vector heuristics(200); + cublasLtMatmulPreference_t preference; + check_cuda_error(cublasLtMatmulPreferenceCreate(&preference)); + check_cuda_error(cublasLtMatmulPreferenceInit(preference)); + uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); + // Restrict reduction algorithms for numerical stability and better determinism + uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask))); +#if TLLM_CUBLAS_VER_LT(12, 0, 0) + uint32_t pointer_mode_mask = 0; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask))); +#endif + + int return_count = 0; + check_cuda_error(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, + heuristics.size(), heuristics.data(), &return_count)); + heuristics.resize(return_count); + + return heuristics; +#endif +} + +} // namespace common + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h new file mode 100644 index 0000000000..79b7c92a47 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaUtils.h" +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +class CublasMMWrapper +{ +protected: + std::shared_ptr mCublasHandle; + std::shared_ptr mCublasLtHandle; + + cudaDataType_t mAType{}; + cudaDataType_t mBType{}; + cudaDataType_t mCType{}; + cublasComputeType_t mComputeType{}; + cudaDataType_t mScaleType{}; + + cublasLtMatmulDesc_t mOperationDesc{NULL}; + cublasLtMatrixLayout_t mADesc{NULL}; + cublasLtMatrixLayout_t mBDesc{NULL}; + cublasLtMatrixLayout_t mCDesc{NULL}; + + cudaStream_t mStream; + + void* mCublasWorkspace = nullptr; + +private: + bool descriptorsCreated() const + { + return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL; + } + +public: + CublasMMWrapper(std::shared_ptr cublasHandle, std::shared_ptr cublasLtHandle, + cudaStream_t stream, void* workspace); + + ~CublasMMWrapper(); + + CublasMMWrapper(CublasMMWrapper const& wrapper); + + /********************** GEMMs **********************/ + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, + std::optional const& algo); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + std::optional const& algo); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt); + + void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB, + void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f, + float const f_beta = 0.0f); + + void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B, + cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType, + int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType); + + /********************** Tactic selection helpers **********************/ + bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo); + + std::vector getTactics(cublasOperation_t transa, cublasOperation_t transb, + int const m, int const n, int const k, int const lda, int const ldb, int const ldc); + + std::vector getTactics(cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc); + + using MatrixLayout = std::tuple; + using cache_idx_t = std::tuple>; + + MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc); + + /********************** Utils **********************/ + void setWorkspace(void* workspace); + + void setFP32GemmConfig(); + void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F); +#ifdef ENABLE_BF16 + void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF); +#endif +#ifdef ENABLE_FP8 + void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F); +#endif + + void setStream(cudaStream_t stream); + + void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType); + + CublasDataType getCublasDataType(cudaDataType_t data_type); + + void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + int const lda, int const ldb, int const ldc, int8_t fastAcc = 0); + void setScaleDescriptors(void* scale_a, void* scale_b); + void destroyDescriptors(); + + cublasHandle_t getCublasHandle() + { + return *(this->mCublasHandle); + } + + cublasLtHandle_t getCublasLtHandle() const + { + return *(this->mCublasLtHandle); + } +}; + +} // namespace common + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h new file mode 100644 index 0000000000..1ee72c6356 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// We don't want to include cublas_api.h. It contains the CUBLAS_VER_* macro +// definition which is not sufficient to determine if we include cublas.h, +// cublas_v2.h or cublasLt.h. + +#define TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) (MAJOR * 10000 + MINOR * 100 + PATCH) +#define TLLM_CUBLAS_VER_LE(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + <= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) +#define TLLM_CUBLAS_VER_LT(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + < TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) +#define TLLM_CUBLAS_VER_GE(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + >= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) +#define TLLM_CUBLAS_VER_GT(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + > TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh new file mode 100644 index 0000000000..0519251e6f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh @@ -0,0 +1,313 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaBf16Wrapper.h" +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +#ifdef ENABLE_BF16 +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; +#else + return __bfloat1622float2(val); +#endif +} + +inline __device__ int16_t bf1622int16(__nv_bfloat162 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = max(min(__low2float(val), 127.f), -128.f); + f_val.y = max(min(__high2float(val), 127.f), -128.f); + + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = static_cast(static_cast(f_val.x)); + int8[1] = static_cast(static_cast(f_val.y)); + return int16; +#else + val = __hmin2(val, make_bfloat162(127., 127.)); + val = __hmax2(val, make_bfloat162(-128., -128.)); + + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = static_cast(static_cast(val.x)); + int8[1] = static_cast(static_cast(val.y)); + return int16; +#endif +} + +inline __device__ __nv_bfloat162 float22bf162(const float2 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __floats2bfloat162_rn(val.x, val.y); +#else + return __float22bfloat162_rn(val); +#endif +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + __nv_bfloat162 val2; + val2.x = val; + val2.y = val; + return val2; +#else + return __bfloat162bfloat162(val); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); +#else + return __hadd2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y)); +#else + return __hadd(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); +#else + return __hsub2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y)); +#else + return __hsub(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); +#else + return __hmul2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y)); +#else + return __hmul(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh, fzl, fzh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + fzl = __low2float(z); + fzh = __high2float(z); + return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); +#else + return __hfma2(x, y, z); +#endif +} + +inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); +#else + return __hfma(x, y, z); +#endif +} + +inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh; + fxl = __low2float(x); + fxh = __high2float(x); + ; + return __floats2bfloat162_rn(expf(fxl), expf(fxh)); +#else + return h2exp(x); +#endif +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) +#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020) + +inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ + __nv_bfloat162 t; + t.x = x; + t.y = y; + return t; +} +#endif +#endif + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); +#else + return (__nv_bfloat16) ((float) a + (float) b + (float) c + (float) d); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + fdl = __low2float(d); + fdh = __high2float(d); + return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); +#else + return a * b * c + d; +#endif +} + +#endif // ENABLE_BF16 + +} // namespace common +} // namespace tensorrt_llm + +// Operator definitions intentionally in global namespace +namespace +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) +#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020) + +inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ + return tensorrt_llm::common::bf16hmul2(x, y); +}; + +inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ + return tensorrt_llm::common::bf16hadd2(x, y); +}; +#endif +#endif +} // namespace diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp new file mode 100644 index 0000000000..7eca46a1ca --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define CUDA_LIB_NAME "cuda" + +#if defined(_WIN32) +#include +#define dllOpen(name) LoadLibrary("nv" name ".dll") +#define dllClose(handle) FreeLibrary(static_cast(handle)) +#define dllGetSym(handle, name) static_cast(GetProcAddress(static_cast(handle), name)) +#else // For non-Windows platforms +#include +#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY) +#define dllClose(handle) dlclose(handle) +#define dllGetSym(handle, name) dlsym(handle, name) +#endif // defined(_WIN32) + +#include "cudaDriverWrapper.h" +#include "tensorrt_llm/common/assert.h" +#include +#include + +namespace tensorrt_llm::common +{ + +std::shared_ptr CUDADriverWrapper::getInstance() +{ + static std::mutex mutex; + static std::weak_ptr instance; + std::shared_ptr result = instance.lock(); + if (result) + { + return result; + } + + std::lock_guard lock(mutex); + result = instance.lock(); + if (!result) + { + result = std::shared_ptr(new CUDADriverWrapper()); + instance = result; + } + return result; +} + +CUDADriverWrapper::CUDADriverWrapper() + : handle(dllOpen(CUDA_LIB_NAME)) +{ + + TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); + + auto load_sym = [](void* handle, char const* name) + { + void* ret = dllGetSym(handle, name); + return ret; + }; + + *reinterpret_cast(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName"); + *reinterpret_cast(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage"); + *reinterpret_cast(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute"); + *reinterpret_cast(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete"); + *reinterpret_cast(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload"); + *reinterpret_cast(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy"); + *reinterpret_cast(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData"); + *reinterpret_cast(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2"); + *reinterpret_cast(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction"); + *reinterpret_cast(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2"); + *reinterpret_cast(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2"); + *reinterpret_cast(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2"); + *reinterpret_cast(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel"); + *reinterpret_cast(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel"); + *reinterpret_cast(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled"); + *reinterpret_cast(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2"); +} + +CUDADriverWrapper::~CUDADriverWrapper() +{ + dllClose(handle); +} + +CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const +{ + return (*_cuGetErrorName)(error, pStr); +} + +CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const +{ + return (*_cuGetErrorMessage)(error, pStr); +} + +CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const +{ + return (*_cuFuncSetAttribute)(hfunc, attrib, value); +} + +CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const +{ + return (*_cuLinkComplete)(state, cubinOut, sizeOut); +} + +CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const +{ + return (*_cuModuleUnload)(hmod); +} + +CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const +{ + return (*_cuLinkDestroy)(state); +} + +CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const +{ + return (*_cuModuleLoadData)(module, image); +} + +CUresult CUDADriverWrapper::cuLinkCreate( + unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const +{ + return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); +} + +CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const +{ + return (*_cuModuleGetFunction)(hfunc, hmod, name); +} + +CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const +{ + return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); +} + +CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, + unsigned int numOptions, CUjit_option* options, void** optionValues) const +{ + return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); +} + +CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, + char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const +{ + return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); +} + +CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const +{ + return (*_cuLaunchCooperativeKernel)( + f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams); +} + +CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const +{ + return (*_cuLaunchKernel)( + f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra); +} + +CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, + cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const +{ + return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides, + boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill); +} + +CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const +{ + return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount); +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h new file mode 100644 index 0000000000..c4d470a85f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CUDA_DRIVER_WRAPPER_H +#define CUDA_DRIVER_WRAPPER_H + +#include "tensorrt_llm/common/assert.h" +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +class CUDADriverWrapper +{ +public: + static std::shared_ptr getInstance(); + + ~CUDADriverWrapper(); + CUDADriverWrapper(CUDADriverWrapper const&) = delete; + CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete; + CUDADriverWrapper(CUDADriverWrapper&&) = delete; + CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete; + + CUresult cuGetErrorName(CUresult error, char const** pStr) const; + + CUresult cuGetErrorMessage(CUresult error, char const** pStr) const; + + CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; + + CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const; + + CUresult cuModuleUnload(CUmodule hmod) const; + + CUresult cuLinkDestroy(CUlinkState state) const; + + CUresult cuModuleLoadData(CUmodule* module, void const* image) const; + + CUresult cuLinkCreate( + unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; + + CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const; + + CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const; + + CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, + CUjit_option* options, void** optionValues) const; + + CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name, + unsigned int numOptions, CUjit_option* options, void** optionValues) const; + + CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const; + + CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, + CUstream hStream, void** kernelParams, void** extra) const; + + CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, + void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim, + cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const; + + CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const; + +private: + void* handle; + CUDADriverWrapper(); + + CUresult (*_cuGetErrorName)(CUresult, char const**); + CUresult (*_cuGetErrorMessage)(CUresult, char const**); + CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); + CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); + CUresult (*_cuModuleUnload)(CUmodule); + CUresult (*_cuLinkDestroy)(CUlinkState); + CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); + CUresult (*_cuModuleLoadData)(CUmodule*, void const*); + CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*); + CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*); + CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**); + CUresult (*_cuLinkAddData)( + CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**); + CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void**); + CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, + CUstream hStream, void** kernelParams, void** extra); + CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, + cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); + CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount); +}; + +template +void checkDriver( + T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line) +{ + if (result) + { + char const* errorName = nullptr; + char const* errorMsg = nullptr; + wrap.cuGetErrorName(result, &errorName); + wrap.cuGetErrorMessage(result, &errorMsg); + throw TllmException( + file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg)); + } +} + +} // namespace tensorrt_llm::common + +/* + * Macros compliant with TensorRT coding conventions + */ +#define TLLM_CU_CHECK(stat) \ + do \ + { \ + tensorrt_llm::common::checkDriver( \ + (stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \ + } while (0) + +#endif // CUDA_DRIVER_WRAPPER_H diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu new file mode 100644 index 0000000000..8e140609f2 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu @@ -0,0 +1,436 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/reduceKernelUtils.cuh" +#include +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ +#ifdef ENABLE_FP8 + +constexpr int CTA_SIZE = 256; + +template +__inline__ __device__ float scale(float a, float b) +{ + return QUANTIZE ? a / b : a * b; +} + +template +__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda) +{ + for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x) + { + + if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) + { + output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i % lda]))); + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) + { + output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i / lda]))); + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) + { + output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[0]))); + } + } +} + +template +void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream) +{ + dim3 grid(1024); + dim3 block(CTA_SIZE); + if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + scaleMatrix + <<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TOKEN) + { + scaleMatrix<<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + scaleMatrix<<>>(output, input_scale, input, numel, lda); + } + sync_check_cuda_error(); +} + +template +void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream) +{ + dim3 grid(1024); + dim3 block(CTA_SIZE); + if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + scaleMatrix + <<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TOKEN) + { + scaleMatrix<<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + scaleMatrix + <<>>(output, input_scale, input, numel, lda); + } + sync_check_cuda_error(); +} + +template +__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel) +{ + for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x) + { + T_FAKE tmp = (T_FAKE) (static_cast(src[tid])); + dst[tid] = (T_OUT) (static_cast(tmp)); + } +} + +template +void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream) +{ + fakeQuantize<<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel); + sync_check_cuda_error(); +} + +template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>( + float* dst, float const* src, const int64_t numel, cudaStream_t stream); +template void invokeFakeQuantize( + float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream); +template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>( + half* dst, half const* src, const int64_t numel, cudaStream_t stream); +template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>( + __nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream); + +template void invokeFakeQuantize( + half* dst, float const* src, const int64_t numel, cudaStream_t stream); + +__device__ float atomicMaxExtd(float* address, float val) +{ + assert(val >= 0); + unsigned int* address_as_u = reinterpret_cast(address); + unsigned int old = atomicMax(address_as_u, __float_as_uint(val)); + return __uint_as_float(old); +} + +template +inline __device__ T atomicMaxExtdV2(T* address, T val) +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + static_assert(std::is_same_v | std::is_same_v, "T needs to be either half or bfloat16"); + // The address in 64 bits. + uint64_t address_u64 = reinterpret_cast(address); + + // Pack the input value into 32 bits. + union + { + T v[2]; + uint16_t u[2]; + } old, tmp = {}; + + int const loc = (address_u64 & 0x2) >> 1; + tmp.v[loc] = val; + + // 4B aligned pointer. + auto aligned_address = reinterpret_cast(address_u64 & ~0x3ull); + + if constexpr (std::is_same_v) + { + asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};" + : "=h"(old.u[0]), "=h"(old.u[1]) + : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); + } + if constexpr (std::is_same_v) + { + asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};" + : "=h"(old.u[0]), "=h"(old.u[1]) + : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); + } + + // Return the correct half. + return old.v[loc]; +#endif +} + +__device__ half atomicMaxExtd(half* address, half val) +{ + unsigned short int* address_as_u = reinterpret_cast(address); + unsigned short int old = *address_as_u, assumed; + + while (val > __ushort_as_half(old)) + { + assumed = old; + old = atomicCAS(address_as_u, assumed, __half_as_ushort(val)); + } + + return __ushort_as_half(old); +} + +__device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val) +{ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + unsigned short int* address_as_u = reinterpret_cast(address); + unsigned short int old = *address_as_u, assumed; + + while (val > __ushort_as_bfloat16(old)) + { + assumed = old; + old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val)); + } + + return __ushort_as_bfloat16(old); +#else + assert(0); + asm volatile("brkpt;\n" ::); + return __nv_bfloat16(0); +#endif +} + +template +__global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n) +{ + constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); + if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) + { + for (int64_t col = threadIdx.x; col < n; col += blockDim.x) + { + float max = 0.f; + for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n) + { + auto val = fabs(static_cast(weights[i])); + max = max > val ? max : val; + } + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + if constexpr (std::is_same_v) + { + atomicMaxExtd(quant_ptr + col, scale); + } + else + { + auto const address_u64 = reinterpret_cast(quant_ptr + col); + if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0)) + atomicMaxExtd(quant_ptr + col, scale); + else + atomicMaxExtdV2(quant_ptr + col, scale); + } +#else // Vector atomics require __CUDA_ARCH__ >= 900 + atomicMaxExtd(quant_ptr + col, scale); +#endif + } + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) + { + auto const nrows = size / n; + for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) + { + float max = 0.f; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) + { + auto val = fabs(static_cast(weights[row * n + i])); + max = max > val ? max : val; + } + max = blockReduceMax(max); + if (threadIdx.x == 0) + { + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + quant_ptr[row] = scale; + } + } + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) + { + float max = 0.f; + for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x) + { + auto val = fabs(static_cast(weights[i])); + max = max > val ? max : val; + } + max = blockReduceMax(max); + if (threadIdx.x == 0) + { + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + atomicMaxExtd(quant_ptr, scale); + } + } +} + +template +void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream) +{ + if (quantize_mode == QuantizeMode::PER_TOKEN) + { + dim3 block(CTA_SIZE); + dim3 grid(numel / lda); + computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + dim3 block(CTA_SIZE); + dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); + cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + dim3 block(1024); + dim3 grid(1024); + cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); + } + sync_check_cuda_error(); +} + +#define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \ + template void invokeComputeFP8QuantizeScale(type_scale * input_scale, type_in const* weights, \ + int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream); + +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half); +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half); +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float); +#ifdef ENABLE_BF16 +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16); +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16); +#endif + +template +__global__ void dynamicQuantizeMatrixPerToken( + T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda) +{ + extern __shared__ __align__(sizeof(float)) char _shmem[]; + T_IN* shmem = reinterpret_cast(_shmem); + constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); + auto const nrows = numel / lda; + for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) + { + float max = 0.f; + for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) + { + auto const in = input[row * lda + i]; + shmem[i] = in; + auto val = fabs(static_cast(in)); + max = max > val ? max : val; + } + max = blockAllReduceMax(max); // __syncthreads() called so we can read shmem + auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) + { + // true means we are quantizing + output[row * lda + i] = (T_OUT) scale(static_cast(shmem[i]), static_cast(s)); + } + if (threadIdx.x == 0) + { + quant_ptr[row] = s; + } + } +} + +template +void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel, + const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream) +{ + if (quantize_mode == QuantizeMode::PER_TOKEN) + { + dim3 grid(numel / lda); + bool use_shmem = true; + auto const shmem_size = lda * sizeof(T_IN); + if (shmem_size >= (48 << 10)) + { + cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + use_shmem = ret == cudaSuccess; + } + if (use_shmem) + { + // ensure the threadblock is as large as possible to increase occupancy + dim3 block(std::min((lda + 31) / 32 * 32, static_cast(1024))); + dynamicQuantizeMatrixPerToken<<>>(output, quant_ptr, input, numel, lda); + } + else + { + dim3 block(CTA_SIZE); + computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); + sync_check_cuda_error(); + invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); + } + } + else if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + dim3 block(CTA_SIZE); + dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); + cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); + sync_check_cuda_error(); + invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + dim3 block(1024); + dim3 grid(1024); + cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); + sync_check_cuda_error(); + invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); + } + sync_check_cuda_error(); +} + +#define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \ + template void invokeQuantizeMatrix(type_out * output, \ + type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ + cudaStream_t stream); \ + template void invokeDequantizeMatrix(type_out * output, \ + type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ + cudaStream_t stream); \ + template void invokeComputeScalesAndQuantizeMatrix(type_out * output, \ + type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ + cudaStream_t stream); + +#ifdef ENABLE_FP8 +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float); +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half); +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half); +DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3); +DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3); +DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3); +#ifdef ENABLE_BF16 +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16); +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3); +#endif +#endif + +#endif // ENABLE_FP8 +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp new file mode 100644 index 0000000000..5576fe782f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cudaProfilerUtils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/stringUtils.h" +#include +#include + +namespace +{ + +std::tuple, std::unordered_set> populateIterationIndexesImpl( + std::string const& envVarName) +{ + auto envVarVal = std::getenv(envVarName.c_str()); + auto envVarValStr = std::string{envVarVal != nullptr ? envVarVal : ""}; + auto values = tensorrt_llm::common::str2set(envVarValStr, ','); + std::unordered_set startSet; + std::unordered_set endSet; + for (std::string const& value : values) + { + size_t dashIdx = value.find("-"); + if (dashIdx != std::string::npos) + { + int32_t start = std::stoi(value.substr(0, dashIdx)); + startSet.insert(start); + int32_t end = std::stoi(value.substr(dashIdx + 1)); + endSet.insert(end); + } + else + { + int32_t start_end = std::stoi(value); + startSet.insert(start_end); + endSet.insert(start_end); + } + } + + return std::make_pair(startSet, endSet); +} + +} // namespace + +namespace tensorrt_llm::common +{ + +std::pair, std::unordered_set> populateIterationIndexes( + std::string const& envVarName, std::optional const& legacyEnvVarName) +{ + auto [profileIterIdxs, stopIterIdxs] = populateIterationIndexesImpl(envVarName); + + // If empty, try to use legacy env var name + if (legacyEnvVarName && profileIterIdxs.empty() && stopIterIdxs.empty()) + { + std::tie(profileIterIdxs, stopIterIdxs) = populateIterationIndexesImpl(legacyEnvVarName.value()); + + if (!profileIterIdxs.empty() || !stopIterIdxs.empty()) + { + TLLM_LOG_WARNING( + "Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. " + "Please " + "use %s " + "instead.", + legacyEnvVarName.value().c_str(), envVarName.c_str()); + } + } + + return std::make_pair(profileIterIdxs, stopIterIdxs); +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh new file mode 100644 index 0000000000..a0463a3a49 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh @@ -0,0 +1,752 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" +#include "tensorrt_llm/common/cudaBf16Wrapper.h" +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include +#include +#include +#if ENABLE_BF16 +#include +#endif + +namespace tensorrt_llm +{ +namespace common +{ + +template +inline __device__ T ldg(T const* val) +{ + return __ldg(val); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} + +template <> +inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} +#endif // ENABLE_BF16 + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter +{ + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter +{ + using Type = half; +}; + +template <> +struct TypeConverter +{ + using Type = half2; +}; + +#if ENABLE_BF16 +template <> +struct TypeConverter<__nv_bfloat162> +{ + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> +{ + using Type = __nv_bfloat162; +}; +#endif // ENABLE_BF16 + +// Defined math operations (bfloat16 fallback to fp32 when it is not supported) +template +inline __device__ T hadd2(T a, T b) +{ + return __hadd2(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hadd2(a, b); +} +#endif // ENABLE_BF16 + +template +inline __device__ T add(T a, T b) +{ + return a + b; +} + +template <> +inline __device__ half2 add(half2 a, half2 b) +{ + return __hadd2(a, b); +} + +template <> +inline __device__ half add(half a, half b) +{ + return __hadd(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hadd2(a, b); +} + +template <> +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) +{ + return bf16hadd(a, b); +} + +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b) +{ + return bf16hadd(a, __float2bfloat16(b)); +} +#endif // ENABLE_BF16 + +// applies to all 4 values addition +template +inline __device__ T add(T a, T b, T c) +{ + return a + b + c; +} + +#if ENABLE_BF16 +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ + return bf16hadd(a, b, c); +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hadd2(a, b, c); +} +#endif // ENABLE_BF16 + +// applies to all 4 values addition +template +inline __device__ T add(T a, T b, T c, T d) +{ + return (T) ((float) a + (float) b + (float) c + (float) d); +} + +#if ENABLE_BF16 +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) +{ + return bf16hadd(a, b, c, d); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hsub2(T a, T b) +{ + return __hsub2(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hsub2(a, b); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hmul2(T a, T b) +{ + return __hmul2(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hmul2(a, b); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hmul2(T a, T b, T c) +{ + return a * b * c; +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hmul2(a, b, c); +} +#endif // ENABLE_BF16 + +template +inline __device__ T mul(T a, T b, T c) +{ + return a * b * c; +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ + return bf16hmul(a, b, c); +} + +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hmul2(a, b, c); +} +#endif // ENABLE_BF16 + +template +inline __device__ T fma(T a, T b, T c, T d) +{ + return a * b * c + d; +} + +#if ENABLE_BF16 +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) +{ + return bf16hfma2(a, b, c, d); +} +#endif // ENABLE_BF16 + +template +inline __device__ T fma(T a, T b, T c) +{ + return a * b + c; +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hfma2(a, b, c); +} + +template <> +inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ + return bf16hfma(a, b, c); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hexp2(T a) +{ + return h2exp(a); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a) +{ + return bf16exp2(a); +} +#endif // ENABLE_BF16 + +template +__device__ inline T_OUT cuda_cast(T_IN val) +{ + return val; +} + +template <> +__device__ inline float2 cuda_cast(int2 val) +{ + return make_float2(val.x, val.y); +} + +template <> +__device__ inline float2 cuda_cast(float val) +{ + return make_float2(val, val); +} + +template <> +__device__ inline float2 cuda_cast(half2 val) +{ + return __half22float2(val); +} + +template <> +__device__ inline half2 cuda_cast(float2 val) +{ + return __float22half2_rn(val); +} + +template <> +__device__ inline half2 cuda_cast(float val) +{ + return __float2half2_rn(val); +} + +template <> +__device__ inline half2 cuda_cast(half val) +{ + return __half2half2(val); +} + +template <> +__device__ inline int8_t cuda_cast(half val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + union + { + half fp16; + int16_t int16_in; + }; + + fp16 = val; + asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); + return int8[0]; +} + +template <> +__device__ inline int16_t cuda_cast(half2 val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = cuda_cast(val.x); + int8[1] = cuda_cast(val.y); + return int16; +} + +template <> +__device__ inline int8_t cuda_cast(float val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); + return int8[0]; +} + +template <> +__device__ inline int16_t cuda_cast(float2 val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = cuda_cast(val.x); + int8[1] = cuda_cast(val.y); + return int16; +} + +template <> +__device__ inline half2 cuda_cast(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + return make_half2(int8[0], int8[1]); +} + +template <> +__device__ inline float2 cuda_cast(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + return make_float2(int8[0], int8[1]); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat16 cuda_cast(int32_t val) +{ + return static_cast(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast(int8_t val) +{ + return static_cast(val); +} + +template <> +__device__ inline int8_t cuda_cast(__nv_bfloat16 val) +{ + return static_cast(val); +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) +{ + return __bfloat162float(val); +} + +template <> +__device__ inline float2 cuda_cast(__nv_bfloat162 val) +{ + return bf1622float2(val); +} + +template <> +__device__ inline half cuda_cast(__nv_bfloat16 val) +{ + return __float2half(__bfloat162float(val)); +} + +template <> +__device__ inline int16_t cuda_cast(__nv_bfloat162 val) +{ + return bf1622int16(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) +{ + return __float2bfloat16(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) +{ + return __float2bfloat16(__half2float(val)); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) +{ + return bf162bf162(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) +{ + return __float2bfloat162_rn(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) +{ + return float22bf162(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + __nv_bfloat162 res; + res.x = cuda_cast<__nv_bfloat16>(int8[0]); + res.y = cuda_cast<__nv_bfloat16>(int8[1]); + return res; +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) +{ + return float22bf162(__half22float2(val)); +} + +#endif // ENABLE BF16 + +template +__device__ inline T cuda_abs(T val) +{ + assert(false); + return {}; +} + +template <> +__device__ inline float cuda_abs(float val) +{ + return fabs(val); +} + +template <> +__device__ inline float2 cuda_abs(float2 val) +{ + return make_float2(fabs(val.x), fabs(val.y)); +} + +template <> +__device__ inline half cuda_abs(half val) +{ + return __habs(val); +} + +template <> +__device__ inline half2 cuda_abs(half2 val) +{ + return __habs2(val); +} + +#ifdef ENABLE_BF16 + +#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) +template <> +__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) +{ + return __habs(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) +{ + return __habs2(val); +} +#endif + +#endif // ENABLE_FP16 + +template +__device__ inline To cuda_sum(Ti val) +{ + return cuda_cast(val); +}; + +template +__device__ inline To cuda_sum(float2 val) +{ + return cuda_cast(val.x + val.y); +}; + +// Unary maximum: compute the max of a vector type +template +__device__ inline To cuda_max(Ti val) +{ + return cuda_cast(val); +}; + +template <> +__device__ inline float cuda_max(float2 val) +{ + return fmaxf(val.x, val.y); +} + +template <> +__device__ inline half cuda_max(half2 val) +{ + return __hmax(val.x, val.y); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) +{ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __hmax(val.x, val.y); +#else + assert(0); + asm volatile("brkpt;\n" ::); + return __nv_bfloat16(0); +#endif +} +#endif + +// Binary maximum: compute the max of two values. +template +__device__ inline T cuda_max(T val1, T val2) +{ + return (val1 > val2) ? val1 : val2; +} + +template <> +__device__ inline float2 cuda_max(float2 val1, float2 val2) +{ + float2 out; + out.x = fmaxf(val1.x, val2.x); + out.y = fmaxf(val1.y, val2.y); + return out; +} + +template <> +__device__ inline half2 cuda_max(half2 val1, half2 val2) +{ + return __hmax2(val1, val2); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2) +{ + return __hmax2(val1, val2); +} +#endif // ENABLE_BF16 + +// Binary maximum: compute the min of two values. +template +__device__ inline T cuda_min(T val1, T val2) +{ + return (val1 < val2) ? val1 : val2; +} + +template <> +__device__ inline float2 cuda_min(float2 val1, float2 val2) +{ + float2 out; + out.x = fminf(val1.x, val2.x); + out.y = fminf(val1.y, val2.y); + return out; +} + +template <> +__device__ inline half2 cuda_min(half2 val1, half2 val2) +{ + return __hmin2(val1, val2); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat162 cuda_min(__nv_bfloat162 val1, __nv_bfloat162 val2) +{ + return __hmin2(val1, val2); +} +#endif // ENABLE_BF16 + +// Helper function of clamping the val into the given range. +template +inline __device__ T cuda_clamp(T val, T minVal, T maxVal) +{ + return cuda_min(cuda_max(val, minVal), maxVal); +} + +#ifdef ENABLE_FP8 +template <> +__device__ inline float2 cuda_cast(__nv_fp8x2_e4m3 val) +{ + return bf1622float2(fp8x2_e4m3_to_bfloat2(&val)); +} + +template <> +__device__ inline half2 cuda_cast(__nv_fp8x2_e4m3 val) +{ + return fp8x2_e4m3_to_half2(&val); +} + +template <> +__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val) +{ + return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val))); +} + +template <> +__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, half2>(half2 val) +{ + return __nv_fp8x2_e4m3(cuda_cast(val)); +} + +template <> +__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, __nv_bfloat162>(__nv_bfloat162 val) +{ + return __nv_fp8x2_e4m3(cuda_cast(val)); +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val) +{ + return __nv_fp8_e4m3(val); +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val) +{ + return __nv_fp8_e4m3(val); +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val) +{ + return __nv_fp8_e4m3(val); +} + +template <> +__device__ inline float cuda_cast(__nv_fp8_e4m3 val) +{ + return (float) val; +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val) +{ + return fp8x2_e4m3_to_bfloat2(&val); +} + +template <> +__device__ inline int8_t cuda_cast(__nv_fp8_e4m3 val) +{ + // no impl + return 0; +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val) +{ + return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast(val))); +} + +#endif // ENABLE_FP8 + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h new file mode 100644 index 0000000000..d7bf43b407 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace tensorrt_llm::utils::customAllReduceUtils +{ + +constexpr size_t NUM_POINTERS_PER_RANK = 7; + +// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py +inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept +{ + if (worldSize <= 2) + { + return 16 * 1000 * 1000; + } + return 8 * 1000 * 1000; +} + +} // namespace tensorrt_llm::utils::customAllReduceUtils diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp new file mode 100644 index 0000000000..64d3d44acb --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp @@ -0,0 +1,214 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "envUtils.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" +#include + +namespace tensorrt_llm::common +{ + +std::optional getIntEnv(char const* name) +{ + char const* const env = std::getenv(name); + if (env == nullptr) + { + return std::nullopt; + } + int32_t const val = std::stoi(env); + if (val <= 0) + { + return std::nullopt; + } + return {val}; +}; + +// Returns true if the env variable exists and is set to "1" +static bool getBoolEnv(char const* name) +{ + char const* env = std::getenv(name); + return env && env[0] == '1' && env[1] == '\0'; +} + +// XQA kernels (optimized kernels for generation phase). +bool forceXQAKernels() +{ + static bool const forceXQA = (getIntEnv("TRTLLM_FORCE_XQA").value_or(0) != 0); + return forceXQA; +} + +std::optional getEnvEnableXQAJIT() +{ + static bool init = false; + static bool exists = false; + static bool enableXQAJIT = false; + if (!init) + { + init = true; + char const* enable_xqa_jit_var = std::getenv("TRTLLM_ENABLE_XQA_JIT"); + if (enable_xqa_jit_var) + { + exists = true; + if (enable_xqa_jit_var[0] == '1' && enable_xqa_jit_var[1] == '\0') + { + enableXQAJIT = true; + } + } + } + if (exists) + { + return enableXQAJIT; + } + else + { + return std::nullopt; + } +} + +// Tune the number of blocks per sequence for accuracy/performance purpose. +bool getEnvMmhaMultiblockDebug() +{ + static bool init = false; + static bool forceMmhaMaxSeqLenTile = false; + if (!init) + { + init = true; + char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG"); + if (enable_mmha_debug_var) + { + if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0') + { + forceMmhaMaxSeqLenTile = true; + } + } + } + return forceMmhaMaxSeqLenTile; +} + +int getEnvMmhaBlocksPerSequence() +{ + static bool init = false; + static int mmhaBlocksPerSequence = 0; + if (!init) + { + init = true; + char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE"); + if (mmhaBlocksPerSequenceEnv) + { + mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv); + if (mmhaBlocksPerSequence <= 0) + { + TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!"); + } + } + } + return mmhaBlocksPerSequence; +} + +int getEnvMmhaKernelBlockSize() +{ + static bool init = false; + static int mmhaKernelBlockSize = 0; + if (!init) + { + init = true; + char const* mmhaKernelBlockSizeEnv = std::getenv("TRTLLM_MMHA_KERNEL_BLOCK_SIZE"); + if (mmhaKernelBlockSizeEnv) + { + mmhaKernelBlockSize = std::atoi(mmhaKernelBlockSizeEnv); + if (mmhaKernelBlockSize <= 0) + { + TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_KERNEL_BLOCK_SIZE. Will use default values instead!"); + } + } + } + return mmhaKernelBlockSize; +} + +bool getEnvEnablePDL() +{ + static bool init = false; + static bool enablePDL = false; + if (!init) + { + init = true; + // PDL only available when arch >= 90 + if (getSMVersion() >= 90) + { + // PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1` + enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL"); + } + } + return enablePDL; +} + +bool getEnvUseUCXKvCache() +{ + static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE"); + return useUCXKVCache; +} + +std::string getEnvUCXInterface() +{ + static bool init = false; + static std::string ucxInterface; + if (!init) + { + init = true; + { + char const* ucx_interface = std::getenv("TRTLLM_UCX_INTERFACE"); + if (ucx_interface) + { + ucxInterface = ucx_interface; + } + } + } + return ucxInterface; +} + +bool getEnvDisaggLayerwise() +{ + static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE"); + return disaggLayerwise; +} + +bool getEnvParallelCacheSend() +{ + static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND"); + return parallelCacheSend; +} + +bool getEnvRequestKVCacheSerial() +{ + static bool const requestKVCacheSerial = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_SERIAL"); + return requestKVCacheSerial; +} + +bool getEnvDisableKVCacheTransferOverlap() +{ + static bool const disableKVCacheTransferOverlap = getBoolEnv("TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP"); + return disableKVCacheTransferOverlap; +} + +bool getEnvDisableReceiveKVCacheParallel() +{ + static bool const disableReceiveParallel = getBoolEnv("TRTLLM_DISABLE_KVCACHE_RECEIVE_PARALLEL"); + return disableReceiveParallel; +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h new file mode 100644 index 0000000000..027c7cfbb3 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h @@ -0,0 +1,60 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace tensorrt_llm::common +{ +// Useful when you want to inject some debug code controllable with env var. +std::optional getIntEnv(char const* name); + +// XQA kernels (optimized kernels for generation phase). +bool forceXQAKernels(); + +// Whether XQA JIT is enabled. +// +// Returns the value of TRTLLM_ENABLE_XQA_JIT env var. If such env var doesn't exist, std::nullopt is returned. +std::optional getEnvEnableXQAJIT(); + +// Tune the number of blocks per sequence for accuracy/performance purpose. +bool getEnvMmhaMultiblockDebug(); + +int getEnvMmhaBlocksPerSequence(); + +int getEnvMmhaKernelBlockSize(); + +// Whether PDL is enabled. +bool getEnvEnablePDL(); + +bool getEnvUseUCXKvCache(); + +std::string getEnvUCXInterface(); + +bool getEnvDisaggLayerwise(); + +bool getEnvParallelCacheSend(); + +bool getEnvRequestKVCacheSerial(); + +bool getEnvDisableKVCacheTransferOverlap(); + +bool getEnvDisableReceiveKVCacheParallel(); + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp new file mode 100644 index 0000000000..334ad23690 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/tllmException.h" +#include + +namespace tensorrt_llm::common +{ + +Logger::Logger() +{ + char* isFirstRankOnlyChar = std::getenv("TLLM_LOG_FIRST_RANK_ONLY"); + bool isFirstRankOnly = (isFirstRankOnlyChar != nullptr && std::string(isFirstRankOnlyChar) == "ON"); + + auto const* levelName = std::getenv("TLLM_LOG_LEVEL"); + if (levelName != nullptr) + { + auto level = [levelName = std::string(levelName)]() + { + if (levelName == "TRACE") + return TRACE; + if (levelName == "DEBUG") + return DEBUG; + if (levelName == "INFO") + return INFO; + if (levelName == "WARNING") + return WARNING; + if (levelName == "ERROR") + return ERROR; + TLLM_THROW("Invalid log level: %s", levelName.c_str()); + }(); + // If TLLM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR + if (isFirstRankOnly) + { + auto const deviceId = getDevice(); + if (deviceId != 1) + { + level = ERROR; + } + } + setLevel(level); + } +} + +void Logger::log(std::exception const& ex, Logger::Level level) +{ + log(level, "%s: %s", TllmException::demangle(typeid(ex).name()).c_str(), ex.what()); +} + +Logger* Logger::getLogger() +{ + thread_local Logger instance; + return &instance; +} +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h new file mode 100644 index 0000000000..1bad3a2c15 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace tensorrt_llm +{ +namespace common +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T divUp(T m, T n) +{ + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu new file mode 100644 index 0000000000..d13217b203 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu @@ -0,0 +1,906 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaTypeUtils.cuh" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/memoryUtils.h" + +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +template +void deviceMalloc(T** ptr, size_t size, bool is_random_initialize) +{ + check_cuda_error(cudaMalloc((void**) (ptr), sizeof(T) * size)); + if (is_random_initialize) + { + cudaRandomUniform(*ptr, size); + } +} + +template void deviceMalloc(float** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(half** ptr, size_t size, bool is_random_initialize); +#ifdef ENABLE_BF16 +template void deviceMalloc(__nv_bfloat16** ptr, size_t size, bool is_random_initialize); +#endif +template void deviceMalloc(uint16_t** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(int** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(bool** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(char** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(int8_t** ptr, size_t size, bool is_random_initialize); +#ifdef ENABLE_FP8 +template void deviceMalloc(__nv_fp8_e4m3** ptr, size_t size, bool is_random_initialize); +#endif + +template +void deviceMemSetZero(T* ptr, size_t size) +{ + check_cuda_error(cudaMemset(static_cast(ptr), 0, sizeof(T) * size)); +} + +template void deviceMemSetZero(float* ptr, size_t size); +template void deviceMemSetZero(half* ptr, size_t size); +template void deviceMemSetZero(int* ptr, size_t size); +template void deviceMemSetZero(uint32_t* ptr, size_t size); +template void deviceMemSetZero(bool* ptr, size_t size); +#ifdef ENABLE_FP8 +template void deviceMemSetZero(__nv_fp8_e4m3* ptr, size_t size); +#endif +#ifdef ENABLE_BF16 +template void deviceMemSetZero(__nv_bfloat16* ptr, size_t size); +#endif + +template +void deviceFree(T*& ptr) +{ + if (ptr != NULL) + { + check_cuda_error(cudaFree(ptr)); + ptr = NULL; + } +} + +template void deviceFree(float*& ptr); +template void deviceFree(half*& ptr); +#ifdef ENABLE_BF16 +template void deviceFree(__nv_bfloat16*& ptr); +#endif +template void deviceFree(unsigned short*& ptr); +template void deviceFree(int*& ptr); +template void deviceFree(bool*& ptr); +template void deviceFree(char*& ptr); +template void deviceFree(int8_t*& ptr); +#ifdef ENABLE_FP8 +template void deviceFree(__nv_fp8_e4m3*& ptr); +#endif + +template +void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream) +{ + T* arr = new T[size]; + std::fill(arr, arr + size, value); + check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream)); + delete[] arr; +} + +template void deviceFill(float* devptr, size_t size, float value, cudaStream_t stream); +template void deviceFill(half* devptr, size_t size, half value, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void deviceFill(__nv_bfloat16* devptr, size_t size, __nv_bfloat16 value, cudaStream_t stream); +#endif +template void deviceFill(int* devptr, size_t size, int value, cudaStream_t stream); +template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream); + +template +void cudaD2Hcpy(T* tgt, T const* src, const size_t size) +{ + check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost)); +} + +template void cudaD2Hcpy(float* tgt, float const* src, size_t size); +template void cudaD2Hcpy(half* tgt, half const* src, size_t size); +#ifdef ENABLE_BF16 +template void cudaD2Hcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); +#endif +template void cudaD2Hcpy(int* tgt, int const* src, size_t size); +template void cudaD2Hcpy(bool* tgt, bool const* src, size_t size); +#ifdef ENABLE_FP8 +template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); +#endif +template void cudaD2Hcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); +template void cudaD2Hcpy(unsigned int* tgt, unsigned int const* src, size_t size); +template void cudaD2Hcpy(int8_t* tgt, int8_t const* src, size_t size); + +template +void cudaH2Dcpy(T* tgt, T const* src, const size_t size) +{ + check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice)); +} + +template void cudaH2Dcpy(float* tgt, float const* src, size_t size); +template void cudaH2Dcpy(half* tgt, half const* src, size_t size); +#ifdef ENABLE_BF16 +template void cudaH2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); +#endif +template void cudaH2Dcpy(int* tgt, int const* src, size_t size); +template void cudaH2Dcpy(bool* tgt, bool const* src, size_t size); +#ifdef ENABLE_FP8 +template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); +#endif +template void cudaH2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); +template void cudaH2Dcpy(unsigned int* tgt, unsigned int const* src, size_t size); +template void cudaH2Dcpy(int8_t* tgt, int8_t const* src, size_t size); + +template +void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) +{ + check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream)); +} + +template void cudaD2Dcpy(float* tgt, float const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(half* tgt, half const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void cudaD2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); +#endif +template void cudaD2Dcpy(int* tgt, int const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_FP8 +template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size, cudaStream_t stream); +#endif +template void cudaD2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); + +template +__global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (T_OUT) ((float) (src[tid])); + } +} + +template +void invokeCudaCast(T_OUT* dst, T_IN const* const src, const size_t size, cudaStream_t stream) +{ + cudaCast<<<256, 256, 0, stream>>>(dst, src, size); +} + +template void invokeCudaCast(float* dst, half const* const src, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeCudaCast(float* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_bfloat16* dst, float const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_bfloat16* dst, half const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(half* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_FP8 +template void invokeCudaCast(float* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast( + __nv_bfloat16* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(half* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_fp8_e4m3* dst, float const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast( + __nv_fp8_e4m3* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_fp8_e4m3* dst, half const* const src, const size_t size, cudaStream_t stream); +#endif + +template +void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) +{ + if (stream != NULL) + { + check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDefault, stream)); + } + else + { + check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDefault)); + } +} + +template void cudaAutoCpy(float* tgt, float const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(half* tgt, half const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void cudaAutoCpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); +#endif +template void cudaAutoCpy(int* tgt, int const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(uint8_t* tgt, uint8_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(uint32_t* tgt, uint32_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(unsigned long* tgt, unsigned long const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(char* tgt, char const* src, size_t size, cudaStream_t stream); + +template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void cudaAutoCpy(__nv_bfloat16 const** tgt, __nv_bfloat16 const* const* src, size_t size, cudaStream_t stream); +#endif +template void cudaAutoCpy(int const** tgt, int const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(bool const** tgt, bool const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(int8_t const** tgt, int8_t const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy( + unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream); + +template +__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((unsigned long long int) 1337, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = (T) (curand_uniform(&local_state) * 0.2f - 0.1f); + } +} + +template <> +__global__ void cuda_random_uniform_kernel(int* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = curand(&local_state); + } +} + +template <> +__global__ void cuda_random_uniform_kernel(bool* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = (curand(&local_state) % 2 == 0); + } +} + +template <> +__global__ void cuda_random_uniform_kernel(char* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = curand(&local_state) % 0xFF; + } +} + +template +void cudaRandomUniform(T* buffer, const size_t size) +{ + static int seq_offset = 0; + cuda_random_uniform_kernel<<<256, 256>>>(buffer, size, seq_offset); + seq_offset += 256 * 256; +} + +template void cudaRandomUniform(float* buffer, const size_t size); +template void cudaRandomUniform(half* buffer, const size_t size); +#ifdef ENABLE_BF16 +template void cudaRandomUniform(__nv_bfloat16* buffer, const size_t size); +#endif +template void cudaRandomUniform(int* buffer, const size_t size); +template void cudaRandomUniform(bool* buffer, const size_t size); +template void cudaRandomUniform(char* buffer, const size_t size); +#ifdef ENABLE_FP8 +template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size); +#endif + +// loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or +// the product of the elements in shape is 0, this function will return an empty vector. +template +std::vector loadWeightFromBinHelper(std::vector shape, std::string filename) +{ + if (shape.size() > 2) + { + printf("[ERROR] shape should have less than two dims \n"); + return std::vector(); + } + size_t dim0 = shape[0], dim1 = 1; + if (shape.size() == 2) + { + dim1 = shape[1]; + } + size_t size = dim0 * dim1; + if (size == 0) + { + TLLM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); + return std::vector(); + } + + std::vector host_array(size); + std::ifstream in(filename, std::ios::in | std::ios::binary); + if (!in.is_open()) + { + TLLM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); + return std::vector(); + } + + size_t loaded_data_size = sizeof(T) * size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + + TLLM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); + in.read((char*) host_array.data(), loaded_data_size); + + size_t in_get_size = in.gcount(); + if (in_get_size != loaded_data_size) + { + TLLM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n", filename.c_str(), + in_get_size, loaded_data_size); + return std::vector(); + } + in.close(); + // If we succeed, return an array with values. + return host_array; +} + +template +int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) +{ + std::vector host_array = loadWeightFromBinHelper(shape, filename); + + if (host_array.empty()) + { + return 0; + } + + if (std::is_same::value == true) + { + cudaH2Dcpy(ptr, (T*) host_array.data(), host_array.size()); + } + else + { + T_IN* ptr_2 = nullptr; + deviceMalloc(&ptr_2, host_array.size(), false); + cudaH2Dcpy(ptr_2, host_array.data(), host_array.size()); + invokeCudaD2DcpyConvert(ptr, ptr_2, host_array.size()); + deviceFree(ptr_2); + } + return 0; +} + +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(int8_t* ptr, std::vector shape, std::string filename); +#ifdef ENABLE_BF16 +template int loadWeightFromBinFunc<__nv_bfloat16, float>( + __nv_bfloat16* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc<__nv_bfloat16, half>( + __nv_bfloat16* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>( + __nv_bfloat16* ptr, std::vector shape, std::string filename); +#endif // ENABLE_BF16 +template int loadWeightFromBinFunc(int* ptr, std::vector shape, std::string filename); +#ifdef ENABLE_FP8 +template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>( + __nv_fp8_e4m3* ptr, std::vector shape, std::string filename); +#endif // ENABLE_FP8 + +template +int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type) +{ + switch (model_file_type) + { + case TRTLLMCudaDataType::FP32: loadWeightFromBinFunc(ptr, shape, filename); break; + case TRTLLMCudaDataType::FP16: loadWeightFromBinFunc(ptr, shape, filename); break; + case TRTLLMCudaDataType::INT8: loadWeightFromBinFunc(ptr, shape, filename); break; +#ifdef ENABLE_BF16 + case TRTLLMCudaDataType::BF16: loadWeightFromBinFunc(ptr, shape, filename); break; +#endif +#ifdef ENABLE_FP8 + case TRTLLMCudaDataType::FP8: loadWeightFromBinFunc(ptr, shape, filename); break; +#endif + default: TLLM_LOG_ERROR("Does not support TRTLLMCudaDataType=%d", model_file_type); TLLM_CHECK(false); + } + return 0; +} + +template <> +int loadWeightFromBin(int* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type) +{ + loadWeightFromBinFunc(ptr, shape, filename); + return 0; +} + +template int loadWeightFromBin( + float* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +template int loadWeightFromBin( + half* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +template int loadWeightFromBin( + int8_t* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +#ifdef ENABLE_BF16 +template int loadWeightFromBin( + __nv_bfloat16* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +#endif +#ifdef ENABLE_FP8 +template int loadWeightFromBin( + __nv_fp8_e4m3* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +#endif +template int loadWeightFromBin( + int* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); + +template +__global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = cuda_cast(src[tid]); + } +} + +template +void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size); +} + +template void invokeCudaD2DcpyConvert(int8_t* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, int8_t const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(half* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(half* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, half const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(uint32_t* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, uint32_t const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, half const* src, const size_t size, cudaStream_t stream); + +#ifdef ENABLE_BF16 +template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); +#endif // ENABLE_BF16 + +template +__global__ void cudaD2DScaleCpyConvert( + T_OUT* dst, const T_IN* src, float const* scale, bool invert_scale, const size_t size) +{ + float const scale_value = invert_scale ? 1.0f / scale[0] : scale[0]; + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = cuda_cast(cuda_cast(src[tid]) * scale_value); + } +} + +template +void invokeCudaD2DScaleCpyConvert( + T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream) +{ + cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size); +} + +// clang-format off +template void invokeCudaD2DScaleCpyConvert(float* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const float* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(half* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const half* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeCudaD2DScaleCpyConvert(__nv_bfloat16* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const __nv_bfloat16* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +template void invokeCudaD2DScaleCpyConvert(float* tgt, const __nv_fp8_e4m3* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +#endif // ENABLE_FP8 +// clang-format on + +void invokeCudaD2DcpyHalf2Float(float* dst, half* src, const size_t size, cudaStream_t stream) +{ + invokeCudaD2DcpyConvert(dst, src, size, stream); +} + +void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const size_t size, cudaStream_t stream) +{ + invokeCudaD2DcpyConvert(dst, src, size, stream); +} + +template +void saveToBinary(T const* ptr, const size_t size, std::string filename) +{ + + std::vector h_ptr(size); + cudaD2Hcpy(h_ptr.data(), ptr, size); + std::vector float_ptr(size); + for (size_t i = 0; i < size; i++) + { + float_ptr[i] = (float) h_ptr[i]; + } + + std::ofstream out(filename, std::ios::out | std::ios::binary); + TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); + + out.write((char*) float_ptr.data(), size * sizeof(float)); +} + +template void saveToBinary(float const* ptr, const size_t size, std::string filename); +template void saveToBinary(half const* ptr, const size_t size, std::string filename); +#ifdef ENABLE_BF16 +template void saveToBinary(__nv_bfloat16 const* ptr, const size_t size, std::string filename); +#endif // ENABLE_BF16 + +template <> +void saveToBinary(int const* ptr, const size_t size, std::string filename) +{ + std::vector h_ptr(size); + cudaD2Hcpy(h_ptr.data(), ptr, size); + std::ofstream out(filename, std::ios::out | std::ios::binary); + TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); + out.write((char*) h_ptr.data(), size * sizeof(int)); +} + +template +__global__ void fakeCast(T_IN* input_ptr, const size_t size) +{ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) + { + T_fake_type tmp_val = (T_fake_type) ((float) input_ptr[i]); + input_ptr[i] = (T_IN) ((float) tmp_val); + } +} + +template +void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream) +{ + dim3 block(256); + dim3 grid((size + 255) / 256); + fakeCast<<>>(input_ptr, size); +} + +#ifdef ENABLE_FP8 +__global__ void cudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (float) (src[tid]); + } +} + +void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream) +{ + cudaD2Dcpyfp82Float<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (half) ((float) (src[tid])); + } +} + +void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream) +{ + cudaD2Dcpyfp82Half<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (__nv_fp8_e4m3) src[tid]; + } +} + +void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyFloat2fp8<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (__nv_fp8_e4m3) src[tid]; + } +} + +void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyHalf2fp8<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (__nv_fp8_e4m3) src[tid]; + } +} + +void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyBfloat2fp8<<<256, 256, 0, stream>>>(dst, src, size); +} + +#endif // ENABLE_FP8 + +template +__global__ void transpose(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1; tid += blockDim.x * gridDim.x) + { + const size_t src_col_id = tid % dim1; + const size_t src_row_id = tid / dim1; + dst[src_col_id * dim0 + src_row_id] = (T_OUT) (src[tid]); + } +} + +template +void invokeInPlaceTranspose(T* data, T* workspace, const size_t dim0, const size_t dim1) +{ + // copy data to workspace, and then transpose from workspace to data + cudaD2Dcpy(workspace, data, dim0 * dim1); + transpose<<<256, 256>>>(data, workspace, dim0, dim1); +} + +#ifdef ENABLE_FP8 +template void invokeInPlaceTranspose( + __nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +template void invokeInPlaceTranspose( + __nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1); +#endif // ENABLE_BF16 +template void invokeInPlaceTranspose(float* data, float* workspace, const size_t dim0, const size_t dim1); + +template +__global__ void transpose0213( + T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3) +{ + // src permutation: [0, 1, 2, 3] + // dst permutation: [0, 2, 1, 3] + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2 * dim3; + tid += blockDim.x * gridDim.x) + { + size_t tmp_idx = tid; + const size_t dim_3_idx = tmp_idx % dim3; + tmp_idx = (tmp_idx - dim_3_idx) / dim3; + const size_t dim_2_idx = tmp_idx % dim2; + tmp_idx = (tmp_idx - dim_2_idx) / dim2; + const size_t dim_1_idx = tmp_idx % dim1; + tmp_idx = (tmp_idx - dim_1_idx) / dim1; + const size_t dim_0_idx = tmp_idx % dim0; + dst[dim_0_idx * dim1 * dim2 * dim3 + dim_2_idx * dim1 * dim3 + dim_1_idx * dim3 + dim_3_idx] = src[tid]; + } +} + +template +void invokeInPlaceTranspose0213( + T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3) +{ + // copy data to workspace, and then transpose from workspace to data + // Note that this kernel is used for pre-processing and not very efficient. + cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2 * dim3); + transpose0213<<<256, 256>>>(data, workspace, dim0, dim1, dim2, dim3); +} + +#ifdef ENABLE_FP8 +template void invokeInPlaceTranspose0213(__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, + const size_t dim1, const size_t dim2, const size_t dim3); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +template void invokeInPlaceTranspose0213(__nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, + const size_t dim1, const size_t dim2, const size_t dim3); +#endif // ENABLE_BF16 +template void invokeInPlaceTranspose0213( + float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3); + +template +__global__ void transpose102(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2) +{ + // src permutation: [0, 1, 2] + // dst permutation: [1, 0, 2] + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2; tid += blockDim.x * gridDim.x) + { + size_t tmp_idx = tid; + const size_t dim_2_idx = tmp_idx % dim2; + tmp_idx = (tmp_idx - dim_2_idx) / dim2; + const size_t dim_1_idx = tmp_idx % dim1; + tmp_idx = (tmp_idx - dim_1_idx) / dim1; + const size_t dim_0_idx = tmp_idx % dim0; + dst[dim_1_idx * dim0 * dim2 + dim_0_idx * dim2 + dim_2_idx] = src[tid]; + } +} + +template +void invokeInPlaceTranspose102(T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2) +{ + // copy data to workspace, and then transpose from workspace to data + // Note that this kernel is used for pre-processing and not very efficient. + cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2); + transpose102<<<256, 256>>>(data, workspace, dim0, dim1, dim2); +} + +#ifdef ENABLE_FP8 +template void invokeInPlaceTranspose102( + __nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1, const size_t dim2); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +template void invokeInPlaceTranspose102( + __nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1, const size_t dim2); +#endif // ENABLE_BF16 +template void invokeInPlaceTranspose102( + float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2); + +template +void __global__ multiplyScale(T* tensor, float scale, const size_t size) +{ + for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x) + { + tensor[index] = (T) (((float) tensor[index]) * scale); + } +} + +template +void invokeMultiplyScale(T* tensor, float scale, const size_t size, cudaStream_t stream) +{ + int block = 256; + int grid = (size + 255) / 256; + multiplyScale<<>>(tensor, scale, size); +} + +template void invokeMultiplyScale(float* tensor, float scale, const size_t size, cudaStream_t stream); +template void invokeMultiplyScale(half* tensor, float scale, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeMultiplyScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_FP8 +template void invokeMultiplyScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream); +#endif + +template +void __global__ divideScale(T* tensor, float scale, const size_t size) +{ + for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x) + { + tensor[index] = (T) (((float) tensor[index]) / scale); + } +} + +template +void invokeDivideScale(T* tensor, float scale, const size_t size, cudaStream_t stream) +{ + int block = 256; + int grid = (size + 255) / 256; + divideScale<<>>(tensor, scale, size); +} + +template void invokeDivideScale(float* tensor, float scale, const size_t size, cudaStream_t stream); +template void invokeDivideScale(half* tensor, float scale, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeDivideScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_FP8 +template void invokeDivideScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_BF16 +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast<__nv_bfloat16, __nv_bfloat16>( + __nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast(half* input_ptr, const size_t size, cudaStream_t stream); +#endif +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +#ifdef ENABLE_FP8 +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast(half* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast<__nv_bfloat16, __nv_fp8_e4m3>( + __nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream); +#endif + +size_t cuda_datatype_size(TRTLLMCudaDataType dt) +{ + static const std::unordered_map sizes{ + {TRTLLMCudaDataType::FP32, sizeof(float)}, {TRTLLMCudaDataType::FP16, sizeof(half)} +#ifdef ENABLE_BF16 + , + {TRTLLMCudaDataType::BF16, sizeof(__nv_bfloat16)} +#endif + }; + + return sizes.at(dt); +} + +template +__global__ void check_range(T const* buffer, size_t size, T min, T max, bool* d_within_range) +{ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) + { + const T val = buffer[i]; + if (val < min || val > max) + { + *d_within_range = false; + } + } +} + +template +bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream) +{ + cudaMemsetAsync(d_within_range, true, sizeof(bool), stream); + + dim3 block(256); + dim3 grid((size + 255) / 256); + check_range<<>>(buffer, size, min, max, d_within_range); + + bool result; + cudaD2Hcpy(&result, d_within_range, 1); + return result; +} + +template bool invokeCheckRange( + int const* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream); + +/* + * Determine the total workspace size based on a vector containing multiple variable sizes. + */ +size_t calcAlignedSize(std::vector const& sizes, const size_t ALIGN_BYTES) +{ + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + // Check ALIGN_BYTES is a power of 2 + assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0); + + size_t total = 0; + for (auto sz : sizes) + { + total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } + + // We add extra "ALIGN_BYTES - 1" bytes in case the start address passed to the function calcAlignedPointers() is + // not aligned. + return total + ALIGN_BYTES - 1; +} + +/* + * Given the address of the workspace and the vector containing multiple variable sizes, calculate the start addresses + * of each variable. + */ +void calcAlignedPointers( + std::vector& outPtrs, void const* p, std::vector const& sizes, size_t ALIGN_BYTES) +{ + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + // Check ALIGN_BYTES is a power of 2 + assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0); + + // In case the start address is not aligned + char* ptr = reinterpret_cast((reinterpret_cast(p) + ALIGN_BYTES - 1) & ALIGN_MASK); + + outPtrs.reserve(sizes.size()); + for (auto sz : sizes) + { + outPtrs.push_back(ptr); + ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } +} + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h new file mode 100644 index 0000000000..9e413a1beb --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include "tensorrt_llm/common/cudaUtils.h" + +#include + +namespace tensorrt_llm +{ +namespace common +{ + +template +void deviceMalloc(T** ptr, size_t size, bool is_random_initialize = true); + +template +void deviceMemSetZero(T* ptr, size_t size); + +template + +void deviceFree(T*& ptr); + +template +void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0); + +template +void cudaD2Hcpy(T* tgt, T const* src, size_t const size); + +template +void cudaH2Dcpy(T* tgt, T const* src, size_t const size); + +template +void cudaD2Dcpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL); + +template +void cudaAutoCpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL); + +template +void cudaRandomUniform(T* buffer, size_t const size); + +template +int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, + TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32); + +// template +// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr, +// T* scale_ptr, +// std::vector shape, +// std::string filename, +// TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32); + +void invokeCudaD2DcpyHalf2Float(float* dst, half* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyFloat2Half(half* dst, float* src, size_t const size, cudaStream_t stream); +#ifdef ENABLE_FP8 +void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream); +void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +void invokeCudaD2DcpyBfloat2Float(float* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream); +#endif // ENABLE_BF16 + +template +void invokeCudaCast(T_OUT* dst, T_IN const* const src, size_t const size, cudaStream_t stream); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The following functions implement conversion of multi-dimensional indices to an index in a flat array. +// The shape of the Tensor dimensions is passed as one array (`dims`), the indices are given as individual arguments. +// For examples on how to use these functions, see their tests `test_memory_utils.cu`. +// All of these functions can be evaluated at compile time by recursive template expansion. + +template +__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( + T const& acc, TDim dims, TIndex const& index) +{ + assert(index < dims[0]); + return acc * dims[0] + index; +} + +template +__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( + T const& acc, TDim dims, TIndex const& index, TIndices... indices) +{ + assert(index < dims[0]); + return flat_index(acc * dims[0] + index, dims + 1, indices...); +} + +template +__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( + [[maybe_unused]] TDim dims, T const& index) +{ + assert(index < dims[0]); + return index; +} + +template +__inline__ __host__ __device__ + std::enable_if_t::value, typename std::remove_pointer::type> constexpr flat_index( + TDim dims, TIndex const& index, TIndices... indices) +{ + assert(index < dims[0]); + return flat_index(static_cast::type>(index), dims + 1, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index( + std::array const& dims, TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(&dims[skip], index, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index( + T const& acc, std::array const& dims, TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(acc, &dims[skip], index, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index(T const (&dims)[N], TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(static_cast(dims) + skip, index, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index( + T const& acc, T const (&dims)[N], TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(acc, static_cast(dims) + skip, index, indices...); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// These are simpler functions for multi-dimensional index conversion. Indices and dimensions are passed as individual +// arguments. These functions are more suitable for usage inside kernels than the corresponding flat_index functions +// which require arrays as arguments. Usage examples can be found in `test_memory_utils.cu`. The functions can be +// evaluated at compile time. + +template +__inline__ __host__ __device__ T constexpr flat_index2(TIndex const& index_0, TIndex const& index_1, T const& dim_1) +{ + assert(index_1 < dim_1); + return index_0 * dim_1 + index_1; +} + +template +__inline__ __host__ __device__ T constexpr flat_index3( + TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& dim_1, T const& dim_2) +{ + assert(index_2 < dim_2); + return flat_index2(index_0, index_1, dim_1) * dim_2 + index_2; +} + +template +__inline__ __host__ __device__ T constexpr flat_index4(TIndex const& index_0, TIndex const& index_1, + TIndex const& index_2, TIndex const& index_3, T const& dim_1, T const& dim_2, T const& dim_3) +{ + assert(index_3 < dim_3); + return flat_index3(index_0, index_1, index_2, dim_1, dim_2) * dim_3 + index_3; +} + +template +__inline__ __host__ __device__ T constexpr flat_index5(TIndex const& index_0, TIndex const& index_1, + TIndex const& index_2, TIndex const& index_3, TIndex const& index_4, T const& dim_1, T const& dim_2, T const& dim_3, + T const& dim_4) +{ + assert(index_4 < dim_4); + return flat_index4(index_0, index_1, index_2, index_3, dim_1, dim_2, dim_3) * dim_4 + index_4; +} + +template +__inline__ __host__ __device__ T constexpr flat_index_strided3( + TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& stride_1, T const& stride_2) +{ + assert(index_1 < stride_1 / stride_2); + assert(index_2 < stride_2); + return index_0 * stride_1 + index_1 * stride_2 + index_2; +} + +template +__inline__ __host__ __device__ T constexpr flat_index_strided4(TIndex const& index_0, TIndex const& index_1, + TIndex const& index_2, TIndex const& index_3, T const& stride_1, T const& stride_2, T const& stride_3) +{ + assert(index_1 < stride_1 / stride_2); + assert(index_2 < stride_2 / stride_3); + assert(index_3 < stride_3); + return index_0 * stride_1 + index_1 * stride_2 + index_2 * stride_3 + index_3; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void invokeInPlaceTranspose(T* data, T* workspace, size_t const dim0, size_t const dim1); + +template +void invokeInPlaceTranspose0213( + T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2, size_t const dim3); + +template +void invokeInPlaceTranspose102(T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2); + +template +void invokeMultiplyScale(T* tensor, float scale, size_t const size, cudaStream_t stream); + +template +void invokeDivideScale(T* tensor, float scale, size_t const size, cudaStream_t stream); + +template +void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, size_t const size, cudaStream_t stream = 0); + +template +void invokeCudaD2DScaleCpyConvert( + T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, size_t const size, cudaStream_t stream = 0); + +inline bool checkIfFileExist(std::string const& file_path) +{ + std::ifstream in(file_path, std::ios::in | std::ios::binary); + if (in.is_open()) + { + in.close(); + return true; + } + return false; +} + +template +void saveToBinary(T const* ptr, size_t const size, std::string filename); + +template +void invokeFakeCast(T_IN* input_ptr, size_t const size, cudaStream_t stream); + +size_t cuda_datatype_size(TRTLLMCudaDataType dt); + +template +bool invokeCheckRange(T const* buffer, size_t const size, T min, T max, bool* d_within_range, cudaStream_t stream); + +constexpr size_t DEFAULT_ALIGN_BYTES = 256; + +size_t calcAlignedSize(std::vector const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES); +void calcAlignedPointers(std::vector& outPtrs, void const* p, std::vector const& sizes, + size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES); + +struct AlignedPointersUnpacker +{ + template + void operator()(T*&... outPtrs) + { + assert(sizeof...(T) == alignedPointers.size()); + auto it = alignedPointers.begin(); + ((outPtrs = static_cast(*it++)), ...); + } + + std::vector alignedPointers; +}; + +AlignedPointersUnpacker inline calcAlignedPointers( + void const* p, std::vector const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES) +{ + AlignedPointersUnpacker unpacker{}; + calcAlignedPointers(unpacker.alignedPointers, p, sizes, ALIGN_BYTES); + return unpacker; +} + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp new file mode 100644 index 0000000000..dbdaca4ee7 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp @@ -0,0 +1,588 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "tensorrt_llm/common/mpiUtils.h" + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/runtime/common.h" +#include "tensorrt_llm/runtime/iBuffer.h" + +#include +#include +#include +#include +#include +#ifndef _WIN32 +#include +#endif + +// We rely on SizeType32 being int32_t in some places with weak type checking, +// i.e. we're passing void ptr to some function. To prevent mysterious errors +// in the future, we trigger a compilation error here if SizeType32 isn't int32_t. +static_assert(std::is_same::value); + +namespace tensorrt_llm::mpi +{ + +MPI_Datatype getMpiDtype(MpiType dtype) +{ +#if ENABLE_MULTI_DEVICE + static std::unordered_map const dtype_map{ + {MpiType::kBYTE, MPI_BYTE}, + {MpiType::kHALF, MPI_UINT16_T}, + {MpiType::kFLOAT, MPI_FLOAT}, + {MpiType::kDOUBLE, MPI_DOUBLE}, + {MpiType::kBOOL, MPI_C_BOOL}, + {MpiType::kINT8, MPI_INT8_T}, + {MpiType::kUINT8, MPI_UINT8_T}, + {MpiType::kINT32, MPI_INT32_T}, + {MpiType::kUINT32, MPI_UINT32_T}, + {MpiType::kINT64, MPI_INT64_T}, + {MpiType::kUINT64, MPI_UINT64_T}, + {MpiType::kFP8, MPI_UINT8_T}, + {MpiType::kBF16, MPI_UINT16_T}, + {MpiType::kCHAR, MPI_CHAR}, + }; + return dtype_map.at(dtype); +#else + TLLM_THROW("Multi device support is disabled."); +#endif +} + +MPI_Op getMpiOp(MpiOp op) +{ +#if ENABLE_MULTI_DEVICE + static std::unordered_map const op_map{ + {MpiOp::NULLOP, MPI_OP_NULL}, + {MpiOp::MAX, MPI_MAX}, + {MpiOp::MIN, MPI_MIN}, + {MpiOp::SUM, MPI_SUM}, + {MpiOp::PROD, MPI_PROD}, + {MpiOp::LAND, MPI_LAND}, + {MpiOp::BAND, MPI_BAND}, + {MpiOp::LOR, MPI_LOR}, + {MpiOp::BOR, MPI_BOR}, + {MpiOp::LXOR, MPI_LXOR}, + {MpiOp::BXOR, MPI_BXOR}, + {MpiOp::MINLOC, MPI_MINLOC}, + {MpiOp::MAXLOC, MPI_MAXLOC}, + {MpiOp::REPLACE, MPI_REPLACE}, + }; + return op_map.at(op); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +namespace +{ + +bool mpiInitialized = false; +std::recursive_mutex mpiMutex; + +MpiComm initLocalSession() +{ +#if ENABLE_MULTI_DEVICE + MPI_Comm localComm = nullptr; + MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm); + MpiComm localSession{localComm, false}; +#else + MpiComm localSession{COMM_SESSION, false}; +#endif // ENABLE_MULTI_DEVICE + return localSession; +} + +} // namespace + +std::vector getWorldRanks(MpiComm const& comm) +{ +#if ENABLE_MULTI_DEVICE + MPI_Group group = nullptr; + MPI_Group worldGroup = nullptr; + + MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); + MPICHECK(MPI_Comm_group(comm, &group)); + + int groupSize = 0; + MPICHECK(MPI_Group_size(group, &groupSize)); + std::vector ranks(groupSize); + std::vector worldRanks(groupSize); + std::iota(ranks.begin(), ranks.end(), 0); + + MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data())); + MPICHECK(MPI_Group_free(&group)); + MPICHECK(MPI_Group_free(&worldGroup)); +#else + std::vector worldRanks{0}; +#endif + return worldRanks; +} + +void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent) +{ + // double-checked locking + if (mpiInitialized) + { + return; + } + std::lock_guard lk(mpiMutex); + if (mpiInitialized) + { + return; + } +#if ENABLE_MULTI_DEVICE + int initialized = 0; + TLLM_MPI_CHECK(MPI_Initialized(&initialized)); + if (!initialized) + { + TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode); + int providedMode = 0; + auto requiredMode = static_cast(threadMode); + MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode)); + TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed"); + std::atexit([]() { MPI_Finalize(); }); + + /* + * We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2 + * signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers + * correctly. + */ + for (int sig : {SIGABRT, SIGSEGV}) + { + __sighandler_t previousHandler = nullptr; + if (forwardAbortToParent) + { + previousHandler = std::signal(sig, + [](int signal) + { +#ifndef _WIN32 + pid_t parentProcessId = getppid(); + kill(parentProcessId, SIGKILL); +#endif + MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); + }); + } + else + { + previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); }); + } + TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed"); + } + + // ensure local MPI communicator is initialized + MpiComm::localSession(); + TLLM_LOG_INFO("Initialized MPI"); + } +#endif // ENABLE_MULTI_DEVICE + mpiInitialized = true; +} + +void MpiComm::barrier() const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Barrier(mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +#if ENABLE_MULTI_DEVICE +template >>> +size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args) +{ + constexpr auto maxP1 = static_cast(std::numeric_limits::max()) + 1; + if (TLLM_LIKELY(size < maxP1)) + { + MPICHECK(func(buffer, size, dtype, args...)); + return 1; + } + + constexpr size_t alignment = 256; + int elementSize = 1; + MPICHECK(MPI_Type_size(dtype, &elementSize)); + elementSize = std::min(elementSize, alignment); + + // We cap at max alignment-bytes chunks that can be sent at once. + auto const step = maxP1 - (alignment / elementSize); + + using TCast = std::conditional_t, uint8_t const, uint8_t>; + size_t count = 0; + while (size != 0) + { + auto currentStep = static_cast(std::min(size, step)); + MPICHECK(func(buffer, currentStep, dtype, args...)); + size -= currentStep; + size_t diff = static_cast(currentStep) * elementSize; + buffer = static_cast(buffer) + diff; + ++count; + } + + return count; +} +#endif // ENABLE_MULTI_DEVICE + +std::shared_ptr MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const +{ + std::shared_ptr r = std::make_shared(); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + return r; +} + +std::shared_ptr MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const +{ + TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU); + return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); +} + +void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const +{ +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::bcast(runtime::IBuffer& buf, int root) const +{ + bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); +} + +std::shared_ptr MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const +{ + TLLM_LOG_DEBUG("start MPI_Isend with size %d", size); + std::shared_ptr r = std::make_shared(); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, tag, mComm, &r->mRequest); +#else + TLLM_THROW("Multi device support is disabled."); +#endif + TLLM_LOG_DEBUG("end MPI_Isend with size %d", size); + return r; +} + +std::shared_ptr MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, int tag) const +{ + return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); +} + +void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const +{ + TLLM_LOG_DEBUG("start MPI_Send with size %d", size); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + TLLM_LOG_DEBUG("end MPI_Send with size %d", size); +} + +void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const +{ + send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); +} + +MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const +{ + TLLM_LOG_DEBUG("start MPI_Recv with size %d", size); + MPI_Status status{}; +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + TLLM_LOG_DEBUG("end MPI_Recv with size %d", size); + return status; +} + +MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const +{ + return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag); +} + +MpiComm MpiComm::split(int color, int key) const +{ + MPI_Comm splitComm = nullptr; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + return MpiComm{splitComm, true}; +} + +void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf, + std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(), + getMpiDtype(recvtype), mComm)); + +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const +{ +#if ENABLE_MULTI_DEVICE + int flag{0}; + MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status)); + return flag != 0; +#else + TLLM_THROW("Multi device support is disabled."); + return false; +#endif +} + +bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const +{ +#if ENABLE_MULTI_DEVICE + int flag{0}; + MPICHECK(MPI_Iprobe(source, tag, mComm, &flag, status)); + return flag != 0; +#else + TLLM_THROW("Multi device support is disabled."); + return false; +#endif +} + +void MpiComm::recvPoll(int source, int tag, int periodMs) const +{ + MPI_Status status; + while (!iprobe(source, tag, &status)) + { + std::this_thread::sleep_for(std::chrono::milliseconds(periodMs)); + } +} + +int MpiComm::getRank() const +{ + int rank = 0; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_rank(mComm, &rank)); +#endif + return rank; +} + +int MpiComm::getSize() const +{ + int world_size = 1; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_size(mComm, &world_size)); +#endif + return world_size; +} + +MpiComm const& MpiComm::world() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm commWorld{MPI_COMM_WORLD, false}; + initialize(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return commWorld; +} + +MpiComm& MpiComm::mutableSession() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm commSession{MPI_COMM_WORLD, false}; + initialize(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return commSession; +} + +MpiComm& MpiComm::mutableLocalSession() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm localSession = initLocalSession(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return localSession; +} + +void MpiComm::refreshLocalSession() +{ +#if ENABLE_MULTI_DEVICE + static std::mutex mutex; + std::unique_lock lock(mutex); + auto initSessionRanks = getWorldRanks(MpiComm::session()); + auto localSessionRanks = getWorldRanks(MpiComm::localSession()); + + // Add to intersectionRanks in order of initSessionRanks + std::vector intersectionRanks; + std::unordered_set localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end()); + for (auto rank : initSessionRanks) + { + if (localSessionRanksSet.find(rank) != localSessionRanksSet.end()) + { + intersectionRanks.push_back(rank); + } + } + + MPI_Group worldGroup = nullptr; + MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); + MPI_Group localGroup = nullptr; + MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup)); + MPI_Comm localComm = nullptr; + MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm)); + MpiComm::mutableLocalSession().mFreeComm = true; + MpiComm::mutableLocalSession() = MpiComm{localComm, false}; + TLLM_LOG_INFO("Refreshed the MPI local session"); +#endif // ENABLE_MULTI_DEVICE +} + +MpiComm::MpiComm(MPI_Comm g, bool freeComm) + : mComm{g} + , mFreeComm{freeComm} +{ + TLLM_CHECK(mComm != MPI_COMM_NULL); +} + +MpiComm::~MpiComm() noexcept +{ +#if ENABLE_MULTI_DEVICE + if (mFreeComm && mComm) + { + if (MPI_Comm_free(&mComm) != MPI_SUCCESS) + { + TLLM_LOG_ERROR("MPI_Comm_free failed"); + } + } +#endif // ENABLE_MULTI_DEVICE +} + +MpiComm::MpiComm(MpiComm&& comm) noexcept + : mComm{comm.mComm} + , mFreeComm{comm.mFreeComm} +{ + comm.mFreeComm = false; +} + +MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept +{ + this->~MpiComm(); + mComm = comm.mComm; + mFreeComm = comm.mFreeComm; + comm.mFreeComm = false; + return *this; +} + +MpiWaitThread::MpiWaitThread(std::string name, std::function funcWait, std::function funcSetup) + : mName{name.c_str()} + , mFuncWait{funcWait} + , mFuncSetup{funcSetup} +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + mThread = std::make_unique(&MpiWaitThread::sideThread, this); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +MpiWaitThread::~MpiWaitThread() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + waitStop(); + mShouldExit.store(true); + notifyStart(); + mThread->join(); + mThread.reset(nullptr); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::sideThread() +{ + if (mFuncSetup) + { + mFuncSetup(); + } + while (!mShouldExit.load()) + { + notifyStop(); + waitStart(); + mFuncWait(); + } +} + +void MpiWaitThread::waitStart() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::unique_lock lock(mMutex); + mCondVar.wait(lock, [this] { return mRunning; }); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::waitStop() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::unique_lock lock(mMutex); + mCondVar.wait(lock, [this] { return !mRunning; }); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::notifyStart() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::lock_guard lock(mMutex); + mRunning = true; + mCondVar.notify_one(); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::notifyStop() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::lock_guard lock(mMutex); + mRunning = false; + mCondVar.notify_one(); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +} // namespace tensorrt_llm::mpi diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h new file mode 100644 index 0000000000..0a9d51975a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h @@ -0,0 +1,46 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +namespace tensorrt_llm::common::nvtx +{ +inline nvtx3::color nextColor() +{ +#ifndef NVTX_DISABLE + constexpr std::array kColors{nvtx3::color{0xff00ff00}, nvtx3::color{0xff0000ff}, nvtx3::color{0xffffff00}, + nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}}; + constexpr auto numColors = kColors.size(); + + static thread_local std::size_t colorId = 0; + auto const color = kColors[colorId]; + colorId = colorId + 1 >= numColors ? 0 : colorId + 1; + return color; +#else + return nvtx3::color{0}; +#endif +} + +} // namespace tensorrt_llm::common::nvtx + +#define NVTX3_SCOPED_RANGE_WITH_NAME(range, name) \ + ::nvtx3::scoped_range range(::tensorrt_llm::common::nvtx::nextColor(), name) +#define NVTX3_SCOPED_RANGE(range) NVTX3_SCOPED_RANGE_WITH_NAME(range##_range, #range) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp new file mode 100644 index 0000000000..39aefda481 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp @@ -0,0 +1,323 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/common/mpiUtils.h" + +#include "cuda.h" +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#define FN_NAME __FUNCTION__ +#else +#define FN_NAME __func__ +#endif + +#if ENABLE_MULTI_DEVICE + +std::unordered_map* getDtypeMap() +{ + static std::unordered_map dtypeMap = {{nvinfer1::DataType::kFLOAT, ncclFloat32}, + {nvinfer1::DataType::kHALF, ncclFloat16}, {nvinfer1::DataType::kBF16, ncclBfloat16}}; + return &dtypeMap; +} + +namespace +{ + +// Get NCCL unique ID for a group of ranks. +ncclUniqueId getUniqueId(std::set const& group) noexcept +{ + auto const rank = COMM_SESSION.getRank(); + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); + ncclUniqueId id; + if (rank == *group.begin()) + { + NCCLCHECK(ncclGetUniqueId(&id)); + for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) + { + COMM_SESSION.sendValue(id, *it, 0); + } + } + else + { + COMM_SESSION.recvValue(id, *group.begin(), 0); + } + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); + return id; +} +} // namespace + +std::shared_ptr getComm(std::set const& group) +{ + auto const rank = COMM_SESSION.getRank(); + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); + static std::map, std::shared_ptr> commMap; + static std::mutex mutex; + std::lock_guard lock(mutex); + std::ostringstream oss; + int index = 0; + for (auto const& rank : group) + { + if (index != 0) + { + oss << ","; + } + oss << rank; + index++; + } + auto groupStr = oss.str(); + auto it = commMap.find(group); + if (it != commMap.end()) + { + auto ncclComm = it->second; + TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); + return ncclComm; + } + + TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); + ncclUniqueId id = getUniqueId(group); + int groupRank = 0; + for (auto const& currentRank : group) + { + if (rank == currentRank) + break; + ++groupRank; + } + TLLM_CHECK(groupRank < group.size()); + std::shared_ptr ncclComm(new ncclComm_t, + [](ncclComm_t* comm) + { + ncclCommDestroy(*comm); + delete comm; + }); + NCCLCHECK(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); + commMap[group] = ncclComm; + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); + return ncclComm; +} +#endif // ENABLE_MULTI_DEVICE + +void const* tensorrt_llm::common::getCommSessionHandle() +{ +#if ENABLE_MULTI_DEVICE + return &COMM_SESSION; +#else + return nullptr; +#endif // ENABLE_MULTI_DEVICE +} + +namespace +{ + +// Get current cuda context, a default context will be created if there is no context. +inline CUcontext getCurrentCudaCtx() +{ + CUcontext ctx{}; + CUresult err = cuCtxGetCurrent(&ctx); + if (err == CUDA_ERROR_NOT_INITIALIZED || ctx == nullptr) + { + TLLM_CUDA_CHECK(cudaFree(nullptr)); + err = cuCtxGetCurrent(&ctx); + } + TLLM_CHECK(err == CUDA_SUCCESS); + return ctx; +} + +// Helper to create per-cuda-context singleton managed by std::shared_ptr. +// Unlike conventional singletons, singleton created with this will be released +// when not needed, instead of on process exit. +// Objects of this class shall always be declared static / global, and shall never own CUDA +// resources. +template +class PerCudaCtxSingletonCreator +{ +public: + using CreatorFunc = std::function()>; + using DeleterFunc = std::function; + + // creator returning std::unique_ptr is by design. + // It forces separation of memory for T and memory for control blocks. + // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. + // creator itself must not own CUDA resources. Only the object it creates can. + PerCudaCtxSingletonCreator(CreatorFunc creator, DeleterFunc deleter) + : mCreator{std::move(creator)} + , mDeleter{std::move(deleter)} + { + } + + std::shared_ptr operator()() + { + std::lock_guard lk{mMutex}; + CUcontext ctx{getCurrentCudaCtx()}; + std::shared_ptr result = mObservers[ctx].lock(); + if (result == nullptr) + { + // Create the resource and register with an observer. + result = std::shared_ptr{mCreator().release(), + [this, ctx](T* obj) + { + if (obj == nullptr) + { + return; + } + mDeleter(obj); + + // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts + // frequently. + std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. + std::lock_guard lk{mMutex}; + // Must check observer again because another thread may created new instance for this ctx just + // before we lock mMutex. We can't infer that the observer is stale from the fact that obj is + // destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic + // operation, and the observer may be changed to observe another instance. + observedObjHolder = mObservers.at(ctx).lock(); + if (observedObjHolder == nullptr) + { + mObservers.erase(ctx); + } + }}; + mObservers.at(ctx) = result; + } + return result; + } + +private: + CreatorFunc mCreator; + DeleterFunc mDeleter; + mutable std::mutex mMutex; + // CUDA resources are per-context. + std::unordered_map> mObservers; +}; + +template +class PerThreadSingletonCreator +{ +public: + using CreatorFunc = std::function()>; + using DeleterFunc = std::function; + + // creator returning std::unique_ptr is by design. + // It forces separation of memory for T and memory for control blocks. + // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. + // creator itself must not own CUDA resources. Only the object it creates can. + PerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter) + : mCreator{std::move(creator)} + , mDeleter{std::move(deleter)} + { + } + + std::shared_ptr operator()() + { + std::lock_guard lk{mMutex}; + + std::thread::id thread = std::this_thread::get_id(); + std::shared_ptr result = mObservers[thread].lock(); + + if (result == nullptr) + { + // Create the resource and register with an observer. + result = std::shared_ptr{mCreator().release(), + [this, thread](T* obj) + { + if (obj == nullptr) + { + return; + } + mDeleter(obj); + + // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts + // frequently. + std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. + std::lock_guard lk{mMutex}; + // Must check observer again because another thread may created new instance for this ctx just + // before we lock mMutex. We can't infer that the observer is stale from the fact that obj is + // destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic + // operation, and the observer may be changed to observe another instance. + observedObjHolder = mObservers.at(thread).lock(); + if (observedObjHolder == nullptr) + { + mObservers.erase(thread); + } + }}; + mObservers.at(thread) = result; + } + return result; + } + +private: + CreatorFunc mCreator; + DeleterFunc mDeleter; + mutable std::mutex mMutex; + // CUDA resources are per-thread. + std::unordered_map> mObservers; +}; + +} // namespace + +std::shared_ptr getCublasHandle() +{ + static PerThreadSingletonCreator creator( + []() -> auto + { + auto handle = std::unique_ptr(new cublasHandle_t); + TLLM_CUDA_CHECK(cublasCreate(handle.get())); + return handle; + }, + [](cublasHandle_t* handle) + { + TLLM_CUDA_CHECK(cublasDestroy(*handle)); + delete handle; + }); + return creator(); +} + +std::shared_ptr getCublasLtHandle() +{ + static PerThreadSingletonCreator creator( + []() -> auto + { + auto handle = std::unique_ptr(new cublasLtHandle_t); + TLLM_CUDA_CHECK(cublasLtCreate(handle.get())); + return handle; + }, + [](cublasLtHandle_t* handle) + { + TLLM_CUDA_CHECK(cublasLtDestroy(*handle)); + delete handle; + }); + return creator(); +} + +std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, + std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace) +{ + static PerThreadSingletonCreator creator( + [cublasHandle, cublasltHandle, stream, workspace]() -> auto + { + auto wrapper = std::unique_ptr( + new tensorrt_llm::common::CublasMMWrapper(cublasHandle, cublasltHandle, stream, workspace)); + return wrapper; + }, + [](tensorrt_llm::common::CublasMMWrapper* wrapper) { delete wrapper; }); + return creator(); +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h new file mode 100644 index 0000000000..4e278e5cf2 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h @@ -0,0 +1,215 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cublasMMWrapper.h" +#include "tensorrt_llm/common/workspace.h" + +#include +#include +#include +#include +#if ENABLE_MULTI_DEVICE +#include +#endif // ENABLE_MULTI_DEVICE + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +// Write values into buffer +template +void write(char*& buffer, T const& val) +{ + std::memcpy(buffer, &val, sizeof(T)); + buffer += sizeof(T); +} + +// Read values from buffer +template +void read(char const*& buffer, T& val) +{ + std::memcpy(&val, buffer, sizeof(T)); + buffer += sizeof(T); +} + +// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members. +// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and +// your clone() implementation is responsible for initializing such data members. +// With this we can simplify clone() implementation when there are many data members including at least one unique_ptr. +template > +class UniqPtrWNullCopy : public std::unique_ptr +{ +public: + using std::unique_ptr::unique_ptr; + + // for compatibility with std::make_unique + explicit UniqPtrWNullCopy(std::unique_ptr&& src) + : std::unique_ptr::unique_ptr{std::move(src)} + { + } + + // copy constructor produces nullptr + UniqPtrWNullCopy(UniqPtrWNullCopy const&) + : std::unique_ptr::unique_ptr{} + { + } +}; + +// for testing only +void const* getCommSessionHandle(); +} // namespace tensorrt_llm::common + +inline bool isBuilding() +{ + auto constexpr key = "IS_BUILDING"; + auto const val = getenv(key); + return val != nullptr && std::string(val) == "1"; +} + +#if ENABLE_MULTI_DEVICE +#define NCCLCHECK(cmd) \ + do \ + { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) \ + { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +std::unordered_map* getDtypeMap(); + +std::shared_ptr getComm(std::set const& group); + +#endif // ENABLE_MULTI_DEVICE + +//! To save GPU memory, all the plugins share the same cublas and cublasLt handle globally. +//! Get cublas and cublasLt handle for current cuda context +std::shared_ptr getCublasHandle(); +std::shared_ptr getCublasLtHandle(); +std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, + std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace); + +#ifndef DEBUG + +#define PLUGIN_CHECK(status) \ + do \ + { \ + if (status != 0) \ + abort(); \ + } while (0) + +#define ASSERT_PARAM(exp) \ + do \ + { \ + if (!(exp)) \ + return STATUS_BAD_PARAM; \ + } while (0) + +#define ASSERT_FAILURE(exp) \ + do \ + { \ + if (!(exp)) \ + return STATUS_FAILURE; \ + } while (0) + +#define CSC(call, err) \ + do \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaStatus != cudaSuccess) \ + { \ + return err; \ + } \ + } while (0) + +#define DEBUG_PRINTF(...) \ + do \ + { \ + } while (0) + +#else + +#define ASSERT_PARAM(exp) \ + do \ + { \ + if (!(exp)) \ + { \ + fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \ + return STATUS_BAD_PARAM; \ + } \ + } while (0) + +#define ASSERT_FAILURE(exp) \ + do \ + { \ + if (!(exp)) \ + { \ + fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \ + return STATUS_FAILURE; \ + } \ + } while (0) + +#define CSC(call, err) \ + do \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaStatus != cudaSuccess) \ + { \ + printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \ + return err; \ + } \ + } while (0) + +#define PLUGIN_CHECK(status) \ + { \ + if (status != 0) \ + { \ + DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \ + abort(); \ + } \ + } + +#define DEBUG_PRINTF(...) \ + do \ + { \ + printf(__VA_ARGS__); \ + } while (0) + +#endif // DEBUG + +#define NVML_CHECK(cmd) \ + do \ + { \ + nvmlReturn_t r = cmd; \ + if (r != NVML_SUCCESS) \ + { \ + printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh new file mode 100644 index 0000000000..a228d3f9fc --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +template +struct QuantTypeStaticVals; + +template <> +struct QuantTypeStaticVals +{ + static constexpr float MAX_VAL = 127.f; + static constexpr float MIN_SCALING_FACTOR = 0.f; + static constexpr float MIN_SCALING_FACTOR_RCP = FLT_MAX; +}; + +#ifdef ENABLE_FP8 + +template <> +struct QuantTypeStaticVals<__nv_fp8_e4m3> +{ + static constexpr float MAX_VAL = 448.f; + // Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720 + static constexpr float MIN_SCALING_FACTOR = 1.0f / (448.f * 512.f); + static constexpr float MIN_SCALING_FACTOR_RCP = (448.f * 512.f); +}; + +#endif // ENABLE_FP8 + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh new file mode 100644 index 0000000000..c5a4fe0e24 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh @@ -0,0 +1,399 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) +#include +#else +#include +#endif +#include "tensorrt_llm/common/cudaTypeUtils.cuh" +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + +namespace tensorrt_llm +{ +namespace common +{ + +template +struct BytesToType; + +template <> +struct BytesToType<1> +{ + using type = uint8_t; +}; + +template <> +struct BytesToType<2> +{ + using type = uint16_t; +}; + +template <> +struct BytesToType<4> +{ + using type = uint32_t; +}; + +template <> +struct BytesToType<8> +{ + using type = uint64_t; +}; + +template <> +struct BytesToType<16> +{ + using type = float4; +}; + +template +__device__ inline void copy(void const* local, void* data) +{ + using T = typename BytesToType::type; + + T const* in = static_cast(local); + T* out = static_cast(data); + *out = *in; +} + +static float constexpr HALF_FLT_MAX = 65504.F; +#define FINAL_MASK 0xffffffff + +template +__inline__ __device__ T warpReduceSum(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T) (0.0f); + val = warpReduceSum(val); + + return val; +} + +template +__inline__ __device__ T warpReduceMax(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} + +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + + return val; +} + +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockAllReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + + return val; +} + +template +__inline__ __device__ T warpReduceSumV2(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) + { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T) (0.0f); +} + +template +__inline__ __device__ T blockReduceSumV2(T* val) +{ + static __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSumV2(val); + + if (lane == 0) + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) + { + val[i] = is_mask ? shared[i][lane] : (T) (0.0f); + } + warpReduceSumV2(val); + return (T) 0.0f; +} + +template +__inline__ __device__ T warpReduceMaxV2(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) + { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); + } + return (T) (0.0f); +} + +template +__inline__ __device__ T blockReduceMaxV2(T* val) +{ + static __shared__ T shared[32][NUM]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + warpReduceMaxV2(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + shared[wid][i] = val[i]; + } + } + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) + { + val[i] = is_mask ? shared[lane][i] : (T) -1e20f; + } + warpReduceMaxV2(val); + + return (T) 0.0f; +} + +template +__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm) +{ + cg::thread_block cta = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta); + + int const tid = cta.thread_rank(); + int const blockz = blockDim.x; + for (int i = 0; i < NUM; i++) + { +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) + cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus()); +#else + // TODO Add implementation here + if (threadIdx.x == 0 && blockIdx.x == 0) + { + printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n"); + assert(false); + } +#endif + } + cg::sync(cta); + if (tid == 0) + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + float beta = 0.0f; + for (int j = 0; j < blockz; j += 32) + { + beta += cgBlockReduceSumElements_shm[i * blockz + j]; + } + element_list[i] = beta; + } + } +} + +template +struct TopK +{ + int p[MAX_K]; // index, being -1 at the tail if the array is not full + T u[MAX_K]; // value in descend order, being -MAX_T_VAL if the element is invalid + + __device__ __forceinline__ void insert(T const elem, int const elem_id) + { + if (elem_id < 0) + { + return; + } + // Condition of updating the array + // 1. array is not full + // 2. elem is greater than the smallest (last) element in the array + // 3. elem is equal to the smallest (last) element in the array but its elem_id is smaller + bool const need_update + = (p[MAX_K - 1] == -1 || elem > u[MAX_K - 1] || elem == u[MAX_K - 1] && elem_id < p[MAX_K - 1]); + if (!need_update) + { + return; + } + // Find suitable index for the new element + int i; + for (i = MAX_K - 2; i >= 0; --i) + { + bool const need_decrease = (p[i] == -1 || elem > u[i] || elem == u[i] && elem_id < p[i]); + if (!need_decrease) + break; + } + // Move elements to correct positions + for (int k = MAX_K - 2; k >= i; --k) + { + p[k + 1] = p[k]; + u[k + 1] = u[k]; + } + p[i] = elem_id; + u[i] = elem; + } + + __device__ __forceinline__ void init() + { + T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; + for (int i = 0; i < MAX_K; i++) + { + p[i] = -1; + u[i] = -MAX_T_VAL; + } + } +}; + +template +__device__ __forceinline__ TopK reduce_topk_op(TopK const& a, TopK const& b) +{ + TopK res = a; + for (int i = 0; i < MAX_K; ++i) + res.insert(b.u[i], b.p[i]); + return res; +} + +template +struct TopK_2 +{ + int p = -1; + T u = -((std::is_same::value) ? HALF_FLT_MAX : FLT_MAX); + + __device__ __forceinline__ void insert(T elem, int elem_id) + { + if (elem > u) + { + u = elem; + p = elem_id; + } + } + + __device__ __forceinline__ void init() + { + u = -((std::is_same::value) ? HALF_FLT_MAX : FLT_MAX); + p = -1; + } +}; + +template +__device__ __forceinline__ TopK_2 reduce_topk_op_2(TopK_2 const& a, TopK_2 const& b) +{ + return a.u > b.u ? a : b; +} + +template +__device__ __forceinline__ T clamp_inf_for_half(float const input) +{ + return input; +} + +template <> +__device__ __forceinline__ half clamp_inf_for_half(float const input) +{ + // clamp inf values to enable fp16 training + return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000); +} + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h new file mode 100644 index 0000000000..9cda9fa0d4 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace tensorrt_llm::common::stl_utils +{ + +template +constexpr TOutputIt basicInclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, TBinOp op) +{ + if (first != last) + { + auto val = *first; + while (true) + { + *dFirst = val; + ++dFirst; + ++first; + if (first == last) + { + break; + } + val = op(std::move(val), *first); + } + } + return dFirst; +} + +template +constexpr TOutputIt inclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst) +{ +#if defined(__GNUC__) && __GNUC__ <= 8 + return basicInclusiveScan(first, last, dFirst, std::plus<>{}); +#else + return std::inclusive_scan(first, last, dFirst); +#endif +} + +template +constexpr TOutputIt basicExclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init, TBinOp op) +{ + if (first != last) + { + while (true) + { + T tmp{op(init, *first)}; + *dFirst = init; + ++dFirst; + ++first; + if (first == last) + { + break; + } + init = std::move(tmp); + } + } + return dFirst; +} + +template +constexpr TOutputIt exclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init) +{ +#if defined(__GNUC__) && __GNUC__ <= 8 + return basicExclusiveScan(first, last, dFirst, std::move(init), std::plus<>{}); +#else + return std::exclusive_scan(first, last, dFirst, std::move(init)); +#endif +} + +template +struct HasOperatorOutput : std::false_type +{ +}; + +template +struct HasOperatorOutput() << std::declval()))>> + : std::true_type +{ +}; + +template +std::string toString(T const& t, typename std::enable_if_t::value, int> = 0) +{ + std::ostringstream oss; + oss << t; + return oss.str(); +} + +template +std::string toString(std::optional const& t, typename std::enable_if_t::value, int> = 0) +{ + std::ostringstream oss; + if (t) + { + oss << t.value(); + } + else + { + oss << "None"; + } + return oss.str(); +} + +} // namespace tensorrt_llm::common::stl_utils diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp new file mode 100644 index 0000000000..f1c6f88b43 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/stringUtils.h" +#include "tensorrt_llm/common/assert.h" + +#include +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +namespace +{ +std::string vformat(char const* fmt, va_list args) +{ + va_list args0; + va_copy(args0, args); + auto const size = vsnprintf(nullptr, 0, fmt, args0); + if (size <= 0) + return ""; + + std::string stringBuf(size, char{}); + auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args); + + TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno))); + + return stringBuf; +} + +} // namespace + +std::string fmtstr(char const* format, ...) +{ + va_list args; + va_start(args, format); + std::string result = vformat(format, args); + va_end(args); + return result; +}; + +std::unordered_set str2set(std::string const& input, char delimiter) +{ + std::unordered_set values; + if (!input.empty()) + { + std::stringstream valStream(input); + std::string val; + while (std::getline(valStream, val, delimiter)) + { + if (!val.empty()) + { + values.insert(val); + } + } + } + return values; +}; + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp new file mode 100644 index 0000000000..c00041abda --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "tensorrt_llm/common/timestampUtils.h" + +namespace tensorrt_llm::common +{ + +std::string getCurrentTimestamp() +{ + auto now = std::chrono::system_clock::now(); + auto now_t = std::chrono::system_clock::to_time_t(now); + auto tm = *std::localtime(&now_t); + + auto epoch_to_now = now.time_since_epoch(); + auto seconds = std::chrono::duration_cast(epoch_to_now); + auto us = std::chrono::duration_cast(epoch_to_now - seconds); + + std::ostringstream stream; + stream << std::put_time(&tm, "%m-%d-%Y %H:%M:%S"); + stream << "." << std::setfill('0') << std::setw(6) << us.count(); + return stream.str(); +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h new file mode 100644 index 0000000000..f52f23028c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace tensorrt_llm::common +{ + +/// @brief Get the current timestamp in the format "MM-DD-YYYY HH:MM:SS:uuuuuu" +std::string getCurrentTimestamp(); + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp new file mode 100644 index 0000000000..b410613d05 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/tllmException.h" +#include "tensorrt_llm/common/stringUtils.h" + +#include +#if !defined(_MSC_VER) +#include +#include +#include +#endif +#include + +namespace tensorrt_llm::common +{ + +namespace +{ +int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2; +} + +#if !defined(_MSC_VER) + +TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) + : std::runtime_error{""} +{ + mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES); + auto const trace = getTrace(); + std::runtime_error::operator=( + std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())}); +} +#else +TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) + : mNbFrames{} + , std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)} +{ +} +#endif + +TllmException::~TllmException() noexcept = default; + +std::string TllmException::getTrace() const +{ +#if defined(_MSC_VER) + return ""; +#else + auto const trace = backtrace_symbols(mCallstack.data(), mNbFrames); + std::ostringstream buf; + for (auto i = 1; i < mNbFrames; ++i) + { + Dl_info info; + if (dladdr(mCallstack[i], &info) && info.dli_sname) + { + auto const clearName = demangle(info.dli_sname); + buf << fmtstr("%-3d %*p %s + %zd", i, VOID_PTR_SZ, mCallstack[i], clearName.c_str(), + static_cast(mCallstack[i]) - static_cast(info.dli_saddr)); + } + else + { + buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace[i]); + } + if (i < mNbFrames - 1) + buf << std::endl; + } + + if (mNbFrames == MAX_FRAMES) + buf << std::endl << "[truncated]"; + + std::free(trace); + return buf.str(); +#endif +} + +std::string TllmException::demangle(char const* name) +{ +#if defined(_MSC_VER) + return name; +#else + std::string clearName{name}; + auto status = -1; + auto const demangled = abi::__cxa_demangle(name, nullptr, nullptr, &status); + if (status == 0) + { + clearName = demangled; + std::free(demangled); + } + return clearName; +#endif +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h b/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h new file mode 100644 index 0000000000..1406e82133 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace tensorrt_llm::common +{ + +std::uintptr_t constexpr kCudaMemAlign = 128; + +inline int8_t* alignPtr(int8_t* ptr, uintptr_t to) +{ + uintptr_t addr = (uintptr_t) ptr; + if (addr % to) + { + addr += to - addr % to; + } + return (int8_t*) addr; +} + +constexpr size_t alignSize(size_t size, size_t to) +{ + if ((size % to) != 0U) + { + size += to - size % to; + } + return size; +} + +inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment) +{ + uintptr_t addr = (uintptr_t) ptr; + addr += previousWorkspaceSize; + return alignPtr((int8_t*) addr, alignment); +} + +inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize) +{ + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign); +} + +inline int8_t* nextWorkspacePtr( + int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign) +{ + uintptr_t curr_offset = offset; + uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment; + int8_t* newptr = size == 0 ? nullptr : base + curr_offset; + offset = next_offset; + return newptr; +} + +inline int8_t* nextWorkspacePtrWithAlignment( + int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign) +{ + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment); +} + +inline size_t calculateTotalWorkspaceSize( + size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign) +{ + size_t total = 0; + for (int i = 0; i < count; i++) + { + total += workspaces[i]; + if (workspaces[i] % alignment) + { + total += alignment - (workspaces[i] % alignment); + } + } + return total; +} + +}; // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp new file mode 100644 index 0000000000..61a41031bf --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp @@ -0,0 +1,352 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +// Config + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10)) +#define CUTE_ARCH_RED_F16_SM70_ENABLED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +#define CUTE_ARCH_RED_VEC_SM90_ENABLED +#define CUTE_ARCH_RED_BF16_SM90_ENABLED +#endif + +namespace cute +{ + +////////////////////////////////// +// Wrapper around CUDA's atomicAdd +////////////////////////////////// + +template +struct TypedAtomicAdd +{ + using SRegisters = T[1]; + using DRegisters = T[1]; + + CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst) + { + atomicAdd(&dst, src); + } +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// +// F16 ADD PTX +////////////////////////////////// + +struct SM70_RED_ADD_NOFTZ_F16 +{ + using SRegisters = uint16_t[1]; + using DRegisters = uint16_t[1]; + + CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) + asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +struct SM70_RED_ADD_NOFTZ_F16x2 +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) + asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +struct SM90_RED_ADD_NOFTZ_F16x2_V2 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint64_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) + asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +struct SM90_RED_ADD_NOFTZ_F16x2_V4 +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void copy( + uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) + asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), + "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// +// BF16 ADD PTX +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16 +{ + using SRegisters = uint16_t[1]; + using DRegisters = uint16_t[1]; + + CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16x2 +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16x2_V2 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint64_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16x2_V4 +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void copy( + uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), + "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +} // end namespace cute diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h new file mode 100644 index 0000000000..2362da4f7f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once +#include "cutlass_extensions/weight_only_quant_op.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace arch +{ + +// Tag which triggers MMA which will trigger +struct OpMultiplyAddDequantizeInterleavedBToA; + +/* + Below we have extra tags to signal what kind of dequantization we want to do + (per col, scale only fine grained, finegrained with zero). This still lets us + the existing template infrastructure (incl. that in CUTLASS). However, we + split out the template below into OpMultiplyAddDequantizeInterleavedBToA along + with the quantization op before instantiating the GEMM pieces. + + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of + code we need to duplicate. + */ +struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; + +// The default just forwards the original operator +template +struct TagOperator +{ + using TaggedOperator = MmaOp; +}; + +// Specializations below attach more information to the operator +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +}; + +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +}; + +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; +}; + +// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original +// operator + the extra information. If no extra info was tagged, the dequant op per column scaling +// as a default. +template +struct DetagOperator +{ + using Operator = TaggedMmaOp; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +}; + +} // namespace arch +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h new file mode 100644 index 0000000000..c83a9a074d --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cutlass/device_kernel.h" +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ + +template +inline int compute_occupancy_for_kernel() +{ + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) + { + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + if constexpr (enable_cutlass_3x) + { + tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); + } + else + { + tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel)); + } + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) + { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) + // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this + // configuration. + return 0; + } + + if constexpr (enable_cutlass_3x) + { + tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute( + cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + else + { + tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute( + cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + } + + int max_active_blocks = -1; + if constexpr (enable_cutlass_3x) + { + tensorrt_llm::common::check_cuda_error( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel, + 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); + } + else + { + tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); + } + + return max_active_blocks; +} + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp new file mode 100644 index 0000000000..bba25ec23a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp @@ -0,0 +1,550 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" + +#include "cute/numeric/numeric_types.hpp" +#include "cute/tensor.hpp" +#include "cutlass/trace.h" + +#include "cutlass_extensions/arch/copy_red_global.hpp" +#include "cutlass_extensions/util/gather_tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace epilogue +{ +namespace collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class EpilogueMoeFusedFinalize +{ +public: + using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized; + using DispatchPolicy = PtrArrayNoSmemWarpSpecialized; + + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementIntermediate = typename ThreadEpilogueOp::ElementD; + + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + + static_assert(!is_same_v, "Stride C must be a pointer"); + static_assert(is_same_v, "Stride D must not be a pointer"); + + using CopyAtomR2S = Copy_Atom; + using CopyAtomS2R = Copy_Atom; + using CopyAtomR2G = Copy_Atom; + static constexpr int AlignmentD = CopyAtomR2G::NumValSrc; + + using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{})); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + + struct SharedStorage + { + alignas(SmemAlignmentD) cute::ArrayEngine> smem_D; + }; + + struct TensorMapStorage + { + }; + + struct Arguments + { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C{}; + StrideC dC{}; + ElementD* ptr_D{}; + StrideD dD{}; + ElementBias const* ptr_bias; + StrideBias dBias{}; + ElementScale const* ptr_scale; + StrideScale dScale{}; + int64_t const* group_offset{}; + int32_t const* scatter_index{}; + cutlass::FastDivmod num_rows_in_final_output; + }; + + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace) + { + return args; + } + + template + static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0) + { + return 0; + } + + template + static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, + void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) + { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement( + [[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args) + { + bool implementable = true; + if (problem_shape.is_host_problem_shape_available()) + { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shape.groups(); i++) + { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(M, N, L), InternalStrideD{}); + } + } + + if (!implementable) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global " + "reduction instruction.\n"); + } + return implementable; + } + + CUTLASS_HOST_DEVICE + EpilogueMoeFusedFinalize(Params const& params_) + : params(params_) + { + } + + CUTLASS_DEVICE + bool is_source_needed() + { + // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. + return params.ptr_C != nullptr + && (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0); + } + + template + CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, cute::Tensor const& accumulators, TiledMma tiled_mma, + ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + auto synchronize = [&]() + { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto mma_tile_m = tile_size<0>(tiled_mma); + auto mma_tile_n = tile_size<1>(tiled_mma); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + + // Batches are managed by using appropriate pointers to C and D matrices + int32_t const mock_L = 1; + int32_t const mock_l_coord = 0; + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op(params.thread, l_coord); + + SharedStorage& storage = *reinterpret_cast(smem_buf); + + Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{}); + Tensor sD = as_position_independent_swizzle_tensor(sD_); + + // Function to scatter output rows + auto& num_rows = params.num_rows_in_final_output; + auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord])); + auto get_scatter_idx = [&](auto i) + { + auto scatter = read_scatter_map(i); + int quot, rem; + num_rows(quot, rem, scatter); + return rem; + }; + + // Represent the full output tensor + ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr; + auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{}; + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l) + Tensor mD_mnl = make_gather_tensor( + make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l) + + // Use fake shape for bias, it doesn't matter + bool const is_bias_needed = params.ptr_bias != nullptr; + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias); + Tensor mScale_mnl = make_tensor( + make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale); + + Tensor gC_mnl + = local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl + = local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) + + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor gBias_mnl + = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gScale_mnl + = local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N) + Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N) + + Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Get the smallest tiled copy we can use to retile the accumulators + TiledCopy tiled_copy_C_atom + = make_tiled_copy_C_atom(Copy_Atom{}, tiled_mma); + TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom); + + auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N) + Tensor tRS_rD = make_tensor(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N) + + // Make a tiled copy vectorized along major direction of D + auto tiled_s2r = [&]() + { + if constexpr (cutlass::gemm::detail::is_k_major()) + { + constexpr int NumThreadsMajor = epi_tile_n / AlignmentD; + constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; + return make_tiled_copy(CopyAtomS2R{}, + Layout, Int>, Stride, _1>>{}, + Layout>>{}); + } + else if constexpr (cutlass::gemm::detail::is_mn_major()) + { + constexpr int NumThreadsMajor = epi_tile_m / AlignmentD; + constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; + return make_tiled_copy(CopyAtomS2R{}, + Layout, Int>, Stride<_1, Int>>{}, + Layout, _1>>{}); + } + else + { + static_assert(cute::is_void_v, "Unsupported D gmem layout."); + } + }(); + + auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); + Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + + // Allocate intermediate registers for a single subtile + Tensor tSR_rD = make_tensor(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rD_final = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rC = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rBias = make_tensor(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rScale = make_tensor(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) + + // Make an identity coordinate tensor for predicating our output MN tile + Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + + // epilogue subtile loop + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) + { + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) + { + int mma_m = (epi_m * epi_tile_m) / mma_tile_m; + int mma_n = (epi_n * epi_tile_n) / mma_tile_n; + Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n); + + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rD); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v) + { + tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v); + } + + copy(tiled_r2s, tRS_rD, tRS_sD); + synchronize(); + + copy(tiled_s2r, tSR_sD, tSR_rD); + synchronize(); + + Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n); + Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n); + Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n); + Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n); + Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n); + + if (epilogue_op.is_source_needed()) + { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_rD); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_rD); ++n) + { + if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) + { + copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n)); + if (is_bias_needed) + { + copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); + } + copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rD); ++i) + { + auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n)); + if (is_bias_needed) + { + epi_value += static_cast(tSR_rBias(i, m, n)); + } + tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); + } + copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); + } + } + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_rD); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_rD); ++n) + { + if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) + { + if (is_bias_needed) + { + copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); + } + copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rD); ++i) + { + auto epi_value = epilogue_op(tSR_rD(i, m, n)); + if (is_bias_needed) + { + epi_value += static_cast(tSR_rBias(i, m, n)); + } + tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); + } + copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); + } + } + } + } + } + } + } + +private: + Params params; +}; + +namespace detail +{ + +template +constexpr auto get_vectorized_atomic_add_op() +{ + using namespace cute; + + auto constexpr MaxVecSize = size(MaxVec{}); + + if constexpr (is_same_v) + { + if constexpr (MaxVecSize >= 8) + { + return SM90_RED_ADD_NOFTZ_F16x2_V4{}; + } + else if constexpr (MaxVecSize >= 4) + { + return SM90_RED_ADD_NOFTZ_F16x2_V2{}; + } + else if constexpr (MaxVecSize >= 2) + { + return SM70_RED_ADD_NOFTZ_F16x2{}; + } + else + { + return SM70_RED_ADD_NOFTZ_F16{}; + } + } + else if constexpr (is_same_v) + { + if constexpr (MaxVecSize >= 8) + { + return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; + } + else if constexpr (MaxVecSize >= 4) + { + return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; + } + else if constexpr (MaxVecSize >= 2) + { + return SM90_RED_ADD_NOFTZ_BF16x2{}; + } + else + { + return SM90_RED_ADD_NOFTZ_BF16{}; + } + } + else + { + // non-vectorized atomic add for all other types until supported + return TypedAtomicAdd{}; + } +} + +} // namespace detail + +template +struct EpilogueMoeFusedFinalizeBuilder +{ + + // assuming cooperative kernel schedule + using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{})); + using EpilogueTile = Shape<_128, EpiTileN>; + + // Output of linear combination is ElementCompute instead of ElementD + // since we will be doing more computate on it, no need to cast yet. + using ThreadEpilogueOp + = cutlass::epilogue::thread::LinearCombination; + + using SmemLayoutAtomD + = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()); + using CopyAtomR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator()); + using CopyAtomS2R = DefaultCopy; + using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op()); + + template + struct Sm90TmaWarpSpecializedAdapterWithSmemStorage : detail::Sm90TmaWarpSpecializedAdapter + { + // We need to override this one using declaration because otherwise we double up on the smem + using TensorMapStorage = typename EpilogueOp::TensorMapStorage; + + using Base = detail::Sm90TmaWarpSpecializedAdapter; + + CUTLASS_HOST_DEVICE + Sm90TmaWarpSpecializedAdapterWithSmemStorage( + typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors) + : Base(params) + { + } + + // These functions depend on the type of TensorMapStorage + template + CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t next_batch) + { + } + + template + CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate) + { + } + }; + + using CollectiveOp = Sm90TmaWarpSpecializedAdapterWithSmemStorage< + EpilogueMoeFusedFinalize>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h new file mode 100644 index 0000000000..f3c622b88a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h @@ -0,0 +1,105 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination with a maximum operation used by epilogues. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace epilogue +{ +namespace thread +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float copysignf_pos(float a, float b) +{ + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) +{ +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + float const exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +template <> +struct GELU_taylor +{ + static bool const kIsHeavy = true; + + CUTLASS_DEVICE + float operator()(float const& z) const + { + + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); + + return float(cutlass::constants::half() * z + * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const + { + return this->operator()(scalar); + } +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h new file mode 100644 index 0000000000..d3d4d0a45a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -0,0 +1,352 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_conversion.h" +#include "tensorrt_llm/common/quantization.h" + +namespace tk = tensorrt_llm::common; + +namespace cutlass +{ +namespace epilogue +{ +namespace threadblock +{ + +template +class EpilogueVisitorPerRowPerCol +{ +public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments + { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments() + : batch_stride_alpha(0) + , batch_stride_C(0) + , batch_stride_D(0) + { + } + + Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_) + , batch_stride_alpha(0) + , batch_stride_C(0) + , batch_stride_D(0) + { + } + + Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, + int64_t batch_stride_C_, int64_t batch_stride_D_) + : elementwise(elementwise_) + , batch_stride_alpha(batch_stride_alpha_) + , batch_stride_C(batch_stride_C_) + , batch_stride_D(batch_stride_D_) + { + } + }; + + struct Params + { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args) + : elementwise(args.elementwise) + , batch_stride_alpha(args.batch_stride_alpha) + , batch_stride_C(args.batch_stride_C) + , batch_stride_D(args.batch_stride_D) + { + } + }; + + /// Shared storage + struct SharedStorage + { + }; + +private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + bool const per_token_quant_; + bool const per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + +public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) + : params_(params) + , shared_storage_(shared_storage) + , extent_(problem_size) + , elementwise_(params.elementwise) + , per_token_quant_(quant_option.hasPerTokenScaling()) + , per_channel_quant_(quant_option.hasPerChannelScaling()) + , ptr_alpha_row_(ptr_alpha_row) + , ptr_alpha_col_(ptr_alpha_col) + , iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset) + , iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset) + , iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset) + , extent_real_(problem_size_real) + { + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) + { + iterator_C_.clear_mask(); + } + + if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) + { + element_alpha_col_ = *ptr_alpha_col_; + } + + if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) + { + element_alpha_row_ = *ptr_alpha_row_; + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) + { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) + { + iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() + { + if (per_channel_quant_) + { + iterator_alpha_col_.load(fragment_alpha_col_); + } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) + { + fragment_D_.clear(); + fragment_C_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) + { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) + { + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) + { + int thread_offset_row + = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); + + arch::global_load( + element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); + } + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) + { + + NumericArrayConverter source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) + { + ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; + result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); + } + else + { + result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) {} + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) + { + + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + +private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_( + ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) + { + + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) + { + result[i] = accum[i] * (scale_col[i] * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_( + ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) + { + + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) + { + result[i] = accum[i] * (scale_col * scale_row); + } + + return result; + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h new file mode 100644 index 0000000000..6f26d79017 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -0,0 +1,282 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h + +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/platform/platform.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/linear_combination_gelu.h" +#include "cutlass/epilogue/thread/linear_combination_hardswish.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_relu0.h" +#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" + +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace epilogue +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ + +/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template +struct DefaultIteratorsTensorOp +{ + using WarpTileIterator + = cutlass::epilogue::warp::TileIteratorTensorOpMixed; + + using SharedLoadIterator + = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; + + static int const kFragmentsPerIteration = 2; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template +class SharedLoadIteratorMixed +{ +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = int32_t; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + +private: + // + // Data members + // + + /// Byte-level pointer + LoadType const* pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed(TensorRef ref, int thread_idx) + : stride_((ref.stride(0) / LoadType::kElements)) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] = reinterpret_cast(ref.data()); + + int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; + + col_idx += (bank_offset + i) % kLoadsPerAccess; + + pointers_[i] += thread_offset.row() * stride_ + col_idx; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& offset) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] + += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const + { + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) + { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) + { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) + { + + int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx + = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) + { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) + { + + int vector_idx + = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); + + LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; + } + } + } + } + } + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment& frag) const + { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h new file mode 100644 index 0000000000..233d633a82 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h @@ -0,0 +1,141 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can signal to template + * code the type of epilogue we want to run, and let the underlying code specify the details such as + * element types, accumulator type and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "cutlass_extensions/epilogue/thread/fused_activations.h" +#include + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ + +struct EpilogueOpBiasSilu +{ +}; + +struct EpilogueOpBiasReLU +{ +}; + +struct EpilogueOpBiasFtGelu +{ +}; + +struct EpilogueOpBias +{ +}; + +struct EpilogueOpDefaultSilu +{ +}; + +struct EpilogueOpDefaultReLU +{ +}; + +struct EpilogueOpDefaultFtGelu +{ +}; + +struct EpilogueOpDefault +{ +}; + +template +struct Epilogue +{ + static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); +}; + +constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl new file mode 100644 index 0000000000..593eca06e3 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl @@ -0,0 +1,221 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout stage_count) +{ + // 32 bytes to account for barriers etc. + constexpr int stage_barrier_bytes = 32; + constexpr int a_bits = static_cast(sizeof_bits::value); + constexpr int b_bits = static_cast(sizeof_bits::value); + constexpr int stage_bytes = [&]() -> int + { + if constexpr (SwapAB) + { + return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes; + } + else + { + return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes; + } + }(); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS +template class Activation, bool SwapAB> +struct CollectiveBuilderGated + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v) &¬ detail:: + is_use_rmem_A()>> +{ + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsArrayOfPointersGemm + = (cute::is_same_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), + "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); + + // For fp32 types, map to tf32 MMA value type + using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; + using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t + || IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages + = detail::compute_stage_count_or_override_gated(StageCountType{}); + using DispatchPolicy = cute::conditional_t, + /* For FP8 use a separate mainloop compared to other datatypes */ + cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMmaGated, + ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS +template class Activation, bool SwapAB> +struct CollectiveBuilderGated + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v>> +{ + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert( + detail::is_input_fp8(), "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsArrayOfPointersGemm + = (cute::is_same_v); + using AtomLayoutMNK + = cute::conditional_t + || IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages + = detail::compute_stage_count_or_override_gated(StageCountType{}); + using DispatchPolicy = cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMmaGated, + ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp new file mode 100644 index 0000000000..2f2422c991 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp" + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class Activation, + bool SwapAB = false, class Enable = void> +struct CollectiveBuilderGated +{ + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp new file mode 100644 index 0000000000..d850f36df5 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp @@ -0,0 +1,59 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class Activation, bool SwapAB = false> +struct CollectiveMmaGated +{ + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 0000000000..dcba6ee637 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,642 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template class Activation_, bool SwapAB_> +struct CollectiveMmaGated, TileShape_, + ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, + GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> +{ + static constexpr bool isGated = true; + static constexpr bool SwapAB = SwapAB_; + + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using Activation = Activation_; + + using ElementAux = cute::conditional_t; + using ValTypeAux = cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutAux = cute::conditional_t; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value + && cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + using InternalElementAux = cute::conditional_t; + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned> smem_Aux; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments + { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params + { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Aux = cute::conditional_t; + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Aux tma_load_aux; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, void* workspace) + { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + if constexpr (SwapAB) + { + auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; + } + else + { + auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; + } + } + + template + static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) + { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + + if (!implementable) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes + = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) + / 8; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const + { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (SwapAB) + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + else + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) + { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) + { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id + = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) + : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAuxgAux = block_tma_aux.partition_S(gAux); + Tensor tAuxsAux = block_tma_aux.partition_D(sAux); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_aux = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) + { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) + { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + if constexpr (SwapAB) + { + mcast_mask_aux = mcast_mask_a; + } + else + { + mcast_mask_aux = mcast_mask_b; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) + { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) + { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, + FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, + Params const& mainloop_params) + { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutAux{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + auto tCsAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.partition_A(sAux); + } + else + { + return thread_mma.partition_B(sAux); + } + }(); + auto tCrAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.make_fragment_A(tCsAux); + } + else + { + return thread_mma.make_fragment_B(tCsAux); + } + }(); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) + { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } + else + { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); + if constexpr (SwapAB) + { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); + } + else + { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); + if constexpr (SwapAB) + { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); + } + else + { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) + { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) + { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp new file mode 100644 index 0000000000..72c1adf293 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp @@ -0,0 +1,665 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/collective/fp8_accumulation.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template class Activation_, bool SwapAB_> +struct CollectiveMmaGated, TileShape_, + ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, + GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> +{ + static constexpr bool isGated = true; + static constexpr bool SwapAB = SwapAB_; + + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using Activation = Activation_; + + using ElementAux = cute::conditional_t; + using ValTypeAux = cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutAux = cute::conditional_t; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value + && cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned> smem_Aux; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments + { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params + { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Aux = cute::conditional_t; + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Aux tma_load_aux; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, void* workspace) + { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + if constexpr (SwapAB) + { + auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + } + else + { + auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + } + } + + template + static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) + { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + /* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA + * instructions. */ + implementable = implementable && (args.mma_promotion_interval % 4 == 0); + + if (!implementable) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes + = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) + / 8; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const + { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (SwapAB) + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + else + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) + { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) + { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id + = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) + : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAuxgAux = block_tma_aux.partition_S(gAux); + Tensor tAuxsAux = block_tma_aux.partition_D(sAux); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_aux = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) + { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) + { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + if constexpr (SwapAB) + { + mcast_mask_aux = mcast_mask_a; + } + else + { + mcast_mask_aux = mcast_mask_b; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) + { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) + { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, + FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, + Params const& mainloop_params) + { + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + auto tCsAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.partition_A(sAux); + } + else + { + return thread_mma.partition_B(sAux); + } + }(); + auto tCrAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.make_fragment_A(tCsAux); + } + else + { + return thread_mma.make_fragment_B(tCsAux); + } + }(); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) + { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } + else + { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8Accumulation accumulation0(accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + GmmaFP8Accumulation accumulation1(accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation0.prepare_if_needed()) + { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); + if constexpr (SwapAB) + { + cute::gemm( + tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); + } + else + { + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + accumulation0.promote_if_needed(); + accumulation1.promote_if_needed(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + if (accumulation0.prepare_if_needed()) + { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); + if constexpr (SwapAB) + { + cute::gemm( + tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); + } + else + { + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + + accumulation0.promote_if_needed(); + accumulation1.promote_if_needed(); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation0.promote_residue_if_needed(); + accumulation1.promote_residue_if_needed(); + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) + { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) + { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h new file mode 100644 index 0000000000..2edd5a228b --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -0,0 +1,438 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// #include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace device +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/* + This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) + It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs + and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support + that feature at the moment. + */ + +template +class GemmUniversalBaseCompat +{ +public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + +protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + +protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) + { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + int const kAlignK + = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) + { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + +public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) + { + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) + { + + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } + else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) + { + + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); + + return result; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + if (smem_size <= (48 << 10)) + { + + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); + + if (result == cudaSuccess) + { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } + else + { + + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); + + if (result != cudaSuccess) + { + + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + + return -1; + } + + if (smem_capacity < 0) + { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) + { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) + { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; + } + + CUTLASS_TRACE_HOST(" returning internal error"); + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) + { + + if (!workspace) + { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) + { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) + { + cudaError_t result + = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) + { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) + { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) + { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) + { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h new file mode 100644 index 0000000000..bfd3666b9c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h @@ -0,0 +1,542 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace device +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk, + int64_t* splitk_buffer_offsets) +{ + // in_tensor: [problem_idx, k_partition, hidden_size] + // Note that different requests of in_tensor might have different hidden_size (=m*n) + // so, we need to use splitk_buffer_offsets. + // out_tensor: problem_idx * [hidden_size] + + int const problem_idx = blockIdx.y; + GemmCoord problem = problem_sizes[problem_idx]; + int const hidden_size = problem.m() * problem.n(); + const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; + T_OUT* out_tensor_ = out_tensor[problem_idx]; + + for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) + { + float sum = 0.0f; + for (int k_idx = 0; k_idx < splitk; k_idx++) + { + sum += (float) in_tensor_[k_idx * hidden_size + i]; + } + out_tensor_[i] = (T_OUT) (sum); + } +} + +/// GEMM Grouped +template +class BaseSplitkGrouped +{ +public: + using BaseKernel = BaseKernel_; + + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + /// Argument structure + using Arguments = typename BaseKernel::Arguments; + + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + +protected: + /// Kernel parameters object + typename BaseKernel::Params gemm_params_; + +private: + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) + { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) + { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes) + { + cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); + if (cuda_error != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const& args, int32_t tile_count, void* workspace) + { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute( + args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data()); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, std::vector const& indices) + { + // For now, simply create a copy of the data and then copy over to the original. + std::vector copy(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) + { + copy.at(i) = data[indices[i]]; + } + + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + +public: + /// Constructs the GEMM. + BaseSplitkGrouped() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) + { + + return BaseKernel::can_implement(args); + } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) + { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const& args) + { + if (args.host_problem_sizes == nullptr) + { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; + } + + return group_tile_count(args.host_problem_sizes, args.problem_count); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) + { + size_t total_mn = 0; + for (int i = 0; i < args.problem_count; i++) + { + total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); + } + size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices; + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) + { + workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size( + args.host_problem_sizes, args.problem_count, args.threadblock_count); + } + return workSpaceSize; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) + { + + return dim3(args.threadblock_count, 1, 1); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) + { + + CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) + { + result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, BaseKernel::kThreadCount, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr, + int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) + { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), + [&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient( + cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) + { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result)); + return 0; + } + + int multiprocessor_count; + result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result)); + return 0; + } + + bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); + if (override_sm_count) + { + available_sm_count = multiprocessor_count; + } + + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) + { + return 0; + } + + int occupancy_based_block_count = available_sm_count * max_active_blocks; + + if (problem_sizes_ptr == nullptr || problem_count == 0) + { + return occupancy_based_block_count; + } + + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); + + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return total_tiles + // unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) + { + return total_tiles; + } + + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating through + // problem sizes to determine that they have no work to do. This competes for cycles + // with those threadblocks that are assigned tiles to compute. + return std::min(total_tiles, occupancy_based_block_count); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) + { + + CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Workspace + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) + { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) + { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) + { + return status; + } + + gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); + } + else + { + gemm_params_ = typename BaseKernel::Params(args, workspace); + } + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) + { + cudaError_t result + = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) + { + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) + { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) + { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) + { + return status; + } + + gemm_params_.update(args, workspace, tile_count); + } + else + { + gemm_params_.update(args, workspace); + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) + { + if (!gemm_params_.problem_visitor.problem_count) + { + return Status::kSuccess; + } + + // + // Launch kernel + // + + // Launch splitk grouped gemm + { + dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + cutlass::Kernel<<>>(gemm_params_); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + // Launch splitkReduction + { + dim3 grid(32, gemm_params_.problem_visitor.problem_count); + dim3 block(256); + splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split, + gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices, + gemm_params_.splitk_buffer_offsets); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) + { + return run(stream); + } + + /// Initializes and runs the kernel. + Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) + { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) + { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class SplitkGemmGrouped : public BaseSplitkGrouped +{ +public: + using GemmKernel = GemmKernel_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h new file mode 100644 index 0000000000..100a1161a8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -0,0 +1,162 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/half.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct MixedGemmArchTraits +{ + static_assert(dependent_false, "Unrecognised parameterization"); +}; + +template +struct MixedGemmArchTraits +{ + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::ColumnMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ======================= Turing Traits ============================== +// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ada Traits ============================== +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +// FP8 A/B = fp8, C/D = fp32 +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + // be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t + using TypeC = __nv_bfloat16; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h new file mode 100644 index 0000000000..3fd722994e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -0,0 +1,57 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +}; + +// ======================= Turing Traits ============================== +template <> +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +}; + +// ======================= Ampere Traits ============================== +template <> +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h new file mode 100644 index 0000000000..1dbd0b1765 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h @@ -0,0 +1,207 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" + +#include "cutlass/layout/permute.h" + +#include "splitk_gemm_grouped.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, + /// Operation performed by GEMM + typename Operator = typename device::DefaultGemmConfiguration::Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// + typename Enable = void> +struct DefaultSplitkGemmGrouped; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Permute result D + typename PermuteDLayout> +struct DefaultSplitkGemmGrouped::value>::type> +{ + + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = kernel::detail::MapArguments; + + // Define the default GEMM kernel + using DefaultGemmKernel = typename kernel::DefaultGemm::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::SplitkGemmGrouped; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h new file mode 100644 index 0000000000..0baec58ea9 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -0,0 +1,566 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ +template +inline constexpr bool dependent_false_v = false; +} + +template +struct GemmFpAIntB +{ + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments + { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, + typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size) + , group_size(group_size) + , ref_A(ref_A) + , ref_B(ref_B) + , ref_scale(ref_scale) + , ref_zero(ref_zero) + , ref_C(ref_C) + , ref_D(ref_D) + , batch_count(serial_split_k_factor) + , output_op(output_op) + , gather_A_indices(gather_A_indices) + , gather_B_indices(gather_B_indices) + , scatter_D_indices(scatter_D_indices) + { + } + }; + + /// Parameters structure + struct Params + { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0) + , semaphore(0) + , gemm_k_size(0) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size) + , group_size(args.group_size) + , grid_tiled_shape(grid_tiled_shape) + , swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)) + , params_A(args.ref_A.layout()) + , ref_A(args.ref_A) + , params_B(args.ref_B.layout()) + , ref_B(args.ref_B) + , params_scale(args.ref_scale.layout()) + , ref_scale(args.ref_scale) + , ref_zero(args.ref_zero) + , params_C(args.ref_C.layout()) + , ref_C(args.ref_C) + , params_D(args.ref_D.layout()) + , ref_D(args.ref_D) + , output_op(args.output_op) + , semaphore(static_cast(workspace)) + , gemm_k_size(gemm_k_size) + , gather_A_indices(args.gather_A_indices) + , gather_B_indices(args.gather_B_indices) + , scatter_D_indices(args.scatter_D_indices) + { + } + }; + + /// Shared memory storage structure + union SharedStorage + { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const& args) + { + static int const kAlignmentA + = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB + = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; + + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, kAlignmentA)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, kAlignmentB)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, kAlignmentC)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, kAlignmentC)) + { + return Status::kErrorMisalignedOperand; + } + + if (!args.ref_scale.good()) + { + return Status::kErrorNotSupported; + } + + if constexpr (hasZero(Mma::QuantOp)) + { + if (!args.ref_zero.good()) + { + return Status::kErrorNotSupported; + } + } + else + { + if (args.ref_zero.good()) + { + return Status::kErrorNotSupported; + } + } + + if constexpr (isFinegrained(Mma::QuantOp)) + { + if (args.group_size != 64 && args.group_size != 128) + { + return Status::kErrorNotSupported; + } + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + + // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator + // has a different constructor signature than a regular cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) + { + + return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) + { + + return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) + { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && kInterleave == 1 + || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() + || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) + { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); + + typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, + params.gather_B_indices); + + typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = initialize_scale( + params.params_scale, params.ref_scale.data(), params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) + { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) + { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) + { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else + { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) + { + if constexpr (platform::is_same::value) + { + run_kernel_(params, shared_storage); + } + else + { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ == 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. +#else + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh new file mode 100644 index 0000000000..1bd0a3f11a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh @@ -0,0 +1,218 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +namespace fused_moe +{ +template +struct Fused_Moe_Kernel_sm80 +{ + static constexpr int kMaxTileM = MaxTileM_; + static constexpr int kTileN = isGateActivation(activation_type_) ? TileN_ / 2 : TileN_; + static constexpr int kTileK = TileK_; + static constexpr int kStages = Stages_; + static constexpr Activation_Type activation_type = activation_type_; + + using ElementInput = ElementInput_; + using ElementWeight = ElementWeight_; + using ElementOutput = ElementOutput_; + using BaseKernelTraits = Fused_Moe_Kernel_traits_sm80; + using Routine_Arguments = Routine_Arguments; + using Routine_Params = Routine_Params; + using ProblemVisitor + = cutlass::gemm::kernel::MoeProblemVisitor, false>, + cutlass::gemm::GemmShape, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + BaseKernelTraits::kThreadCount, BaseKernelTraits::kThreadCount>; + + struct Arguments + { + Routine_Arguments routine_args; + int problem_count{}; + int threadblock_count{}; + }; + + struct Params + { + Routine_Params routine_params; + int threadblock_count{}; + typename ProblemVisitor::Params problem_visitor_param; + }; + + using BaseKernelTraits_m16 = Fused_Moe_Kernel_traits_sm80; + static constexpr bool use_m16 = TileK_ >= 64; // use tileshape m = 16 when original tileshape k >= 64 + + static constexpr int kSmemSize = use_m16 + ? (BaseKernelTraits::kSmemSize > BaseKernelTraits_m16::kSmemSize ? BaseKernelTraits::kSmemSize + : BaseKernelTraits_m16::kSmemSize) + : BaseKernelTraits::kSmemSize; + static constexpr int kThreadCount = BaseKernelTraits::kThreadCount; + + static constexpr bool can_implement(int const avaliable_smem_size) + { + return BaseKernelTraits::can_implement(avaliable_smem_size); + } + + static Params to_underlying_arguments(Arguments const& args) + { + return { + {args.routine_args.ptr_input, args.routine_args.ptr_fc1, args.routine_args.ptr_bias, + args.routine_args.ptr_output, args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, + args.routine_args.gemm_k, args.routine_args.num_expert, args.routine_args.bias_is_broadcast}, + args.threadblock_count, + {args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, args.routine_args.gemm_k, + args.problem_count, nullptr, 0}}; + } + + CUTE_DEVICE + void run_device(Params const& params) + { +#define ROUTINE_PATH(kTileM_size) \ + { \ + constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \ + using RoutineTraits = Fused_Moe_Kernel_routine_sm80; \ + RoutineTraits routine{}; \ + int const block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \ + routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \ + } + typename ProblemVisitor::SharedStorage dummy_storage{}; + ProblemVisitor problem_visitor(params.problem_visitor_param, dummy_storage, blockIdx.x); + while (problem_visitor.next_tile()) + { + auto problem_size = problem_visitor.problem_size(); + auto grid_size = problem_visitor.grid_shape(problem_size); + auto problem_index = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + int const gemm_m = problem_size.m(); + const int32_t block_m_idx_temp = cta_idx / grid_size.n(); + const int32_t block_n_idx = cta_idx % grid_size.n(); + + int const residue_m = gemm_m - kMaxTileM * block_m_idx_temp; + if (residue_m > kMaxTileM / 2) + { + using RoutineTraits = Fused_Moe_Kernel_routine_sm80; + RoutineTraits routine{}; + routine.run_routine(params.routine_params, problem_index, block_m_idx_temp, block_n_idx, gemm_m); + } + else + { + + if constexpr (kMaxTileM >= 128) + { + if (residue_m > 32) + { + ROUTINE_PATH(64); + } + else if (residue_m > 16) + { + ROUTINE_PATH(32); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + else if (kMaxTileM == 64) + { + if (residue_m > 16) + { + ROUTINE_PATH(32); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + else if (kMaxTileM == 32) + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + problem_visitor.advance(gridDim.x); + } +#undef ROUTINE_PATH + } +}; + +template +__global__ void run_global(__grid_constant__ typename GemmType::Params const params) +{ + GemmType gemm; + gemm.run_device(params); +} + +/// Computes the maximum number of active blocks per multiprocessor +template +static int fused_gemm_maximum_active_blocks(int smem_capacity = -1) +{ + + CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); + + constexpr int smem_size = GemmType::kSmemSize; + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) + { + result = cudaFuncSetAttribute(run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, run_global, GemmType::kThreadCount, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; +} +} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh new file mode 100644 index 0000000000..4c46a541ef --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh @@ -0,0 +1,799 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace fused_moe +{ + +template +struct Fused_Moe_Kernel_routine_sm80; + +template +struct Fused_Moe_Kernel_routine_sm80> +{ + using KT = Fused_Moe_Kernel_traits_sm80; + using Params = Routine_Params; + + CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) + { + using X = cute::Underscore; + + int const M = gemm_m; + int const N1 = params.gemm_n; + int const K1 = params.gemm_k; + bool const bias_is_broadcast = params.bias_is_broadcast; + + int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]); + typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; + typename KT::ElementWeight const* ptr_fc1_gate_ + = params.ptr_fc1 + (2 * problem_index + 1) * N1 * K1; // TODO: we only focus on gated activation.. + typename KT::ElementWeight const* ptr_fc1_ + = params.ptr_fc1 + 2 * problem_index * N1 * K1; // TODO: we only focus on gated activation.. + typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr) + ? nullptr + : (bias_is_broadcast ? params.ptr_bias + 2 * problem_index * N1 : params.ptr_bias + 2 * row_jump * N1); + typename KT::ElementInput const* ptr_bias_gate_ = (params.ptr_bias == nullptr) + ? nullptr + : (bias_is_broadcast ? params.ptr_bias + (2 * problem_index + 1) * N1 + : params.ptr_bias + (2 * row_jump + 1) * N1); + typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; + + cute::Tensor mInput_mk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), + cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_gate_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_gate_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mBias_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), + cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2, + cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mBias_gate_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_gate_)), cute::make_shape(M, N1), + cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2, + cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mOutput_mn + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), + cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); + + cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) + cute::Tensor gfc1_gate_nk = cute::local_tile(mfc1_gate_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + + cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gBias_gate_mn = cute::local_tile(mBias_gate_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + return cute::make_tuple(gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn); + } + + // be careful, m_idx will change when use another tile shape.. + CUTE_DEVICE void run_routine( + Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) + { + extern __shared__ char smem_[]; + typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); + int const thread_idx = threadIdx.x; + bool const bias_is_broadcast = params.bias_is_broadcast; + // gmem tensor partition .. + auto [gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn] + = gmem_tensor_init(problem_index, gemm_m, params); + int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); + auto const n_tile_count = cute::size<2>(gfc1_gate_nk); + + // smem tensor .. + cute::Tensor sInput = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) + cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sfc1_gate_weight + = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_gate_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sO = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) + + // (1) first step, get the fc1_res and fc1_gate + + // (1.1) get partition for gmem -> smem + cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) + cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + cute::Tensor gfc1g = gfc1_gate_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + + typename KT::GmemTiledCopyA gmem_tiled_copy_A; + typename KT::GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) + cute::Tensor tInputsInput = gmem_thr_copy_A.partition_D(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) + cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + cute::Tensor tfc1ggfc1g = gmem_thr_copy_B.partition_S(gfc1g); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1gsfc1g = gmem_thr_copy_B.partition_D(sfc1_gate_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + + // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) + cute::Tensor tInputpInput + = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), + cute::Stride{}); + // Construct identity layout for sInput + cute::Tensor cInput = make_identity_tensor( + make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<0>(tInputpInput); ++m) + { + tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m + } + + // (1.2) prefetch gmem -> smem + cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. + auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 + int k_tile_count = cute::size<2>(gInput); + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) + { + if (k_tile_count <= 0) + { + cute::clear(tInputpInput); + } + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + // use copy_if + cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1gsfc1g(cute::_, cute::_, cute::_, k_pipe)); + cute::cp_async_fence(); + k_tile_count--; + if (k_tile_count > 0) + { + ++k_tile_iter; + } + } + + // (1.3) get partition for rf + typename KT::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) + cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + cute::Tensor tOrfc1g = thr_mma.partition_fragment_B(sfc1_gate_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + + cute::Tensor accum + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::Tensor accum_gate + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::clear(accum); + cute::clear(accum_gate); + // checkout the shape + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum_gate)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum_gate)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum_gate)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1g)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); + + // (1.4)retiling the smem and rf for copy.. + auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) + cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K + + auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) + cute::Tensor tOsfc1g = smem_thr_copy_B.partition_S(sfc1_gate_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1g_copy_view = smem_thr_copy_B.retile_D(tOrfc1g); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1g) == cute::size<1>(tOrfc1g_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1g) == cute::size<2>(tOrfc1g_copy_view)); // CPY_K + + // (1.5) mainloop + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = KT::Stages - 1; + + cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); + + constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); + // prefetch register pipeline + if constexpr (K_BLOCK_MAX > 1) + { + cute::cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), + tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1g_copy_view(cute::_, cute::_, cute::Int<0>{})); + } + // k loop for mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), + tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy_if(gmem_tiled_copy_A, tInputpInput, + tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1gsfc1g(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::cp_async_fence(); + if (k_tile_count - 1 > 0) + { + ++k_tile_iter; + } + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), + accum); + cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), + tOrfc1g(cute::_, cute::_, k_block), accum_gate); + }); + } + + // load tail + cute::for_each(cute::make_int_sequence{}, + [&](auto WaitIndex) + { + k_tile_count--; + using WaitIndex_t = decltype(WaitIndex); + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), + tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); + if (k_block == 0) + { + // only update smem_pipe_read + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), + tOrfc1(cute::_, cute::_, k_block), accum); + cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), + tOrfc1g(cute::_, cute::_, k_block), accum_gate); + }); + }); + // mma tail + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), + tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); + // Thread-level register gemm for k_block + cute::gemm( + tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum); + cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), + tOrfc1g(cute::_, cute::_, k_block), accum_gate); + }); + // if (cute::thread0()) { + // cute::print(accum_gate(0, 0, 0)); + // printf("\n"); + // } + // (2) add bias if it has.. + if (params.ptr_bias != nullptr) + { + cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); + cute::Tensor gBias_gate = gBias_gate_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); + cute::Tensor tOgBias = thr_mma.partition_C(gBias); + cute::Tensor tOgBiasg = thr_mma.partition_C(gBias_gate); + for (int i = 0; i < cute::size(accum); i++) + { + accum(i) += tOgBias(i); + accum_gate(i) += tOgBiasg(i); + } + } + + // (3) calculate swiglu + using ActivationFn = typename KT::ActivationFn; + ActivationFn fn{}; + CUTLASS_PRAGMA_UNROLL + for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) + { + accum(temp_iter) = fn(accum_gate(temp_iter)) * accum(temp_iter); + } + + // (4) push all the result to smem + // (4.1) convert result from ElementAccum to ElementInput + cute::Tensor temp_accum = util_convert_type(accum); + // if (cute::thread0()) { + // cute::print(temp_accum(0, 0, 0)); + // printf("\n"); + // } + // (4.2) retile rf and smem for copy back.. + auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + // cute::clear(sO); + cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); + cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); + + // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) + cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); + __syncthreads(); + + // (4.4) sO -> rO -> gO + + typename KT::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // + // remember, for all the threads in the same col, they have the same idx for bias.. + cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); + // cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row.. + auto tOsO = gmem_thr_copy_O.partition_S(sO); + auto tOgO = gmem_thr_copy_O.partition_D(gO); + // auto tOgBias = gmem_thr_copy_O.partition_D(gBias); + cute::Tensor cOutput = cute::make_identity_tensor( + cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); + cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(tOgO); ++m) + { + if (cute::get<0>(tOcO(0, m, 0)) < residue_m) + { + cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); + } + } + } +}; + +template +struct Fused_Moe_Kernel_routine_sm80> +{ + + using KT = Fused_Moe_Kernel_traits_sm80; + using Params = Routine_Params; + + CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) + { + using X = cute::Underscore; + + int const M = gemm_m; + int const N1 = params.gemm_n; + int const K1 = params.gemm_k; + bool const bias_is_broadcast = params.bias_is_broadcast; + + int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]); + typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; + typename KT::ElementWeight const* ptr_fc1_ = params.ptr_fc1 + problem_index * N1 * K1; + typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr) + ? nullptr + : (bias_is_broadcast ? params.ptr_bias + problem_index * N1 : params.ptr_bias + row_jump * N1); + typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; + + cute::Tensor mInput_mk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), + cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mBias_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), + cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1, + cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mOutput_mn + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), + cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); + + cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) + cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + + cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + return cute::make_tuple(gInput_mk, gfc1_nk, gBias_mn, gOutput_mn); + } + + // be careful, m_idx will change when use another tile shape.. + CUTE_DEVICE void run_routine( + Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) + { + extern __shared__ char smem_[]; + typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); + int const thread_idx = threadIdx.x; + bool const bias_is_broadcast = params.bias_is_broadcast; + // gmem tensor partition .. + auto [gInput_mk, gfc1_nk, gBias_mn, gOutput_mn] = gmem_tensor_init(problem_index, gemm_m, params); + int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); + auto const n_tile_count = cute::size<2>(gfc1_nk); + + // smem tensor .. + cute::Tensor sInput = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) + cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sO = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) + + // (1) first step, get the fc1_res and fc1_gate + + // (1.1) get partition for gmem -> smem + cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) + cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + + typename KT::GmemTiledCopyA gmem_tiled_copy_A; + typename KT::GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) + cute::Tensor tInputsInput = gmem_thr_copy_A.partition_S(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) + cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + + // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) + cute::Tensor tInputpInput + = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), + cute::Stride{}); + // Construct identity layout for sInput + cute::Tensor cInput = make_identity_tensor( + make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<0>(tInputpInput); ++m) + { + tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m + } + + // (1.2) prefetch gmem -> smem + cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. + auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 + int k_tile_count = cute::size<2>(gInput); + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) + { + if (k_tile_count <= 0) + { + cute::clear(tInputpInput); + } + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + // use copy_if + cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); + cute::cp_async_fence(); + k_tile_count--; + if (k_tile_count > 0) + { + ++k_tile_iter; + } + } + + // (1.3) get partition for rf + typename KT::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) + cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + + cute::Tensor accum + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::clear(accum); + // checkout the shape + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); + + // (1.4)retiling the smem and rf for copy.. + auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) + cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K + + auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K + + // (1.5) mainloop + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = KT::Stages - 1; + + cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + + constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); + // prefetch register pipeline + if constexpr (K_BLOCK_MAX > 1) + { + cute::cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), + tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); + } + // k loop for mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy_if(gmem_tiled_copy_A, tInputpInput, + tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::cp_async_fence(); + if (k_tile_count - 1 > 0) + { + ++k_tile_iter; + } + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), + accum); + }); + } + // load tail + cute::for_each(cute::make_int_sequence{}, + [&](auto WaitIndex) + { + k_tile_count--; + using WaitIndex_t = decltype(WaitIndex); + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + if (k_block == 0) + { + // only update smem_pipe_read + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), + tOrfc1(cute::_, cute::_, k_block), accum); + }); + }); + // mma tail + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + // Thread-level register gemm for k_block + cute::gemm( + tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum); + }); + // if (cute::thread0()) { + // cute::print(accum_gate(0, 0, 0)); + // printf("\n"); + // } + // (2) add bias if it has.. + if (params.ptr_bias != nullptr) + { + cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); + cute::Tensor tOgBias = thr_mma.partition_C(gBias); + for (int i = 0; i < cute::size(accum); i++) + { + accum(i) += tOgBias(i); + } + } + // (3) calculate swiglu + using ActivationFn = typename KT::ActivationFn; + ActivationFn fn{}; + CUTLASS_PRAGMA_UNROLL + for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) + { + accum(temp_iter) = fn(accum(temp_iter)); + } + + // (4) push all the result to smem + // (4.1) convert result from ElementAccum to ElementInput + cute::Tensor temp_accum = util_convert_type(accum); + // if (cute::thread0()) { + // cute::print(temp_accum(0, 0, 0)); + // printf("\n"); + // } + // (4.2) retile rf and smem for copy back.. + auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + // cute::clear(sO); + cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); + cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); + + // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) + cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); + __syncthreads(); + + // (4.4) sO -> rO -> gO + + typename KT::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // + cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); + auto tOsO = gmem_thr_copy_O.partition_S(sO); + auto tOgO = gmem_thr_copy_O.partition_D(gO); + cute::Tensor cOutput = cute::make_identity_tensor( + cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); + cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(tOgO); ++m) + { + if (cute::get<0>(tOcO(0, m, 0)) < residue_m) + { + cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); + } + } + } +}; + +} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh new file mode 100644 index 0000000000..b4c90085db --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh @@ -0,0 +1,215 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace fused_moe +{ +template +struct Routine_Arguments +{ + ElementInput* ptr_input{}; + ElementWeight* ptr_fc1{}; + ElementInput* ptr_bias{}; + ElementOutput* ptr_output{}; + int64_t const* total_tokens_including_expert{}; + int gemm_n{}; + int gemm_k{}; + int num_expert{}; + bool bias_is_broadcast{}; +}; + +template +struct Routine_Params +{ + ElementInput* ptr_input{}; + ElementWeight* ptr_fc1{}; + ElementInput* ptr_bias{}; + ElementOutput* ptr_output{}; + int64_t const* total_tokens_including_expert{}; + int gemm_n{}; + int gemm_k{}; + int num_expert{}; + bool bias_is_broadcast{}; +}; + +enum class Activation_Type +{ + Gelu = 0, + Relu, + Silu, + Swiglu, + Geglu, + Identity, + InvalidType +}; + +constexpr bool isGateActivation(Activation_Type const& activation_type) +{ + return activation_type == Activation_Type::Swiglu || activation_type == Activation_Type::Geglu; +} + +template +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::InvalidType; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::Identity; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::Relu; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool is_gate) +{ + return is_gate ? Activation_Type::Swiglu : Activation_Type::Silu; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool is_gate) +{ + return is_gate ? Activation_Type::Geglu : Activation_Type::Gelu; +} + +/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/ +template +struct Fused_Moe_Kernel_traits_sm80 +{ + using ElementInput = ElementInput_; + using ElementWeight = ElementWeight_; + using ElementAccum = float; + using ElementOutput = ElementOutput_; + + using index_t = uint32_t; + static_assert(TileM_ % 16 == 0); + static_assert(TileN_ % 32 == 0); + static_assert(TileK_ % 32 == 0); + static constexpr int Stages = Stages_; + static constexpr int kTileM = TileM_; + static constexpr int kTileN = TileN_; + static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64); + + // tile shape + using TileShape = cute::Shape, cute::Int, cute::Int>; + static constexpr int kWarpsCount = 4; + static constexpr int kThreadCount = kWarpsCount * 32; + + // MMA atom arch and layout + using MMA_Atom_Arch = std::conditional_t, + cute::MMA_Atom, cute::MMA_Atom>; + // using ValLayoutMNK = cute::Layout>; + using ThreadLayoutMNK + = std::conditional_t, cute::_1>>, + cute::Layout, cute::_1>>>; + using ValLayoutMNK = std::conditional_t, + cute::Tile>; + using TiledMma = cute::TiledMMA; // 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4 + static constexpr int kAlignment = 8; + static constexpr int kBlcokKSmem = (kTileM == 16) ? 64 : 32; + // A memory copy operand + using DefaultOperandA + = DefaultGemm_TensorOpSm80_OperandA; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B memory copy operand + using DefaultOperandB + = DefaultGemm_TensorOpSm80_OperandB; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Output memory copy operand + using SmemLayoutAtomO = SmemLayoutAtomA; + using SmemCopyAtomO = cute::Copy_Atom; + static constexpr int kGmemElementPerLoad = sizeof(cute::uint128_t) / sizeof(ElementOutput); + static constexpr int kGmemTrheadsPerRow = kBlcokKSmem / kGmemElementPerLoad; + using GmemLayoutAtomO + = cute::Layout, cute::Int>, + cute::Stride, cute::_1>>; + using GmemTiledCopyO = decltype(cute::make_tiled_copy(cute::Copy_Atom{}, + GmemLayoutAtomO{}, cute::Layout>{})); + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2); + static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M + static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K + static_assert(cute::rank(SmemLayoutAtomB{}) == 2); + static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N + static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K + + using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{}, + cute::make_shape( + cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_M, BLK_K, Stages + using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{}, + cute::make_shape( + cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_N, BLK_K, Stages + using SmemLayoutO = decltype(cute::tile_to_shape( + SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N + + // we need at least 2 stages.. + static_assert(Stages >= 2); + + struct SharedStorageNormal : cute::aligned_struct<128> + { + cute::array_aligned> smem_input; + cute::array_aligned> smem_fc1_weight; + cute::array_aligned> smem_o; + }; + + struct SharedStorageGate : cute::aligned_struct<128> + { + cute::array_aligned> smem_input; + cute::array_aligned> smem_fc1_gate_weight; + cute::array_aligned> smem_fc1_weight; + cute::array_aligned> smem_o; + }; + + using SharedStorage = std::conditional_t; + + using ActivationFn = std::conditional_t, + std::conditional_t, + std::conditional_t, cutlass::epilogue::thread::Identity>>>; + + static constexpr int kSmemSize = static_cast(sizeof(SharedStorage)); + + static constexpr bool can_implement(int const avaliable_smem_size) + { + return avaliable_smem_size > kSmemSize; + } + + // #endif +}; +} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h new file mode 100644 index 0000000000..80a4d85608 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h @@ -0,0 +1,73 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Scheduler for grouped GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/matrix_coord.h" + +#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct GemmMoeProblemVisitor + : public MoeProblemVisitor, ThreadblockShape, + GroupScheduleMode_, PrefetchTileCount, ThreadCount> +{ + + static bool const kTransposed = Transposed; + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base + = MoeProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, shared_storage_, block_idx) + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp new file mode 100644 index 0000000000..3a084ee04f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +//////////////////////////////////////////////////////////////////////////////// + +/* + * Stateless universal device GEMM kernel type that treats GEMM as + * a composition of a collective mainloop and a collective epilogue. + * + * Supports both the 2.x and 3.x APIs based on whether the first type is + * a cute::tuple<> or not. + * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h + * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp + * + * In the following declaration, the name preceding the 'Or' refers to + * 3.x API type argument order, and the name succeeding the 'Or' refers to + * 2.x API type argument order. Template arguments without two names + * belong to the 3.x API only. + **/ +template +class GemmUniversalGated; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel + +//////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp" +#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp" +//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h new file mode 100644 index 0000000000..0650ca8ded --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h @@ -0,0 +1,585 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. + + original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/trace.h" + +#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" + +namespace tk = tensorrt_llm::common; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithEpilogueVisitor +{ +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = + typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment + = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + tk::QuantMode quant_option; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments() + : mode(GemmUniversalMode::kGemm) + , batch_count(1) + { + } + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, + TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, + int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(mode_) + , problem_size(problem_size_) + , batch_count(batch_count_) + , ref_A(ref_A_) + , ref_B(ref_B_) + , quant_option(quant_option_) + , ref_alpha_col(ref_alpha_col_) + , ref_alpha_row(ref_alpha_row_) + , ref_C(ref_C_) + , ref_D(ref_D_) + , batch_stride_A(batch_stride_A_) + , batch_stride_B(batch_stride_B_) + , batch_stride_D(0) + , epilogue_visitor(epilogue_visitor_) + { + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + tk::QuantMode quant_option; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0) + , params_A(0) + , params_B(0) + , params_alpha_col(0) + , params_C(0) + , params_D(0) + , batch_count(0) + , gemm_k_size(0) + , mode(cutlass::gemm::GemmUniversalMode::kGemm) + , ptr_A(nullptr) + , ptr_B(nullptr) + , ptr_alpha_col(nullptr) + , ptr_alpha_row(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , batch_stride_A(0) + , batch_stride_B(0) + { + } + + Params( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) + : problem_size(args.problem_size) + , swizzle_log_tile(0) + , params_A(args.ref_A.layout()) + , params_B(args.ref_B.layout()) + , params_alpha_col(args.ref_alpha_col.layout()) + , params_alpha_row(args.ref_alpha_col.layout()) + , params_C(args.ref_C.layout()) + , params_D(args.ref_D.layout()) + , mode(args.mode) + , batch_count(args.batch_count) + , gemm_k_size(args.problem_size.k()) + , ptr_A(args.ref_A.data()) + , ptr_B(args.ref_B.data()) + , quant_option(args.quant_option) + , ptr_alpha_col(args.ref_alpha_col.data()) + , ptr_alpha_row(args.ref_alpha_row.data()) + , ptr_C(args.ref_C.data()) + , ptr_D(args.ref_D.data()) + , batch_stride_A(args.batch_stride_A) + , batch_stride_B(args.batch_stride_B) + , epilogue_visitor(args.epilogue_visitor) + { + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + int const kAlignK + = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) + { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage + { + + typename Mma::SharedStorage main_loop; + + struct + { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) + { + isAMisaligned = problem_size.k() % kAlignmentA; + } + else if (platform::is_same::value) + { + isAMisaligned = problem_size.m() % kAlignmentA; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) + { + isBMisaligned = problem_size.n() % kAlignmentB; + } + else if (platform::is_same::value) + { + isBMisaligned = problem_size.k() % kAlignmentB; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) + { + isCMisaligned = problem_size.n() % kAlignmentC; + } + else if (platform::is_same::value) + { + isCMisaligned = problem_size.m() % kAlignmentC; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) + { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + +#define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) + { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() + || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) + { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) + { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) + { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) + { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, + params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, + params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C, + params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) + { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) + { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) + { + if constexpr (platform::is_same::value) + { + run_kernel_(params, shared_storage); + } + else + { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. + run_kernel(params, shared_storage); +#else + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h new file mode 100644 index 0000000000..6dc6ffc1a9 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -0,0 +1,143 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct LayoutDetailsB +{ +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// TODO - Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 75>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB +{ + static constexpr int ThreadblockK = 64; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; + // for fast accumulation + // using Operator = cutlass::arch::OpMultiplyAddFastAccum; +}; + +// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, +// which signals that we want to dequantize after loading from smem. +template + struct LayoutDetailsB < TypeA, + uint8_t, Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template + struct LayoutDetailsB < TypeA, + uint4b_t, Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template +struct LayoutDetailsB= 90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh new file mode 100644 index 0000000000..aac2cb3579 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh @@ -0,0 +1,185 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +template +struct DefaultGemm_TensorOpSm80_OperandA; + +template +struct DefaultGemm_TensorOpSm80_OperandB; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +// +// F16: 128-by-128-by-32 (small k-block) +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template +CUTE_DEVICE auto util_convert_type(cute::Tensor const& tensor) +{ + using From_type = typename Engine::value_type; + constexpr int numel = decltype(cute::size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast const*>(tensor.data())); + return cute::make_tensor(cute::make_rmem_ptr(&frag), tensor.layout()); +} + +template +CUTE_DEVICE void util_copy( + TiledCopy const& tiled_copy, cute::Tensor const& S, cute::Tensor& D) +{ + CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{}); + CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{}); + CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D)); + CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D)); + CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D)); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(S); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < cute::size<2>(S); ++k) + { + cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k)); + } + } +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h new file mode 100644 index 0000000000..b708f7c28b --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -0,0 +1,553 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. +// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. +template +using void_t = void; + +template +struct use_dq_gemm : platform::false_type +{ +}; + +template +struct use_dq_gemm> : platform::true_type +{ +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeFCGemm +{ +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor + = GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + int problem_count; + int threadblock_count; + int group_size; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + bool C_is_broadcast; + + int64_t const* total_tokens_including_expert; + int64_t gemm_n; + int64_t gemm_k; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0) + , threadblock_count(0) + , ptr_A(nullptr) + , ptr_B(nullptr) + , weight_scales(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , total_tokens_including_expert(nullptr) + , gemm_n(0) + , gemm_k(0) + , host_problem_sizes(nullptr) + , C_is_broadcast{true} + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, + ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C, + bool C_is_broadcast, ElementC* ptr_D, int64_t const* total_tokens_including_expert, int64_t gemm_n, + int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr) + : problem_count(problem_count) + , threadblock_count(threadblock_count) + , group_size(group_size) + , output_op(output_op) + , ptr_A(const_cast(ptr_A)) + , ptr_B(const_cast(ptr_B)) + , weight_scales(const_cast(weight_scales)) + , ptr_C(const_cast(ptr_C)) + , C_is_broadcast{C_is_broadcast} + , ptr_D(ptr_D) + , total_tokens_including_expert(total_tokens_including_expert) + , gemm_n(gemm_n) + , gemm_k(gemm_k) + , host_problem_sizes(nullptr) + { + if (platform::is_same::value || platform::is_same::value) + { + assert(weight_scales); + } + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int group_size; + bool C_is_broadcast; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr) + , ptr_B(nullptr) + , weight_scales(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , C_is_broadcast(true) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor( + args.total_tokens_including_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count) + , threadblock_count(args.threadblock_count) + , group_size(args.group_size) + , output_op(args.output_op) + , ptr_A(args.ptr_A) + , ptr_B(args.ptr_B) + , weight_scales(args.weight_scales) + , ptr_C(args.ptr_C) + , ptr_D(args.ptr_D) + , C_is_broadcast(args.C_is_broadcast) + { + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + { + + problem_visitor = typename ProblemVisitor::Params(args.total_tokens_including_expert, args.gemm_n, + args.gemm_k, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + C_is_broadcast = args.C_is_broadcast; + } + }; + + /// Shared memory storage structure + union SharedStorage + { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) + { + if (platform::is_same::value || platform::is_same::value) + { + if (args.weight_scales == nullptr) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); + return Status::kInvalid; + } + } + else if (args.weight_scales != nullptr) + { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); + return Status::kInvalid; + } + else if (args.group_size != args.gemm_k) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)"); + return Status::kInvalid; + } + // Handle the case the input is too short + else if (args.gemm_n < Mma::IteratorB::AccessType::kElements) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) + { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert(platform::is_same::value && kInterleave == 1 + || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + int loop = 0; + while (problem_visitor.next_tile()) + { + loop++; + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump + = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B + = platform::is_same::value ? gemm_n : gemm_k * kInterleave; + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + auto CreateMMA = [&]() + { + if constexpr (use_dq_gemm::value) + return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + else + return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + }; + Mma mma = CreateMMA(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); + + if constexpr (use_dq_gemm::value) + { + const MatrixCoord scale_extent = {1, problem_size.n()}; + typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), + weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale); + + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + else + { + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + } + + // + // Epilogue + // + + ElementC* ptr_C = reinterpret_cast(params.ptr_C) + + (params.C_is_broadcast ? problem_idx : rows_to_jump) * gemm_n; + ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + // lora need to set as layout_C(gemm_n) + LayoutC layout_C = params.C_is_broadcast ? LayoutC(0) : LayoutC(gemm_n); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + if constexpr (platform::is_same>::value) + { + EpilogueOutputOp output_op(params.output_op, problem_idx); + epilogue(output_op, iterator_D, accumulators, iterator_C); + } + else + { + EpilogueOutputOp output_op(params.output_op); + epilogue(output_op, iterator_D, accumulators, iterator_C); + } + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) + { + if constexpr (platform::is_same::value) + { + run_kernel_(params, shared_storage); + } + else + { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900) + constexpr bool isFp8 = platform::is_same::value + || platform::is_same::value; + if constexpr (isFp8) + { + run_kernel(params, shared_storage); + } + else + { // reuse sm80 kernel for other types, align with dispatchToArch + run_kernel(params, shared_storage); + } +#elif (__CUDA_ARCH__ >= 900) + run_kernel(params, shared_storage); +#else + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h new file mode 100644 index 0000000000..796dc2fe78 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h @@ -0,0 +1,344 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Base scheduler for grouped problems, using MoE +*/ + +#pragma once + +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct BaseMoeProblemVisitor +{ + using ThreadblockShape = ThreadblockShape_; + + struct ProblemInfo + { + static int32_t const kNoPrefetchEntry = -1; + int32_t problem_idx; + int32_t problem_start; + + CUTLASS_DEVICE + ProblemInfo() + : problem_idx(kNoPrefetchEntry) + , problem_start(kNoPrefetchEntry) + { + } + + CUTLASS_DEVICE + ProblemInfo(int32_t problem_idx_, int32_t problem_start_) + : problem_idx(problem_idx_) + , problem_start(problem_start_) + { + } + }; + + struct Params + { + int64_t const* last_row_for_problem; + int64_t gemm_n; + int64_t gemm_k; + int32_t problem_count; + void const* workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params() + : last_row_for_problem(nullptr) + , gemm_n(0) + , gemm_k(0) + , problem_count(0) + , workspace(nullptr) + , tile_count(0) + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count, + void const* workspace = nullptr, int32_t tile_count = 0) + : last_row_for_problem(last_row_for_problem) + , gemm_n(gemm_n) + , gemm_k(gemm_k) + , problem_count(problem_count) + , workspace(workspace) + , tile_count(tile_count) + { + } + }; + + Params const& params; + int32_t tile_idx; + int32_t problem_tile_start; + int32_t problem_idx; + + // + // Methods + // + CUTLASS_DEVICE + BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) + : params(params_) + , tile_idx(block_idx) + , problem_tile_start(0) + , problem_idx(0) + { + } + + /// Get the grid shape + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem) + { + + return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); + } + + /// Gets the global tile index + CUTLASS_HOST_DEVICE + int32_t tile_index() const + { + return tile_idx; + } + + /// Gets the index of the problem + CUTLASS_HOST_DEVICE + int32_t problem_index() const + { + return problem_idx; + } + + CUTLASS_HOST_DEVICE + int32_t threadblock_idx() const + { + return tile_idx - problem_tile_start; + } + + CUTLASS_DEVICE + void advance(int32_t grid_size) + { + tile_idx += grid_size; + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) + { + ProblemSizeHelper::possibly_transpose_problem(problem); + } + + /// Returns the problem size for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size() const + { + return problem_size(problem_idx); + } + + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size(int idx) const + { + const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t current_problem_row = params.last_row_for_problem[idx]; + const int64_t gemm_m = current_problem_row - prev_problem_row; + GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(cutlass::gemm::GemmCoord const& grid) + { + return ProblemSizeHelper::tile_count(grid); + } + + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count) + { + int32_t total_tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) + { + auto problem = host_problem_sizes_ptr[i]; + possibly_transpose_problem(problem); + auto grid = grid_shape(problem); + total_tiles += tile_count(grid); + } + + return total_tiles; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeProblemVisitor; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ProblemVisitor that performs all scheduling on device +// +template +struct MoeProblemVisitor : public BaseMoeProblemVisitor +{ + using Base = BaseMoeProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static bool const kRequiresPrecomputation = false; + static int const kThreadsPerWarp = 32; + + struct SharedStorage + { + }; + + // Final tile of the problem loaded by this thread. Each thread will hold + // a separate value. + int32_t problem_ending_tile; + + SharedStorage& shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, block_idx) + , problem_ending_tile(0) + , shared_storage(shared_storage_) + { + this->problem_idx = -1 * kThreadsPerWarp; + this->problem_tile_start = 0; + } + + CUTLASS_DEVICE + bool next_tile() + { + // Check whether the tile to compute is within the range of the current problem. + int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); + if (this->tile_idx < problem_tile_end) + { + return true; + } + + // Check whether the tile to compute is within the current group of problems fetched by the warp. + // The last tile for this group is the final tile of the problem held by the final thread in the warp. + int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + // Keep the starting problem for this group in `problem_idx`. This is done to reduce + // register pressure. The starting problem for this group is simply the first problem + // in the group most recently fetched by the warp. + int32_t& group_problem_start = this->problem_idx; + group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; + + // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce + // register pressure. + int32_t& group_tile_start = this->problem_tile_start; + + // Each thread in the warp processes a separate problem to advance until + // reaching a problem whose starting tile is less less than tile_idx. + while (group_tile_end <= this->tile_idx) + { + group_problem_start += kThreadsPerWarp; + if (group_problem_start > this->params.problem_count) + { + return false; + } + + // Since `group_tile_start` is a reference to `this->problem_tile_start`, this + // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` + // is also set here is used later in `next_tile`. + group_tile_start = group_tile_end; + + int lane_idx = threadIdx.x % kThreadsPerWarp; + int32_t lane_problem = group_problem_start + lane_idx; + + // Compute the number of tiles in the problem assigned to each thread. + problem_ending_tile = 0; + if (lane_problem < this->params.problem_count) + { + cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + problem_ending_tile = this->tile_count(grid); + } + + // Compute a warp-wide inclusive prefix sum to compute the ending tile index of + // each thread's problem. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kThreadsPerWarp; i <<= 1) + { + int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); + if (lane_idx >= i) + { + problem_ending_tile += val; + } + } + + // The total tile count for this group is now in the final position of the prefix sum + int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + problem_ending_tile += group_tile_start; + group_tile_end += tiles_in_group; + } + + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); + + this->problem_idx = group_problem_start + problem_idx_in_group; + + // The starting tile for this problem is the ending tile of the previous problem. In cases + // where `problem_idx_in_group` is the first problem in the group, we do not need to reset + // `problem_tile_start`, because it is set to the previous group's ending tile in the while + // loop above. + if (problem_idx_in_group > 0) + { + this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); + } + + return true; + } + + static size_t get_workspace_size( + cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count) + { + return 0; + } + + static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, + int32_t block_count, void* host_workspace_ptr) + { + } +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp new file mode 100644 index 0000000000..e3d31a2c5b --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp @@ -0,0 +1,646 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/tensor.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/workspace.h" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +/////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalGated + && CollectiveMainloop_::isGated>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using Activation = typename CollectiveMainloop::Activation; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock + = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; + + // Kernel level shared memory storage + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> + { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void* workspace{nullptr}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static Params to_underlying_arguments(Arguments const& args, void* workspace) + { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + auto problem_shape = args.problem_shape; + // if constexpr (detail::IF_SWAP_AB::value) { + // // swap M/N + // get<0>(problem_shape) = get<1>(args.problem_shape); + // get<1>(problem_shape) = get<0>(args.problem_shape); + // } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) + { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used + // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means + // subtile will not be used, therefore separate reduction will not be enabled. + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, + ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + + return {args.mode, problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, + scheduler, workspace}; + } + + static bool can_implement(Arguments const& args) + { + bool implementable = (args.mode == GemmUniversalMode::kGemm) + or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) + { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t get_workspace_size(Arguments const& args) + { + size_t workspace_size = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) + { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + status = TileScheduler::template initialize_workspace(args.scheduler, + workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, + NumEpilogueSubTiles); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 get_grid_shape(Params const& params) + { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) + { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN + ? TileScheduler::RasterOrderOptions::AlongN + : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 get_block_shape() + { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) + { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); + static_assert(size<0>(TileShape{}) >= 128, + "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); + + static_assert(cute::rank(StrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ + enum class WarpGroupRole + { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole + { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int mma_thread_idx = thread_idx % size(TiledMma{}); + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) + { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = size(TiledMma{}); + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = []() + { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) + { + cute::cluster_arrive_relaxed(); + return []() { cute::cluster_wait(); }; + } + else + { + __syncthreads(); + return []() {}; // do nothing + } + }(); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + TileScheduler scheduler{params.scheduler}; + auto work_tile_info = scheduler.get_current_work(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 3, + "Output of load_init must have at least three elements (A, B, Aux)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) + { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) + { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid()) + { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) + { + work_tile_info = fetch_next_work(work_tile_info, scheduler); + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the + // work. + auto work_k_tile_count + = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter + = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, + load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, block_rank_in_cluster, + shared_storage.tensors.mainloop); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(work_k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) + { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) + { + while (work_tile_info.is_valid()) + { + if (!TileScheduler::requires_separate_reduction(params.scheduler)) + { + load_order_barrier.wait(); + } + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state = collective_epilogue.load(epi_load_pipeline, + epi_load_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, + shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx()); + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + cutlass::arch::warpgroup_reg_alloc(); + + // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it + bool do_store_tail = false; + float scale_d0 = params.mainloop.scale_d0; + float scale_d1 = params.mainloop.scale_d1; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count + = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + + // Allocate the accumulators for the (M,N) blk_shape + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + auto accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) + { + collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, + accumulators1, work_k_tile_count, mma_thread_idx, shared_storage.tensors.mainloop, + params.mainloop); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, work_k_tile_count); + + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + } + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; + + // Perform reduction across splits, if needed + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx); + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx); + + Activation elt_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators0); i++) + { + accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); + } + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) + { + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] + = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, + epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, + tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx()); + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + if (do_store_tail) + { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, epi_store_pipe_producer_state); + } + } // Consumer Warp Groups End +#endif + } + +private: + // Kernel helper function to get next work unit + CUTLASS_DEVICE + typename TileScheduler::WorkTileInfo fetch_next_work( + typename TileScheduler::WorkTileInfo& work_tile_info, TileScheduler& scheduler) const + { + // Check whether we should continue on with the current work unit. If this is the case, + // the work unit will have been updated in continue_current_work to reflect the new + // tile to be computed. + if (scheduler.continue_current_work(work_tile_info)) + { + return work_tile_info; + } + + // Get next work tile + scheduler.advance_to_next_work(); + return scheduler.get_current_work(); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp new file mode 100644 index 0000000000..39886f2431 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp @@ -0,0 +1,621 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" + +#include "cute/util/debug.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +/////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalGated + && CollectiveMainloop_::isGated>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using Activation = typename CollectiveMainloop::Activation; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(!cute::is_same_v, + "Ping-pong kernel does not currently support stream-K scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 2; + static constexpr uint32_t MaxThreadsPerBlock + = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; + + // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue + static constexpr uint32_t StagesPerMathWarpGroup = 2; + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier; + + // Kernel level shared memory storage + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> + { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static Params to_underlying_arguments(Arguments const& args, void* workspace) + { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + (void) workspace; + auto problem_shape = args.problem_shape; + // if constexpr (detail::IF_SWAP_AB::value) { + // // swap M/N + // get<0>(problem_shape) = get<1>(args.problem_shape); + // get<1>(problem_shape) = get<0>(args.problem_shape); + // } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) + { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + return {args.mode, problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)}; + } + + static bool can_implement(Arguments const& args) + { + bool implementable = (args.mode == GemmUniversalMode::kGemm) + or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) + { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t get_workspace_size(Arguments const& args) + { + size_t workspace_size = 0; + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) + { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = TileScheduler::template initialize_workspace(args.scheduler, + workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 get_grid_shape(Params const& params) + { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) + { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN + ? TileScheduler::RasterOrderOptions::AlongN + : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 get_block_shape() + { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) + { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + enum class WarpGroupRole + { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole + { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) + { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA Load WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); + params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier( + shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [&]() + { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) + { + cute::cluster_arrive_relaxed(); + return []() { cute::cluster_wait(); }; + } + else + { + __syncthreads(); + return []() {}; // do nothing + } + }(); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 3, + "Output of load_init must have at least three elements (A, B, Aux)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + TileScheduler scheduler{params.scheduler}; + + if (warp_group_role == WarpGroupRole::Consumer1) + { + // Advance 2nd Math WG to the next work tile for the startup + scheduler.advance_to_next_work(); + // Advance 2nd Math WG pipeline states to the end of 1st Math WG + mainloop_pipe_consumer_state.advance(k_tile_count); + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + } + auto work_tile_info = scheduler.get_current_work(); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) + { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) + { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); + + collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, + load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster, + shared_storage.tensors.mainloop); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) + { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) + { + load_order_barrier.wait(); + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state + = collective_epilogue.load(epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL, + blk_shape, blk_coord, tiled_mma, lane_idx, shared_storage.tensors.epilogue); + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + cutlass::arch::warpgroup_reg_alloc(); + + float scale_d0 = params.mainloop.scale_d0; + float scale_d1 = params.mainloop.scale_d1; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Allocate the accumulators for the (M,N) blk_shape + Tensor accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + Tensor accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + // Order two Math WG's MMA one after the other, helps hide Epilogue + math_wg_order_barrier.wait(); + + collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, accumulators1, + k_tile_count, warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop); + + // Cue for next Math WG's MMA to start + math_wg_order_barrier.arrive(); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count); + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); + + Activation elt_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators0); i++) + { + accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); + } + + // Order two Math WG's Epilogue one after the other + math_wg_order_barrier.wait(); + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] + = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, + epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, + tiled_mma, warp_group_thread_idx, shared_storage.tensors.epilogue); + + // TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels + // we need to wait for all TMA stores to complete before issuing consumer order barrier arrives + // to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer. + auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] + = collective_epilogue.store_tail(epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next); + + // Update starting load/store pipeline states for the next tile + // state has already been incremented by 1 tile in collective calls, advance once again for ping pong + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + + // Cue for next Math WG's Epilogue to start + math_wg_order_barrier.arrive(); + + // Get next work tile + scheduler.advance_to_next_work(NumMmaWarpGroups); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + } // Consumer Warp Groups End +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h new file mode 100644 index 0000000000..5e3531f093 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h @@ -0,0 +1,494 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SplitkGemmGrouped +{ +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + + using ElementFinalOutput = typename MapArguments::ElementA; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor + = GemmGroupedProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + GemmCoord* problem_sizes; + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // splitK + int split_k_slices; + int64_t* splitk_buffer_offsets; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0) + , threadblock_count(0) + , ptr_A(nullptr) + , ptr_B(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , lda(nullptr) + , ldb(nullptr) + , ldc(nullptr) + , ldd(nullptr) + , host_problem_sizes(nullptr) + , split_k_slices(1) + , splitk_buffer_offsets(nullptr) + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count, + typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C, + ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda, + typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc, + typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices, + int64_t* splitk_buffer_offsets) + : problem_sizes(problem_sizes) + , problem_count(problem_count) + , threadblock_count(threadblock_count) + , output_op(output_op) + , ptr_A(ptr_A) + , ptr_B(ptr_B) + , ptr_C(ptr_C) + , ptr_D(ptr_D) + , lda(lda) + , ldb(ldb) + , ldc(ldc) + , ldd(ldd) + , host_problem_sizes(host_problem_sizes) + , split_k_slices(split_k_slices) + , splitk_buffer_offsets(splitk_buffer_offsets) + { + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + ElementC* ptr_C_split; + ElementC* ptr_D_split; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; + + // + // Methods + // + + // splitk + GemmCoord grid_tiled_shape; + int swizzle_log_tile; + int gemm_k_size; + GemmCoord* host_problem_sizes; + int split_k_slices; + int64_t* splitk_buffer_offsets; + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr) + , ptr_B(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , ptr_C_split(nullptr) + , ptr_D_split(nullptr) + , lda(nullptr) + , ldb(nullptr) + , ldc(nullptr) + , ldd(nullptr) + , swizzle_log_tile(0) + , gemm_k_size(0) + , host_problem_sizes(nullptr) + , split_k_slices(1) + , splitk_buffer_offsets(nullptr) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count) + , host_problem_sizes(args.host_problem_sizes) + , threadblock_count(args.threadblock_count) + , output_op(args.output_op) + , ptr_A(args.ptr_A) + , ptr_B(args.ptr_B) + , ptr_C(args.ptr_C) + , ptr_D(args.ptr_D) + , ptr_C_split((ElementC*) workspace) + , ptr_D_split((ElementC*) workspace) + , lda(args.lda) + , ldb(args.ldb) + , ldc(args.ldc) + , ldd(args.ldd) + , split_k_slices(args.split_k_slices) + , splitk_buffer_offsets(args.splitk_buffer_offsets) + { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0], + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); + swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); + + // only support same k + int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK; + int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + { + + problem_visitor = + typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + ptr_C_split = workspace; + ptr_D_split = workspace; + + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldd = args.ldd; + } + }; + + /// Shared memory storage structure + struct SharedStorage + { + union + { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + SplitkGemmGrouped() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) + { + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { + + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) + { + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + // Load element pointers. Exchange pointers and strides if working on the transpose + ElementA* ptr_A + = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + + ElementB* ptr_B + = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k; + if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) + { + problem_size_k = problem_size.k(); + } + else + { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = params.ptr_C_split; + ElementC* ptr_D = params.ptr_D_split; + + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // assume identity swizzle + MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n()); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C); + + iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C); + iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h new file mode 100644 index 0000000000..ed5e3e4daf --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -0,0 +1,125 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ +//////////////////////////////////////////////////////////////////////////////// + +// We need to distinguish here, since we want volta support. It is too much effort +// to write shared memory iterators that are probably needed for volta to function +// properly. As a result, we allow converters both after the LDG (for volta) and after +// the LDS for Turing+. +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Warp level Mma + typename MmaOperator, + /// Math operation perform by warp level operator + typename MathOperator> +struct SetConverters +{ +}; + +// Dequantize after LDG, so set transforms accordingly +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters +{ + using TransformAfterLDG + = FastInterleavedAndBiasedNumericArrayConverter; + + using TransformAfterLDS = NumericArrayConverter; +}; + +// Dequantize after LDS, so set transforms accordingly + +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters +{ + using TransformAfterLDG = NumericArrayConverter; + + using TransformAfterLDS + = FastInterleavedAndBiasedNumericArrayConverter; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale_, + /// Layout for the scale operand + typename LayoutScale_, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// + typename Enable = void> +struct DqMma; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h new file mode 100644 index 0000000000..17c6346553 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -0,0 +1,302 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/arch/mma.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" +#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsMultistage; + +// Fine grained iterators +template +struct DefaultScaleIteratorsMultistage> +{ + using IteratorScale + = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsMultistage> +{ + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + +private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + +public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && !layout::IsColumnMajorTileInterleave::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value + || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && layout::IsColumnMajorTileInterleave::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value + || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + +private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape + = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h new file mode 100644 index 0000000000..345cd2eec9 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -0,0 +1,284 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/arch/mma.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsPipelined; + +// Fine grained iterators +template +struct DefaultScaleIteratorsPipelined> +{ +private: + using SmemScaleType = half_t; + +public: + using IteratorScale + = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale + = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, + SmemScaleType, Layout, 0, Alignment>; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsPipelined> +{ + static_assert((MmaShape::kN % Alignment) == 0, ""); + +private: + // ThreadMap for scale iterator + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + using SmemScaleType = half_t; + +public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale + = cutlass::transform::threadblock::PredicatedTileIterator, SmemScaleType, + Layout, 0, IteratorScaleThreadMap, Alignment>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + +private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape + = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap + = transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h new file mode 100644 index 0000000000..ad6c7496e1 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h @@ -0,0 +1,351 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#ifdef ENABLE_FP8 +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#endif + +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h new file mode 100644 index 0000000000..77af81005a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -0,0 +1,353 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + +private: + // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaElementA = typename platform::conditional::type; + using MmaElementB = typename platform::conditional::type; + +public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; +}; + +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, + AccessTypeA, GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, + AccessTypeB, GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h new file mode 100644 index 0000000000..1fb7f7eb28 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -0,0 +1,257 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// +// SFINAE trick so I can keep the same loop code for Volta and dispatch to the +// correct warp level mma. On volta, all data is stored to shared memory as FP16. +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C, + int const warp_tileB_k_offset) +{ + warp_mma(D, A, B, C); +} + +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B, + typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset) +{ + warp_mma(D, A, B, C, warp_tileB_k_offset); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// The type of the scales + typename ElementScale_, + /// Number of stages, + int Stages, + /// The dequantizing op to be performed. + WeightOnlyQuantOp DequantOp, + /// Used for partial specialization, + typename Enable = bool> +class DqMmaBase +{ +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; + + static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); + + // Finegrained scales get streamed in via cp.async + static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; + // We always have scales. + static constexpr int ScaleElementsPerStage = Shape::kN; + // We sometimes have a bias + static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad + = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage + { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA + = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB + = MatrixShape; + + /// Shape of the shared memory buffer for the scales for the B matrix. + using ShapeScale = MatrixShape; + /// Shape of the shared memory buffer for the biases of the B matrix. + using ShapeZero = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_zero; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() + { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() + { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() + { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() + { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + +protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx) + , warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h new file mode 100644 index 0000000000..3c4036dd8c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = void> +class DqMmaMultistage; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h" diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h new file mode 100644 index 0000000000..f81961dee3 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -0,0 +1,708 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applied immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + /// Internal structure exposed for introspection. + struct Detail + { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA + = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB + = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + /// The group size for quantization + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) + { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); + typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr + = reinterpret_cast(this->smem_iterator_scale_.get_scale()); + typename IteratorScale::AccessType* smem_zero_ptr + = reinterpret_cast(this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) + { + cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); + } + + if (iterator_scale.group_size_ == 64) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if (iterator_scale.group_size_ == 128) + { + if constexpr (Shape::kK == 128) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if constexpr (Shape::kK == 64) + { + if (iterator_scale.row_groupsize64_ & 0x1) + { + iterator_scale.add_tile_offset({1, 0}); + } + } + else + { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) + { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) + { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) + { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) + { + + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) + { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) + { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) + { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + using FragmentOperandB = cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter + = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) + { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for the B scales immediately. + if (group_start_iteration_B == 0) + { + copy_scales_and_advance(iterator_scale); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) + { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } + else + { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } + else + { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h new file mode 100644 index 0000000000..83efdc5cb0 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h @@ -0,0 +1,647 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail + { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA + = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB + = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< Group size for quantization. Not used by this main loop since it assumes per-column + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) + { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) + { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) + { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) + { + + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) + { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) + { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) + { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + + using FragmentOperandB = cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter + = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) + { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) + { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } + else + { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + smem_read_stage_idx = 0; + } + else + { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h new file mode 100644 index 0000000000..bd3e38971b --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -0,0 +1,106 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Used for partial specialization + typename Enable = void> +class DqMmaPipelined; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h" +#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h" diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h new file mode 100644 index 0000000000..50bdd0d85b --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h @@ -0,0 +1,486 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_> +class DqMmaPipelined> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + + static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + using WarpFragmentZero = typename Dequantizer::FragmentZero; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int const group_size, ///< The group size for quantization + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) + { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale) + { + using TransformScale = NumericArrayConverter; + + FragmentScale tb_frag_scales; + FragmentScale tb_frag_zeros; + tb_frag_scales.clear(); + tb_frag_zeros.clear(); + + TransformScale transformScale; + + using FragmentElement = typename FragmentScale::Element; + + auto gmem_scale_ptr = iterator_scale.get_scale(); + auto gmem_zero_ptr = iterator_scale.get_zero(); + + arch::global_load(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) + { + arch::global_load( + tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); + } + + typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales); + typename TransformScale::result_type tb_frag_zeros_fp16; + if (gmem_zero_ptr != nullptr) + tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); + + auto frag_scale_ptr_fp16 = reinterpret_cast(&tb_frag_scales_fp16); + auto frag_zero_ptr_fp16 = reinterpret_cast(&tb_frag_zeros_fp16); + auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); + auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); + + if (iterator_scale.valid()) + { + auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); + arch::shared_store(smem_offset, frag_scale_ptr_fp16); + + if (gmem_zero_ptr != nullptr) + { + smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); + arch::shared_store(smem_offset, frag_zero_ptr_fp16); + } + } + + if (iterator_scale.group_size_ == 64) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if (iterator_scale.group_size_ == 128) + { + if constexpr (Shape::kK == 128) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if constexpr (Shape::kK == 64) + { + if (iterator_scale.row_groupsize64_ & 0x1) + { + iterator_scale.add_tile_offset({1, 0}); + } + } + else + { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) + { ///< source accumulator tile + + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA + = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + copy_scales_and_advance(iterator_scale); + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + WarpFragmentScale warp_frag_scales; + WarpFragmentZero warp_frag_zero; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + iterator_scale.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) + { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) + { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + } + else + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) + { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + copy_scales_and_advance(iterator_scale); + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + iterator_scale.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero); + run_warp_mma( + warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + } + + // Load the scales needed for the next tile iteration + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + // Update internal pointer to the set of scales in shared memory + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h new file mode 100644 index 0000000000..316ea9f80a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h @@ -0,0 +1,399 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_> +class DqMmaPipelined> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation + ///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this + ///< argument is not added, it does not affect compilation for sm>=80. + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) + { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) + { ///< source accumulator tile + + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA + = NumericArrayConverter; + + using TransformScale = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + TransformScale transformScale; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + FragmentScale tb_frag_scales; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + WarpFragmentScale warp_frag_scales; + + tb_frag_A.clear(); + tb_frag_B.clear(); + tb_frag_scales.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + iterator_scale.load(tb_frag_scales); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + warp_dequantizer_.load(warp_frag_scales); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) + { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) + { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) + { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma( + warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h new file mode 100644 index 0000000000..350b247de2 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements, + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp +{ + +private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; + + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = 128 / sizeof_bits::value; + + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = GemmShape; + +public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h new file mode 100644 index 0000000000..7c5088894b --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -0,0 +1,306 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" +#include "cutlass/arch/mma_sm89.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> +class MmaTensorOpComputeBWithF16 +{ +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value + && platform::is_same::value) + || (platform::is_same::value + && platform::is_same::value + && ArchTag::kMinComputeCapability >= 80) + || (platform::is_same::value + && platform::is_same::value + && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); + + static_assert(platform::is_same::value + || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80) + || (platform::is_same::value && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert( + SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); + static_assert( + SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + /// Iterates over the A operand in memory + using IteratorA + = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, + LayoutB, MatrixShape, Policy::OpDelta::kRow, + kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + +public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, + int const warp_tileB_k_offset) const + { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " + "B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) + { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) + { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) + { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } + else + { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) + { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) + { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) + { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } + else + { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h new file mode 100644 index 0000000000..1d5cd5d898 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -0,0 +1,463 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/weight_only_quant_op.h" +#include "tensorrt_llm/common/cudaBf16Wrapper.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + WeightOnlyQuantOp QuantOp_, + /// + typename Enable = void> +class MmaTensorOpDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 80 + && platform::is_same::value>::type> +{ + +public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) + { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) + { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) + { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) + { + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) + { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + +private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Turing & Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 75 + && platform::is_same::value>::type> +{ + +public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) + { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) + { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) + { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) + { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB + = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) + { + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) + { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB + = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + + if constexpr (hasZero(QuantOp)) + { + plus plus_op; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] + = plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) + { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + +private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h new file mode 100644 index 0000000000..4acef2d180 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig +{ + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=16 + CtaShape16x128x64_WarpShape16x32x64, + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, + + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64, + + // TensorCore config CTA_N = 64, CTA_K = 128 + CtaShape128x64x128_WarpShape64x32x128, + + // TensorCore config CTA_N = 256, CTA_K = 64 + CtaShape16x256x64_WarpShape16x64x64, + + // TensorCore config CTA_N = 256, CTA_K = 128 + CtaShape16x256x128_WarpShape16x64x128 + +}; + +enum class SplitKStyle +{ + NO_SPLIT_K, + SPLIT_K_SERIAL, + STREAM_K, // Sm80+ + // SPLIT_K_PARALLEL // Not supported yet +}; + +enum class CutlassTileConfigSM90 +{ + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // CTA configs for M=64 + CtaShape64x16x128B, + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // CTA configs for M=128 + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + + // CTA configs for M=128 + CtaShape256x128x128B, +}; + +enum class MainloopScheduleType +{ + AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this + // defaults to the "legacy" main loop schedule. +}; + +enum class EpilogueScheduleType +{ + AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For + // architectures older than hopper, the epilogue is always performed by the same thread block as the main loop. +}; + +enum class ClusterShape +{ + ClusterShape_1x1x1, + ClusterShape_2x1x1, + ClusterShape_1x2x1, + ClusterShape_2x2x1, + ClusterShape_1x8x1, + ClusterShape_8x1x1 +}; + +struct CutlassGemmConfig +{ + enum CandidateConfigTypeParam : int + { + NONE = 0, + WEIGHT_ONLY = 1u << 0, + SIMT_ONLY = 1u << 1, + INT8_ONLY = 1u << 2, + HOPPER = 1u << 3, + GROUPED_GEMM = 1u << 4, + FP8_ONLY = 1u << 5, + }; + + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; + + // config options for sm90 + CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; + EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; + ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + bool is_sm90 = false; + + CutlassGemmConfig() {} + + CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) + : tile_config(tile_config) + , split_k_style(split_k_style) + , split_k_factor(split_k_factor) + , stages(stages) + , is_sm90(false) + { + } + + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm90(tile_config_sm90) + , mainloop_schedule(mainloop_schedule) + , epilogue_schedule(epilogue_schedule) + , cluster_shape(cluster_shape) + , is_sm90(true) + { + } + + std::string toString() const + { + std::stringstream tactic; + tactic << "Cutlass GEMM Tactic"; + if (tile_config_sm90 != tensorrt_llm::cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic) + { + assert(is_sm90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=TMA" + << "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\tcluster shape ID: " << (int) cluster_shape + << "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule; + } + else if (tile_config != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) + { + assert(!is_sm90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=compatible" + << "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages + << "\n\tsplit k: " << (int) split_k_factor; + } + else + { + tactic << "\n\tundefined"; + } + tactic << "\n"; + return tactic.str(); + } +}; + +inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) +{ + // clang-format off + if (config.is_sm90) + { + out << "tile_config_sm90_enum: " << int(config.tile_config_sm90) + << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) + << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) + << ", cluster_shape_enum: " << int(config.cluster_shape); + } + else + { + out << "tile_config_enum: " << int(config.tile_config) + << ", split_k_style_enum: " << int(config.split_k_style) + << ", split_k_factor: " << config.split_k_factor + << ", stages: " << config.stages; + } + // clang-format on + return out; +} + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h new file mode 100644 index 0000000000..44ba79680e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/numeric_types.h" + +namespace cutlass +{ + +// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low +// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. +// This converter will uninterleave the data and subtract the bias while converting to the result type. +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) + { + fp32_intermediates[ii] -= 8388736.f; + } + + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) + { + bf16_result_ptr[ii] + = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) + { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) + { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h new file mode 100644 index 0000000000..5a0cd29570 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h @@ -0,0 +1,66 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines new layouts needed for MoE +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass +{ +namespace layout +{ + +template +struct ColumnMajorTileInterleave +{ + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave +{ + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> +{ + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h new file mode 100644 index 0000000000..6095925e37 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM + quantization. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace transform +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template +class FineGrainedScaleZeroIterator; + +template +class FineGrainedScaleZeroIterator +{ +public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + using Fragment = cutlass::Array; + + // For compatibility with existing iterator interface + struct Params + { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_(layout.stride(0)) + { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + +private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params) + , pointer_scale_(reinterpret_cast(const_cast(pointer_scale))) + , pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) + { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset + = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) + { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); + } + + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + int const thread_row = thread_id / THREADS_PER_ROW; + int const thread_col = thread_id % THREADS_PER_ROW; + + const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) + { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + } + + // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on + // a given iteration. The same threads will be responsible for issues reads since the number of scales + // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ + // outside of the constructor. + int const global_row = threadblock_offset.row() + thread_row; + int const global_col = threadblock_offset.column() + thread_col * kAlignment; + + bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; + bool const col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator( + params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) + { + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) + { + pointer_zero_ += row_byte_offset + col_byte_offset; + } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) + { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const + { + return is_valid_; + } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const + { + return reinterpret_cast(pointer_scale_); + } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const + { + return reinterpret_cast(pointer_zero_); + } +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp new file mode 100644 index 0000000000..b430380b01 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp @@ -0,0 +1,181 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" + +using namespace cute; + +/// Function object that applies an index to its argument +template +struct IndexedGather +{ + CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {}) + : indices_(indices) + { + } + + template + CUTE_HOST_DEVICE constexpr auto operator()(I i) const + { + return indices_[i]; + } + + CUTE_HOST_DEVICE friend void print(IndexedGather const& s) + { + cute::print("Indexed{"); + print(s.indices_); + print("}"); + } + + Iter indices_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride +{ + CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride) + : func_(func) + , stride_(stride) + { + } + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s) + { + return s.func_(i) * s.stride_; + } + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i) + { + return s.func_(i) * s.stride_; + } + + CUTE_HOST_DEVICE friend void print(CustomStride const& s) + { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) + { + return CustomStride(s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride) + { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func) +{ + // Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride + auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout( + repeat_like(stride, _1{}), replace(stride, CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func) +{ + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); +} + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) + { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s, d); }); + } + else if constexpr (is_scaled_basis::value) + { + if constexpr (Stride::mode() == I) + { + return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); + } + else + { + return make_layout(shape, stride); + } + } + else + { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr auto upcast( + ComposedLayout, Offset, Layout> const& layout) +{ + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace cute diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h new file mode 100644 index 0000000000..64774428e9 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +namespace cutlass +{ + +enum class WeightOnlyQuantOp +{ + UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS +}; + +constexpr bool isFinegrained(WeightOnlyQuantOp op) +{ + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +} + +constexpr bool hasZero(WeightOnlyQuantOp op) +{ + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +} + +} // namespace cutlass diff --git a/sgl-kernel/THIRDPARTYNOTICES.txt b/sgl-kernel/THIRDPARTYNOTICES.txt index c930aa5dd3..fcae14df3a 100644 --- a/sgl-kernel/THIRDPARTYNOTICES.txt +++ b/sgl-kernel/THIRDPARTYNOTICES.txt @@ -223,3 +223,208 @@ BSD 3-Clause "New" License 3rdparty/cutlass include/flashinfer/attention/hopper/block_sparse_gather.cuh + +Notice for NVIDIA/TensorRT-LLM +------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 90c3cbc1d3..5029914031 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -39,6 +39,8 @@ def _get_version(): cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" turbomind = root / "3rdparty" / "turbomind" +tensorrt_llm_parent = root / "3rdparty" +tensorrt_llm = root / "3rdparty" / "tensorrt_llm" include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", @@ -51,6 +53,8 @@ def _get_version(): "cublasLt", turbomind.resolve(), turbomind.resolve() / "src", + tensorrt_llm_parent.resolve(), + tensorrt_llm.resolve() / "cutlass_extensions" / "include", ] nvcc_flags = [ From e81d7f11dede2b9b3f82de00a433eccc3d47c25e Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 30 Jan 2025 23:49:14 +0800 Subject: [PATCH 05/52] add tensorrt_llm moe_gemm as 3rdparty (#3217) --- .../tensorrt_llm/common/cudaBf16Wrapper.h | 21 + .../tensorrt_llm/common/cudaDriverWrapper.cpp | 187 ---- .../tensorrt_llm/common/cudaDriverWrapper.h | 138 --- .../launchers/fused_moe_gemm_launcher_sm80.h | 25 + .../fused_moe_gemm_launcher_sm80.inl | 96 ++ .../launchers/moe_gemm_launcher_sm90.h | 37 + .../launchers/moe_gemm_launcher_sm90.inl | 348 ++++++++ .../moe_gemm/moe_gemm_hopper_input.cu | 131 +++ .../moe_gemm/moe_gemm_kernels.h | 230 +++++ .../moe_gemm/moe_gemm_kernels_bf16_bf16.cu | 24 + .../moe_gemm/moe_gemm_kernels_bf16_uint4.cu | 24 + .../moe_gemm/moe_gemm_kernels_bf16_uint8.cu | 24 + .../moe_gemm/moe_gemm_kernels_fp16_fp16.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp16_uint4.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp16_uint8.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp32_fp32.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp8_fp8.cu | 28 + .../moe_gemm/moe_gemm_kernels_template.h | 823 ++++++++++++++++++ .../moe_gemm/moe_gemm_kernels_template_sm90.h | 222 +++++ .../moe_gemm/moe_sm90_traits.h | 44 + 20 files changed, 2165 insertions(+), 325 deletions(-) create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h new file mode 100644 index 0000000000..fb2a89af5c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_BF16 +#include +#endif diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp deleted file mode 100644 index 7eca46a1ca..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#define CUDA_LIB_NAME "cuda" - -#if defined(_WIN32) -#include -#define dllOpen(name) LoadLibrary("nv" name ".dll") -#define dllClose(handle) FreeLibrary(static_cast(handle)) -#define dllGetSym(handle, name) static_cast(GetProcAddress(static_cast(handle), name)) -#else // For non-Windows platforms -#include -#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY) -#define dllClose(handle) dlclose(handle) -#define dllGetSym(handle, name) dlsym(handle, name) -#endif // defined(_WIN32) - -#include "cudaDriverWrapper.h" -#include "tensorrt_llm/common/assert.h" -#include -#include - -namespace tensorrt_llm::common -{ - -std::shared_ptr CUDADriverWrapper::getInstance() -{ - static std::mutex mutex; - static std::weak_ptr instance; - std::shared_ptr result = instance.lock(); - if (result) - { - return result; - } - - std::lock_guard lock(mutex); - result = instance.lock(); - if (!result) - { - result = std::shared_ptr(new CUDADriverWrapper()); - instance = result; - } - return result; -} - -CUDADriverWrapper::CUDADriverWrapper() - : handle(dllOpen(CUDA_LIB_NAME)) -{ - - TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); - - auto load_sym = [](void* handle, char const* name) - { - void* ret = dllGetSym(handle, name); - return ret; - }; - - *reinterpret_cast(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName"); - *reinterpret_cast(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage"); - *reinterpret_cast(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute"); - *reinterpret_cast(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete"); - *reinterpret_cast(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload"); - *reinterpret_cast(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy"); - *reinterpret_cast(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData"); - *reinterpret_cast(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2"); - *reinterpret_cast(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction"); - *reinterpret_cast(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2"); - *reinterpret_cast(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2"); - *reinterpret_cast(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2"); - *reinterpret_cast(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel"); - *reinterpret_cast(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel"); - *reinterpret_cast(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled"); - *reinterpret_cast(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2"); -} - -CUDADriverWrapper::~CUDADriverWrapper() -{ - dllClose(handle); -} - -CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const -{ - return (*_cuGetErrorName)(error, pStr); -} - -CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const -{ - return (*_cuGetErrorMessage)(error, pStr); -} - -CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const -{ - return (*_cuFuncSetAttribute)(hfunc, attrib, value); -} - -CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const -{ - return (*_cuLinkComplete)(state, cubinOut, sizeOut); -} - -CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const -{ - return (*_cuModuleUnload)(hmod); -} - -CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const -{ - return (*_cuLinkDestroy)(state); -} - -CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const -{ - return (*_cuModuleLoadData)(module, image); -} - -CUresult CUDADriverWrapper::cuLinkCreate( - unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const -{ - return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); -} - -CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const -{ - return (*_cuModuleGetFunction)(hfunc, hmod, name); -} - -CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const -{ - return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); -} - -CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, - unsigned int numOptions, CUjit_option* options, void** optionValues) const -{ - return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); -} - -CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, - char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const -{ - return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); -} - -CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const -{ - return (*_cuLaunchCooperativeKernel)( - f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams); -} - -CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const -{ - return (*_cuLaunchKernel)( - f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra); -} - -CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, - cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const -{ - return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides, - boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill); -} - -CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const -{ - return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount); -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h deleted file mode 100644 index c4d470a85f..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef CUDA_DRIVER_WRAPPER_H -#define CUDA_DRIVER_WRAPPER_H - -#include "tensorrt_llm/common/assert.h" -#include -#include -#include -#include - -namespace tensorrt_llm::common -{ - -class CUDADriverWrapper -{ -public: - static std::shared_ptr getInstance(); - - ~CUDADriverWrapper(); - CUDADriverWrapper(CUDADriverWrapper const&) = delete; - CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete; - CUDADriverWrapper(CUDADriverWrapper&&) = delete; - CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete; - - CUresult cuGetErrorName(CUresult error, char const** pStr) const; - - CUresult cuGetErrorMessage(CUresult error, char const** pStr) const; - - CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; - - CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const; - - CUresult cuModuleUnload(CUmodule hmod) const; - - CUresult cuLinkDestroy(CUlinkState state) const; - - CUresult cuModuleLoadData(CUmodule* module, void const* image) const; - - CUresult cuLinkCreate( - unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; - - CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const; - - CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const; - - CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, - CUjit_option* options, void** optionValues) const; - - CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name, - unsigned int numOptions, CUjit_option* options, void** optionValues) const; - - CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const; - - CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, - unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, - CUstream hStream, void** kernelParams, void** extra) const; - - CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, - void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim, - cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, - CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const; - - CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const; - -private: - void* handle; - CUDADriverWrapper(); - - CUresult (*_cuGetErrorName)(CUresult, char const**); - CUresult (*_cuGetErrorMessage)(CUresult, char const**); - CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); - CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); - CUresult (*_cuModuleUnload)(CUmodule); - CUresult (*_cuLinkDestroy)(CUlinkState); - CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); - CUresult (*_cuModuleLoadData)(CUmodule*, void const*); - CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*); - CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*); - CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**); - CUresult (*_cuLinkAddData)( - CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**); - CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, - unsigned int, unsigned int, unsigned int, CUstream, void**); - CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, - unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, - CUstream hStream, void** kernelParams, void** extra); - CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, - cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); - CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount); -}; - -template -void checkDriver( - T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line) -{ - if (result) - { - char const* errorName = nullptr; - char const* errorMsg = nullptr; - wrap.cuGetErrorName(result, &errorName); - wrap.cuGetErrorMessage(result, &errorMsg); - throw TllmException( - file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg)); - } -} - -} // namespace tensorrt_llm::common - -/* - * Macros compliant with TensorRT coding conventions - */ -#define TLLM_CU_CHECK(stat) \ - do \ - { \ - tensorrt_llm::common::checkDriver( \ - (stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \ - } while (0) - -#endif // CUDA_DRIVER_WRAPPER_H diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h new file mode 100644 index 0000000000..f4eed277c1 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +template +void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, + ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert, + int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, + int* kernel_occupancy); +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl new file mode 100644 index 0000000000..126e761ec9 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" + +#include +#include +#include + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +template +void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, + ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert, + int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, + int* kernel_occupancy) +{ + constexpr auto activation_type = fused_moe::EpilogueRouting(true); + using GemmType = fused_moe::Fused_Moe_Kernel_sm80; + + // make sure GPU has enough resources.. + if (kernel_occupancy != nullptr) + { + constexpr int smem_size = GemmType::kSmemSize; + + if (smem_size > (48 << 10)) + { + cudaFuncAttributes attr{}; + int device = 0; + int max_smem_per_block = 0; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global)); + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) + { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + // smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the + // heuristic to ignore this configuration. + *kernel_occupancy = 0; + return; + } + } + + int max_active_blocks = -1; + tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, fused_moe::run_global, GemmType::kThreadCount, smem_size)); + *kernel_occupancy = max_active_blocks; + return; + } + int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks()); + int const threadblock_count = multi_processor_count * occupancy; + TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel"); + using Arguments = typename GemmType::Arguments; + Arguments args{{const_cast(A), const_cast(B), const_cast(biases), + reinterpret_cast(C), total_tokens_including_expert, static_cast(gemm_n), + static_cast(gemm_k), num_experts, bias_is_broadcast}, + num_experts, threadblock_count}; + auto params = GemmType::to_underlying_arguments(args); + if (GemmType::kSmemSize >= (48 << 10)) + { + cudaError_t result = cudaFuncSetAttribute( + fused_moe::run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize); + TLLM_CHECK_WITH_INFO(result == cudaSuccess, + "Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel"); + } + dim3 grid(params.threadblock_count, 1, 1); + dim3 block(GemmType::kThreadCount); + fused_moe::run_global<<>>(params); + auto result = cudaGetLastError(); + TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result)); +} +} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h new file mode 100644 index 0000000000..91527fadb6 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +// Keep in sync with the signature generated by generate_kernels.py +template +void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts, + int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl new file mode 100644 index 0000000000..cca60a9816 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl @@ -0,0 +1,348 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/tensor_ref.h" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion; + +// Hopper helper class for defining all the cutlass helper types +template +struct HopperGroupedGemmInfo +{ + using Arch = cutlass::arch::Sm90; + + // TODO Update once mixed input support is added + static_assert(cutlass::platform::is_same::value, + "CUTLASS does not currently have specialised SM90 support for quantized operations"); + +#ifdef ENABLE_FP8 + constexpr static bool IsFP8 + = cutlass::platform::is_same::value || cutlass::platform::is_same::value; +#else + constexpr static bool IsFP8 = false; +#endif + +#ifdef ENABLE_BF16 + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value || IsFP8, + "Specialized for bfloat16, half, float, fp8"); +#else + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || IsFP8, + "Specialized for half, float, fp8"); +#endif + + static_assert(cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + "Unexpected quantization type"); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType = typename TllmToCutlassTypeAdapter::type; + + using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter::type; + // For legacy reasons we convert unsigned 8-bit to signed + using CutlassWeightTypeMaybeUint8 + = std::conditional_t, cutlass::int4b_t, + CutlassWeightTypeMaybeUint4>; + using CutlassWeightType + = std::conditional_t, int8_t, CutlassWeightTypeMaybeUint8>; + + using ElementA = ElementType; + using ElementB = CutlassWeightType; + + using ElementD = typename TllmToCutlassTypeAdapter>::type; + using ElementFinalOutput = typename TllmToCutlassTypeAdapter::type; + + // using ElementC = std::conditional_t; + // using ElementCNoVoid = std::conditional_t; + using ElementC = void; + using ElementCNoVoid = ElementD; + + using ElementAccumulator = float; + + using ElementBias = ElementFinalOutput; + using ElementRouterScales = float; + + // A matrix configuration - this is transposed and swapped with B + using LayoutA = HopperGroupedGemmInput::LayoutA; + constexpr static int AlignmentA + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units + // of elements (up to 16 bytes) + + // B matrix configuration - this is transposed and swapped with A + using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand + constexpr static int AlignmentB + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units + // of elements (up to 16 bytes) + + // C matrix configuration + using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand + using StrideC = HopperGroupedGemmInput::StrideC; + // Note we use ElementType here deliberately, so we don't break when BIAS is disabled + constexpr static int AlignmentC + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units + // of elements (up to 16 bytes) + + // D matrix configuration + using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD; + using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD; + constexpr static int AlignmentD + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix + // in units of elements (up to 16 bytes) + + static_assert(cutlass::platform::is_same::value, + "Hopper Grouped GEMM specialisation doesn't support fused activation"); + + using EpilogueOp + = cutlass::epilogue::fusion::LinearCombination; + + // TODO Add mode for fused activation once CUTLASS adds support + // using EpilogueSchedule = cutlass::platform::conditional_t< + // cutlass::platform::is_same::value, + // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, + // cutlass::epilogue::?????????????????? /// <<<<<< what supports activations + // >; + using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; + + // Epilogue For Default Finalize + using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< // + Arch, cutlass::arch::OpClassTensorOp, // + TileShape, ClusterShape, // + cutlass::epilogue::collective::EpilogueTileAuto, // + ElementAccumulator, ElementAccumulator, // + ElementC, LayoutC*, AlignmentC, // + ElementD, LayoutD*, AlignmentD, // + EpilogueSchedule>::CollectiveOp; + + // Epilogue For Fused Finalize + using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< // + TileShape, // + ElementCNoVoid, StrideC*, // + ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, // + ElementAccumulator, // + ElementAccumulator, // + ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, // + ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales // + >::CollectiveOp; + + using CollectiveEpilogue + = std::conditional_t; + + using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>; + + using KernelSchedule + = std::conditional_t; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< // + Arch, cutlass::arch::OpClassTensorOp, // + CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here + ElementType, LayoutA*, AlignmentA, // + ElementAccumulator, // + TileShape, ClusterShape, // + StageCountAutoCarveout, KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + + using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter; +}; + +// Hopper specialised version +template +void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts, + int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size) +{ +#ifdef COMPILE_HOPPER_TMA_GEMMS + using namespace cute; + if constexpr (!should_filter_sm90_gemm_problem_shape_v) + { + using GemmInfo + = HopperGroupedGemmInfo; + + using ElementAccumulator = typename GemmInfo::ElementAccumulator; + using ElementA = typename GemmInfo::ElementA; + using ElementB = typename GemmInfo::ElementB; + using ElementC = typename GemmInfo::ElementC; + using ElementCNoVoid = typename GemmInfo::ElementCNoVoid; + using ElementD = typename GemmInfo::ElementD; + + using CollectiveMainloop = typename GemmInfo::CollectiveMainloop; + using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue; + using GemmKernel = typename GemmInfo::GemmKernel; + using GemmGrouped = typename GemmInfo::GemmGrouped; + + if (kernel_occupancy != nullptr) + { + *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = multi_processor_count; + + GemmGrouped gemm; + + if (workspace_size != nullptr) + { + // Make a mock problem shape with just the minimal information actually required to get the workspace size + // This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to + // catch future cutlass updates causing silent breakages, but that is not fool proof. + // The alternative is to wait until we have data and then dynamically allocate the workspace + typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr}; + + typename GemmGrouped::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info}; + *workspace_size = gemm.get_workspace_size(args); + return; + } + + using MainloopArguments = typename CollectiveMainloop::Arguments; + TLLM_CHECK(hopper_input.stride_a); + TLLM_CHECK(hopper_input.stride_b); + TLLM_CHECK(hopper_input.ptr_a); + TLLM_CHECK(hopper_input.ptr_b); + + MainloopArguments const mainloop_params = {reinterpret_cast(hopper_input.ptr_b), + hopper_input.stride_b, reinterpret_cast(hopper_input.ptr_a), hopper_input.stride_a}; + + typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{ + ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)}; + epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + // TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS + auto make_epi_args = [&]() + { + if constexpr (FUSION == EpilogueFusion::NONE) + { + auto epi_params = hopper_input.default_epilogue; + return EpilogueArguments{epilogue_scalars, reinterpret_cast(hopper_input.ptr_c), + hopper_input.stride_c, reinterpret_cast(epi_params.ptr_d), epi_params.stride_d}; + } + else if constexpr (FUSION == EpilogueFusion::FINALIZE) + { + // Parameters for fused finalize + auto epi_params = hopper_input.fused_finalize_epilogue; + return EpilogueArguments{ + epilogue_scalars, // Parameters to underlying epilogue + reinterpret_cast(hopper_input.ptr_c), hopper_input.stride_c, // C params + reinterpret_cast(epi_params.ptr_final_output), + epi_params.stride_final_output, // D (output) params + reinterpret_cast(epi_params.ptr_bias), + epi_params.stride_bias, // Bias params + epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales + epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales + epi_params.ptr_source_token_index, // Index of the source token to sum into + epi_params.num_rows_in_final_output // Number of tokens in the output buffer + }; + } + else + { + static_assert( + sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher"); + } + }; + EpilogueArguments const epilogue_params = make_epi_args(); + + typename GemmKernel::TileScheduler::Arguments scheduler_args{ + 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; + + typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info, + mainloop_params, epilogue_params, hw_info, scheduler_args}; + + size_t calculated_ws_size = gemm.get_workspace_size(args); + TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size, + "Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size); + + auto can_implement = gemm.can_implement(args); + TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, + "Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); + + auto init_status = gemm.initialize(args, hopper_input.gemm_workspace); + TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, + "Failed to initialize cutlass SM90 grouped gemm. Error: " + + std::string(cutlassGetStatusString(init_status))); + + auto run_status = gemm.run(stream); + TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, + "Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status))); + sync_check_cuda_error(); + } + else + { + TLLM_THROW("Configuration was disabled by FAST_BUILD"); + } + +#else // COMPILE_HOPPER_TMA_GEMMS + TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py."); +#endif // COMPILE_HOPPER_TMA_GEMMS +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu new file mode 100644 index 0000000000..9862460dd6 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu @@ -0,0 +1,131 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/conv/convolution.h" +// Order matters here, packed_stride.hpp is missing cute and convolution includes +#include "cutlass/util/packed_stride.hpp" + +#include "tensorrt_llm/common/logger.h" + +namespace tensorrt_llm +{ +std::array HopperGroupedGemmInput::workspaceBuffers(int num_experts) +{ + size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts; + size_t stride_a_size = sizeof(StrideA) * num_experts; + size_t stride_b_size = sizeof(StrideB) * num_experts; + size_t stride_c_size = sizeof(StrideC) * num_experts; + size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts; + + size_t ptr_buf_size = sizeof(void*) * num_experts; + size_t scale_buf_size = sizeof(float*) * num_experts; + + return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size, + ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size}; +} + +size_t HopperGroupedGemmInput::workspaceSize(int num_experts) +{ + auto buffers = workspaceBuffers(num_experts); + return tensorrt_llm::common::calculateTotalWorkspaceSize(buffers.data(), buffers.size()); +} + +void HopperGroupedGemmInput::configureWorkspace( + int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size) +{ + auto buffers = workspaceBuffers(num_experts); + std::array pointers{}; + TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers"); + for (int i = 0; i < buffers.size(); i++) + { + pointers[i] = start_ptr; + start_ptr = tensorrt_llm::common::nextWorkspacePtr(start_ptr, buffers[i]); + } + + shape_info.num_groups = num_experts; + shape_info.problem_shapes = reinterpret_cast(pointers[0]); + shape_info.host_problem_shapes = nullptr; + stride_a = reinterpret_cast(pointers[1]); + stride_b = reinterpret_cast(pointers[2]); + stride_c = reinterpret_cast(pointers[3]); + default_epilogue.stride_d = reinterpret_cast(pointers[4]); + + ptr_a = reinterpret_cast(pointers[5]); + ptr_b = reinterpret_cast(pointers[6]); + ptr_c = reinterpret_cast(pointers[7]); + default_epilogue.ptr_d = reinterpret_cast(pointers[8]); + + alpha_scale_ptr_array = reinterpret_cast(pointers[9]); + + this->gemm_workspace = reinterpret_cast(gemm_workspace); + this->gemm_workspace_size = gemm_workspace_size; +} + +void HopperGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales, + int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, + int num_output_tokens) +{ + fused_finalize_epilogue.ptr_final_output = final_output; + fused_finalize_epilogue.ptr_router_scales = router_scales; + fused_finalize_epilogue.ptr_bias = bias; + fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset; + fused_finalize_epilogue.ptr_source_token_index = source_token_index; + + fused_finalize_epilogue.stride_final_output + = cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{}, + transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1))); + fused_finalize_epilogue.stride_bias + = transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size)); + fused_finalize_epilogue.stride_router_scales = {}; + + fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens; +} + +std::string HopperGroupedGemmInput::toString() const +{ + std::stringstream ss; + ss << "Hopper Input Information: " << (isValid() ? "valid" : "null") << "\n"; + if (isValid()) + { + ss << "Ptr A: " << ptr_a << ", Ptr B: " << ptr_b << ", Ptr C: " << ptr_c << "\n"; + ss << "Epilogue Fusion: " << (int) fusion; + if (fusion == HopperGroupedGemmInput::EpilogueFusion::FINALIZE) + { + ss << ",\nFinal Output: " << fused_finalize_epilogue.ptr_final_output; + ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; + ss << ",\nBias: " << fused_finalize_epilogue.ptr_bias; + ss << " with Stride: " << fused_finalize_epilogue.stride_bias; + ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales; + ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; + ss << ",\nExpert Offset: " << fused_finalize_epilogue.ptr_expert_first_token_offset; + ss << ", Source Map: " << fused_finalize_epilogue.ptr_source_token_index; + } + else + { + ss << ", Ptr D: " << default_epilogue.ptr_d; + } + ss << '\n'; + ss << "Alpha scale ptr: " << alpha_scale_ptr_array << "\n"; + } + return ss.str(); +} +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h new file mode 100644 index 0000000000..0616c06365 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h @@ -0,0 +1,230 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include "tensorrt_llm/common/workspace.h" +#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" +#include +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/layout/layout.h" + +namespace tensorrt_llm +{ +template +constexpr auto transpose_stride(T const& t) +{ + return cute::prepend(cute::prepend(cute::take<2, cute::rank_v>(t), cute::get<0>(t)), cute::get<1>(t)); +} + +struct HopperGroupedGemmInput +{ + template + using TransposeStride = decltype(transpose_stride(T{})); + template + using TransposeLayoutTag = std::conditional_t, + cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; + + static_assert(std::is_same_v>); + static_assert(std::is_same_v>); + + // Layout for A and B is transposed and then swapped in the implementation + // This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM + using LayoutA = TransposeLayoutTag; // Layout type for A matrix operand + using LayoutB = TransposeLayoutTag; // Layout type for B matrix operand + using LayoutC = TransposeLayoutTag; // Layout type for C matrix operand + + using StrideA + = std::remove_pointer_t>; // Use B because they will be swapped + using StrideB + = std::remove_pointer_t>; // Use A because they will be swapped + using StrideC = std::remove_pointer_t>; + + template + constexpr static bool IsFP8_v = std::is_same_v || std::is_same_v; + + // Currently this should always just be T + template + using OutputTypeAdaptor_t = std::conditional_t, nv_bfloat16, T>; + + using ProblemShape = cutlass::gemm::GroupProblemShape>; + + ProblemShape shape_info{}; + StrideA* stride_a = nullptr; + StrideB* stride_b = nullptr; + + void const** ptr_a = nullptr; + void const** ptr_b = nullptr; + + // C is currently the same in both epilogues + StrideC* stride_c = nullptr; + void const** ptr_c = nullptr; + + struct DefaultEpilogue + { + using LayoutD = TransposeLayoutTag; // Layout type for D matrix operand + using StrideD = std::remove_pointer_t>; + + StrideD* stride_d = nullptr; + void** ptr_d = nullptr; + }; + + struct FusedFinalizeEpilogue + { + using StrideFinalOutput = DefaultEpilogue::StrideD; + using StrideBias = TransposeStride>; + using StrideRouterScales = TransposeStride>; + + void* ptr_final_output = nullptr; + StrideFinalOutput stride_final_output{}; + + void const* ptr_bias = nullptr; + StrideBias stride_bias{}; + + float const* ptr_router_scales = nullptr; + StrideRouterScales stride_router_scales{}; + + int64_t const* ptr_expert_first_token_offset = nullptr; + int const* ptr_source_token_index = nullptr; + + size_t num_rows_in_final_output = 0; + }; + + DefaultEpilogue default_epilogue; + FusedFinalizeEpilogue fused_finalize_epilogue; + + enum class EpilogueFusion + { + NONE, + ACTIVATION, + GATED_ACTIVATION, + FINALIZE + }; + EpilogueFusion fusion = EpilogueFusion::NONE; + + float const** alpha_scale_ptr_array = nullptr; + + uint8_t* gemm_workspace = nullptr; + size_t gemm_workspace_size = 0; + + static std::array workspaceBuffers(int num_experts); + + static size_t workspaceSize(int num_experts); + + void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size); + + bool isValid() const + { + return stride_a != nullptr && ptr_a != nullptr; + } + + void setFinalizeFusionParams(void* final_output, float const* router_scales, + int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, + int num_output_tokens); + + std::string toString() const; +}; + +// Note update moe.py to match +enum class ActivationType +{ + Gelu = 0, + Relu, + Silu, + Swiglu, + Geglu, + Identity, + InvalidType +}; + +constexpr bool isGatedActivation(ActivationType activation_type) +{ + return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; +} + +template +class MoeGemmRunner +{ +public: + MoeGemmRunner(); + +#if defined(ENABLE_FP8) + static constexpr bool use_fp8 = std::is_same_v || std::is_same_v; +#else + static constexpr bool use_fp8 = false; +#endif + + void moeGemmBiasAct(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, + ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + ActivationType activation_type, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig chosen_conf); + + void moeGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, void* C, + int64_t const* total_tokens_including_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array, + cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf); + + std::vector getConfigs() const; + static std::vector getConfigs(int sm); + static std::vector getHopperConfigs(int sm); + static std::vector getAmpereConfigs(int sm); + + [[nodiscard]] bool isHopperSpecialised(cutlass_extensions::CutlassGemmConfig gemm_config) const; + [[nodiscard]] bool supportsHopperSpecialisation() const; + [[nodiscard]] bool isFusedGatedActivation( + cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const; + [[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const; + + size_t getMaxWorkspaceSize(int num_experts) const; + + [[nodiscard]] int getSM() const; + +private: + template + void dispatchToArch(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, + ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array, + cudaStream_t stream, int* occupancy = nullptr); + + template + void runGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, + bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig chosen_conf); + +private: + int sm_{}; + int multi_processor_count_{}; + mutable int num_experts_ = 0; + mutable size_t gemm_workspace_size_ = 0; + size_t calcMaxWorkspaceSize(int num_experts) const; +}; + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu new file mode 100644 index 0000000000..3aa96502d3 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu new file mode 100644 index 0000000000..fbb5270455 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu new file mode 100644 index 0000000000..78f1a93a6a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, uint8_t, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu new file mode 100644 index 0000000000..69c4b6a15a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu new file mode 100644 index 0000000000..4ffa5485f0 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu new file mode 100644 index 0000000000..424b817b87 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu new file mode 100644 index 0000000000..f317023565 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu new file mode 100644 index 0000000000..c6b8fe7872 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_FP8 +template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>; +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; +#endif +// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h new file mode 100644 index 0000000000..2a337e6ca4 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h @@ -0,0 +1,823 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/tensor_ref.h" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#ifdef __GNUC__ // Restore GCC-specific diagnostics +#pragma GCC diagnostic pop +#endif + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" + +#include "moe_gemm_kernels_template_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" +#include + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace kernels::cutlass_kernels +{ + +// ============================= Variable batched Gemm things =========================== +template +void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* kernel_occupancy = nullptr) +{ +#if defined(ENABLE_FP8) + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for fp8, bfloat16, half, float"); +#elif defined(ENABLE_BF16) + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + "Specialized for bfloat16, half, float"); +#else + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); +#endif + + static_assert(cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + ""); + + static_assert(!cutlass::platform::is_same::value, + "Sm90 architecture should use specialised kernels"); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType = typename TllmToCutlassTypeAdapter::type; + using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter::type; + using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; + if (!use_fused_moe) + { + // We need separate config for each architecture since we will target different tensorcore instructions. For + // float, we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue::Op; + + typename EpilogueOp::Params epilogue_op( + ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); + +#if defined(ENABLE_FP8) + if constexpr ((std::is_same_v + || std::is_same_v) &&std::is_same_v) + { + TLLM_CHECK_WITH_INFO(weight_scales == nullptr && biases == nullptr && alpha_scale_ptr_array, + "weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 " + "Ada"); + epilogue_op.alpha_ptr_array = alpha_scale_ptr_array; + } +#endif + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) + { + *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); + int const threadblock_count = multi_processor_count * occupancy; + + int const group_size = gemm_k; + typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, + reinterpret_cast(A), reinterpret_cast(B), + reinterpret_cast(weight_scales), + reinterpret_cast(biases), bias_is_broadcast, + reinterpret_cast(C), total_tokens_including_expert, gemm_n, gemm_k); + + GemmGrouped gemm; + + auto can_implement = gemm.can_implement(args); + TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, + "MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); + + auto init_status = gemm.initialize(args); + TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, + "Failed to initialize cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(init_status))); + + auto run_status = gemm.run(stream); + TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, + "Failed to run cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status))); + } + else if constexpr (sizeof(ElementType) == 2 && sizeof(CutlassWeightType) == 2 + && (std::is_same_v + || std::is_same_v) ) // use fused moe gemm + // kernel.. (only support + // fp16 or bf16) + { + sm80_generic_fused_moe_gemm_kernelLauncher(reinterpret_cast(A), + reinterpret_cast(B), reinterpret_cast(biases), + bias_is_broadcast, reinterpret_cast(C), total_tokens_including_expert, num_rows, gemm_n, + gemm_k, num_experts, multi_processor_count, stream, kernel_occupancy); + } +} + +} // namespace kernels::cutlass_kernels + +template +static void dispatch(T const* A, WeightType const* B, GemmOutputType const* weight_scales, GemmOutputType const* biases, + bool bias_is_broadcast, GemmOutputType* C, int64_t const* total_tokens_including_expert, int64_t num_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + int* occupancy = nullptr) +{ + + static_assert(!std::is_same_v, "Use TMA specialised functions for arch SM90"); +#if defined(ENABLE_FP8) + constexpr bool isFp8 = std::is_same_v || std::is_same_v; +#else + constexpr bool isFp8 = false; +#endif + + if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80) + && (!isFp8 || std::is_same_v) ) + { + kernels::cutlass_kernels::genericMoeGemmKernelLauncher(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + else + { + TLLM_THROW( + "Cutlass gemm. Not instantiated for arch %d with stages set to %d", Arch::kMinComputeCapability, Stages); + } +} + +template +void dispatchGemmConfig(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.stages) + { + case 2: + dispatch(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case 3: + dispatch(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case 4: + dispatch(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break; + } +} + +// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. +// This overload is only enabled when T == WeightType. +template ::value +#if defined(ENABLE_FP8) + && !std::is_same::value && !std::is_same::value +#endif + && std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; + } +} + +// Tensorop GEMM overload +// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve +// compile time +template ::value && !std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for mixed type tensorop GEMM."); break; + } +} + +// This overload will handle tensorop gemms. +// This overload is only enabled when T == WeightType and T == __nv_fp8_e4m3 or __nv_fp8_e5m2 +#if defined(ENABLE_FP8) +template ::value || std::is_same::value) + && std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 64, 128>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; + } +} +#endif + +// This overload will handle simt gemms. It is disabled via SFINAE for tensorop. +template ::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Unsupported config for float MoE gemm."); break; + } +} + +template +std::vector +MoeGemmRunner::getConfigs() const +{ + return getConfigs(sm_); +} + +template +std::vector MoeGemmRunner::getConfigs( + int sm) +{ + std::vector candidate_configs = getHopperConfigs(sm); + std::vector ampere_configs = getAmpereConfigs(sm); + std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs)); + + return candidate_configs; +} + +template +std::vector +MoeGemmRunner::getAmpereConfigs(int sm) +{ + using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; + static constexpr auto weight_only_flag + = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; + static constexpr auto simt_only_flag + = std::is_same::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE; + static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; + int const max_split_k = 1; + int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; + int const enable_hopper = CutlassGemmConfig::NONE; + + auto config_type_param = static_cast( + weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); + + if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) + { + return {}; + } + + std::vector ampere_configs + = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + return ampere_configs; +} + +template +std::vector +MoeGemmRunner::getHopperConfigs(int sm) +{ + using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; + static constexpr auto weight_only_flag + = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; + static constexpr auto simt_only_flag + = std::is_same::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE; + int const max_split_k = 1; + int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; + int const enable_hopper = CutlassGemmConfig::HOPPER; + static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; + auto config_type_param = static_cast( + weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); + + if (!kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + { + return {}; + } + + std::vector hopper_configs + = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + return hopper_configs; +} + +template +bool MoeGemmRunner::isHopperSpecialised( + cutlass_extensions::CutlassGemmConfig gemm_config) const +{ + bool config_is_sm90 = gemm_config.is_sm90; + return supportsHopperSpecialisation() && config_is_sm90; +} + +template +bool MoeGemmRunner::supportsHopperSpecialisation() const +{ + return sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation(); +} + +template +int MoeGemmRunner::getSM() const +{ + return this->sm_; +} + +// currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction +template +bool MoeGemmRunner::supportsFusedGatedActivation( + bool is_gated_activation, int gemm_n, int gemm_k) const +{ + constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true; + return is_gated_activation && std::is_same_v && !std::is_same_v && !use_fp8 + && (this->getSM() >= 80) && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION; +} + +template +bool MoeGemmRunner::isFusedGatedActivation( + cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const +{ + return supportsFusedGatedActivation(is_gated_activation, gemm_n, gemm_k) && !gemm_config.is_sm90; +} + +template +MoeGemmRunner::MoeGemmRunner() +{ + int device{-1}; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + sm_ = tensorrt_llm::common::getSMVersion(); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +template +void MoeGemmRunner::dispatchToArch(T const* A, + WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, + void* C_void, int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy) +{ + static_assert(std::is_same_v, + "Separate Scale/Bias type is not supported. This is assumed to be the gemm output type"); + + // For now we always cast this to output type. + // In the future this will vary based on what fusions are applied for FP8 + auto* C = reinterpret_cast(C_void); + + TLLM_CHECK_WITH_INFO( + sm_ >= 89 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation"); + TLLM_CHECK_WITH_INFO( + sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture"); + + if (sm_ >= 75 && sm_ < 80) + { + dispatchMoeGemmToCutlass(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + else if (sm_ >= 80 && sm_ < 90) + { + if constexpr (use_fp8) + { +#if defined(ENABLE_FP8) + static_assert(!std::is_same_v && !std::is_same_v, + "FP8 GEMM Output not supported"); +#endif + + TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, + occupancy); + } + else + { + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, + occupancy); + } + } + else if (sm_ >= 90) + { + if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + { + + // We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens + // SM80 is faster. We check here to see which is selected + if (gemm_config.is_sm90) + { + TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr, + "Input biases and hopper input disagree if bias is enabled"); + TLLM_CHECK_WITH_INFO(hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config"); + + // Select the appropriate fusion function + auto select_function = [&]() + { + switch (hopper_input.fusion) + { + case HopperGroupedGemmInput::EpilogueFusion::FINALIZE: + return &dispatchMoeGemmSelectTileShapeSM90; + case HopperGroupedGemmInput::EpilogueFusion::NONE: + return &dispatchMoeGemmSelectTileShapeSM90; + case HopperGroupedGemmInput::EpilogueFusion::ACTIVATION: + case HopperGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION: + default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_input.fusion); + }; + }; + auto selected_func = select_function(); + selected_func( + hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr); + return; + } + + // Fallthrough to SM80 impl below + } + + // Do Ampere case instead + if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) + { + TLLM_CHECK_WITH_INFO(!hopper_input.isValid(), + "Non-specialised Hopper implementation is being rerouted to fallback implementation so input " + "information is not required"); + TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90, + "GEMM config is for SM90 configuration, but this configuration is not valid for Hppper"); + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, + occupancy); + } + else + { + TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels"); + } + } + else + { + TLLM_THROW("Arch unsupported for MoE GEMM"); + } +} + +template +size_t MoeGemmRunner::getMaxWorkspaceSize(int num_experts) const +{ + if (num_experts != num_experts_) + { + TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_); + num_experts_ = num_experts; + gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts); + } + return gemm_workspace_size_; +} + +template +size_t MoeGemmRunner::calcMaxWorkspaceSize(int num_experts) const +{ + if (!supportsHopperSpecialisation()) + { + return 0; + } + if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + { + auto configs = getHopperConfigs(sm_); + size_t max_size = 0; + bool has_config = false; + for (auto conf : configs) + { +#define CALC_SIZE_FUSION(FUSION) \ + do \ + { \ + try \ + { \ + size_t size = calcMaxWorkspaceSizeSM90( \ + num_experts, conf, multi_processor_count_); \ + max_size = std::max(max_size, size); \ + has_config = true; \ + } \ + catch (tensorrt_llm::common::TllmException const& e) \ + { \ + TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size"); \ + } \ + } while (0) + + CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::NONE); + CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::FINALIZE); + } + TLLM_CHECK_WITH_INFO(has_config, "Could not find valid config when calculating workspace size"); + return max_size; + } + else + { + TLLM_THROW("Attempting to calculate Hopper GEMM workspace size with unsupported weight combination"); + return 0; + } +} + +template +template +void MoeGemmRunner::runGemm(T const* A, WeightType const* B, + ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C, + int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array, + cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) +{ + dispatchToArch(A, B, weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, + hopper_input, total_rows, gemm_n, gemm_k, num_experts, chosen_conf, use_fused_moe, alpha_scale_ptr_array, + stream, nullptr); +} + +template +void MoeGemmRunner::moeGemmBiasAct(T const* A, WeightType const* B, + ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C, + int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) +{ + switch (activation_type) + { + case ActivationType::Relu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Gelu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Silu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Identity: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Swiglu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Geglu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; + default: TLLM_THROW("Invalid activation type."); break; + } +} + +template +void MoeGemmRunner::moeGemm(T const* A, WeightType const* B, + ScaleBiasType const* weight_scales, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig chosen_conf) +{ + runGemm(A, B, weight_scales, nullptr, true, C, total_tokens_including_expert, + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, + chosen_conf); +} + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h new file mode 100644 index 0000000000..3efb42f41e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h @@ -0,0 +1,222 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/tensor_ref.h" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion; + +template +void dispatchMoeGemmSelectBiasSM90(HopperGroupedGemmInput hopper_input, int num_experts, int multi_processor_count, + cudaStream_t stream, int* occupancy, size_t* workspace_size) +{ + static_assert(kernels::cutlass_kernels::isValidHopperMOESpecialisation(), + "Invalid hopper configuration invoked, fallback to Sm80"); + + TLLM_CHECK_WITH_INFO( + workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information"); + + // auto func = hopper_input.ptr_c ? + // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper + // : + // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper; + // TODO(dastokes) Re-enable bias when CUTLASS supports it + auto func = kernels::cutlass_kernels::sm90_generic_moe_gemm_kernelLauncher; + func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); +} + +/* + 1x1x1 cluster shape is are supported for any tile shape. + + 2x1x1 cluster shape is only supported for when the M tile is at least 128. + + 1x2x1 cluster shape is only supported when the N tile is at least 128. + + 2x2x1 cluster shape is only supported when both the M and N tiles are at least 128. + + We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels + that may not be very useful in practice. + */ +template +constexpr bool are_tile_shapes_supported() +{ + using namespace cute; + [[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{}); + [[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{}); + constexpr int cga_m = get<0>(ClusterShape{}); + constexpr int cga_n = get<1>(ClusterShape{}); + + if constexpr (cga_m == _1{} && cga_n == _1{}) + { + return true; + } + else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{}) + { + return true; + } + else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{}) + { + return true; + } + else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{}) + { + return true; + } + else + { + return false; + } +} + +template +void dispatchMoeGemmSelectClusterShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, + size_t* workspace_size) +{ + using namespace cute; + switch (gemm_config.cluster_shape) + { +#define SHAPE_CASE(M, N, K) \ + case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \ + { \ + using ClusterShape = Shape<_##M, _##N, _##K>; \ + if constexpr (are_tile_shapes_supported()) \ + { \ + dispatchMoeGemmSelectBiasSM90( \ + hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \ + break; \ + } \ + else \ + { \ + TLLM_THROW("Unsupported tile and cluster shape combination"); \ + } \ + } + + SHAPE_CASE(1, 1, 1) + SHAPE_CASE(1, 2, 1) + + SHAPE_CASE(2, 1, 1) + SHAPE_CASE(2, 2, 1) + +#undef SHAPE_CASE + default: TLLM_THROW("Unsupported config for MoE gemm."); + } +} // namespace tensorrt_llm + +template +void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, + size_t* workspace_size) +{ + using namespace cute; + + switch (gemm_config.tile_config_sm90) + { +#define SHAPE_CASE(M, N, K) \ + case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \ + { \ + constexpr int KtileBytes = K / sizeof(T); \ + using KTileDim = Int; \ + using TileShape = Shape<_##M, _##N, KTileDim>; \ + dispatchMoeGemmSelectClusterShapeSM90( \ + hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \ + break; \ + } + + SHAPE_CASE(128, 16, 128) + SHAPE_CASE(128, 32, 128) + SHAPE_CASE(128, 64, 128) + SHAPE_CASE(128, 128, 128) + SHAPE_CASE(128, 256, 128) + SHAPE_CASE(256, 128, 128) + +#undef SHAPE_CASE + case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Unsupported config for MoE gemm."); break; + } +} + +template +size_t calcMaxWorkspaceSizeSM90( + int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count) +{ + size_t count; + // Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat + dispatchMoeGemmSelectTileShapeSM90( + HopperGroupedGemmInput{}, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count); + return count; +} + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h new file mode 100644 index 0000000000..959d0ea088 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h @@ -0,0 +1,44 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/mma_sm90.h" +#include "cutlass_extensions/epilogue_helpers.h" + +namespace tensorrt_llm::kernels::cutlass_kernels +{ + +// Hopper arch +template +constexpr bool isValidHopperMOESpecialisation() +{ +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + return cutlass::platform::is_same::value + && cutlass::platform::is_same::value; +#else + return false; // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled +#endif +} + +// Hopper arch +template +constexpr bool isValidAmpereMOESpecialisation() +{ + return true; // Default to true +} + +} // namespace tensorrt_llm::kernels::cutlass_kernels From 9602c2aac76d2655d4d9aa657e60accde1cfb51f Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 31 Jan 2025 00:39:47 +0800 Subject: [PATCH 06/52] keep the parts needed for moe_kernels (#3218) --- .../tensorrt_llm/common/CMakeLists.txt | 22 - .../3rdparty/tensorrt_llm/common/assert.cpp | 0 .../3rdparty/tensorrt_llm/common/assert.h | 92 ++ .../tensorrt_llm/common/cudaDriverWrapper.cpp | 187 ++++ .../tensorrt_llm/common/cudaDriverWrapper.h | 138 +++ .../tensorrt_llm/common/cudaFp8Utils.h | 239 +++++ .../tensorrt_llm/common/cudaProfilerUtils.cpp | 84 -- .../3rdparty/tensorrt_llm/common/cudaUtils.h | 641 +++++++++++++ .../common/customAllReduceUtils.h | 36 - .../3rdparty/tensorrt_llm/common/envUtils.cpp | 214 ----- .../3rdparty/tensorrt_llm/common/envUtils.h | 60 -- .../3rdparty/tensorrt_llm/common/logger.h | 190 ++++ .../3rdparty/tensorrt_llm/common/mathUtils.h | 37 - .../tensorrt_llm/common/memoryUtils.cu | 906 ------------------ .../tensorrt_llm/common/memoryUtils.h | 292 ------ .../3rdparty/tensorrt_llm/common/mpiUtils.cpp | 588 ------------ .../3rdparty/tensorrt_llm/common/nvtxUtils.h | 46 - .../3rdparty/tensorrt_llm/common/opUtils.cpp | 323 ------- .../3rdparty/tensorrt_llm/common/opUtils.h | 215 ----- .../tensorrt_llm/common/quantization.h | 358 +++++++ .../3rdparty/tensorrt_llm/common/stlUtils.h | 123 --- .../tensorrt_llm/common/stringUtils.h | 113 +++ .../tensorrt_llm/common/timestampUtils.cpp | 42 - .../{timestampUtils.h => tllmException.h} | 27 +- 24 files changed, 1983 insertions(+), 2990 deletions(-) delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt mode change 100755 => 100644 sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/assert.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/logger.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp rename sgl-kernel/3rdparty/tensorrt_llm/common/{timestampUtils.h => tllmException.h} (50%) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt b/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt deleted file mode 100644 index e479b298db..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt +++ /dev/null @@ -1,22 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & -# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not -# use this file except in compliance with the License. You may obtain a copy of -# the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations under -# the License. -# -file(GLOB SRCS *.cpp) -file(GLOB CU_SRCS *.cu) - -add_library(common_src OBJECT ${SRCS} ${CU_SRCS}) -set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp old mode 100755 new mode 100644 diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h new file mode 100644 index 0000000000..7f51dbf1b4 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/stringUtils.h" +#include "tensorrt_llm/common/tllmException.h" + +#include + +namespace tensorrt_llm::common +{ +[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "") +{ + throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str())); +} + +} // namespace tensorrt_llm::common + +class DebugConfig +{ +public: + static bool isCheckDebugEnabled(); +}; + +#if defined(_WIN32) +#define TLLM_LIKELY(x) (__assume((x) == 1), (x)) +#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x)) +#else +#define TLLM_LIKELY(x) __builtin_expect((x), 1) +#define TLLM_UNLIKELY(x) __builtin_expect((x), 0) +#endif + +#define TLLM_CHECK(val) \ + do \ + { \ + TLLM_LIKELY(static_cast(val)) ? ((void) 0) \ + : tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \ + } while (0) + +#define TLLM_CHECK_WITH_INFO(val, info, ...) \ + do \ + { \ + TLLM_LIKELY(static_cast(val)) \ + ? ((void) 0) \ + : tensorrt_llm::common::throwRuntimeError( \ + __FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \ + } while (0) + +#define TLLM_CHECK_DEBUG(val) \ + do \ + { \ + if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \ + { \ + TLLM_LIKELY(static_cast(val)) ? ((void) 0) \ + : tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \ + } \ + } while (0) + +#define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \ + do \ + { \ + if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \ + { \ + TLLM_LIKELY(static_cast(val)) \ + ? ((void) 0) \ + : tensorrt_llm::common::throwRuntimeError( \ + __FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \ + } \ + } while (0) + +#define TLLM_THROW(...) \ + do \ + { \ + throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \ + } while (0) + +#define TLLM_WRAP(ex) \ + NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what()) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp new file mode 100644 index 0000000000..7eca46a1ca --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define CUDA_LIB_NAME "cuda" + +#if defined(_WIN32) +#include +#define dllOpen(name) LoadLibrary("nv" name ".dll") +#define dllClose(handle) FreeLibrary(static_cast(handle)) +#define dllGetSym(handle, name) static_cast(GetProcAddress(static_cast(handle), name)) +#else // For non-Windows platforms +#include +#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY) +#define dllClose(handle) dlclose(handle) +#define dllGetSym(handle, name) dlsym(handle, name) +#endif // defined(_WIN32) + +#include "cudaDriverWrapper.h" +#include "tensorrt_llm/common/assert.h" +#include +#include + +namespace tensorrt_llm::common +{ + +std::shared_ptr CUDADriverWrapper::getInstance() +{ + static std::mutex mutex; + static std::weak_ptr instance; + std::shared_ptr result = instance.lock(); + if (result) + { + return result; + } + + std::lock_guard lock(mutex); + result = instance.lock(); + if (!result) + { + result = std::shared_ptr(new CUDADriverWrapper()); + instance = result; + } + return result; +} + +CUDADriverWrapper::CUDADriverWrapper() + : handle(dllOpen(CUDA_LIB_NAME)) +{ + + TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); + + auto load_sym = [](void* handle, char const* name) + { + void* ret = dllGetSym(handle, name); + return ret; + }; + + *reinterpret_cast(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName"); + *reinterpret_cast(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage"); + *reinterpret_cast(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute"); + *reinterpret_cast(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete"); + *reinterpret_cast(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload"); + *reinterpret_cast(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy"); + *reinterpret_cast(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData"); + *reinterpret_cast(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2"); + *reinterpret_cast(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction"); + *reinterpret_cast(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2"); + *reinterpret_cast(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2"); + *reinterpret_cast(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2"); + *reinterpret_cast(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel"); + *reinterpret_cast(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel"); + *reinterpret_cast(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled"); + *reinterpret_cast(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2"); +} + +CUDADriverWrapper::~CUDADriverWrapper() +{ + dllClose(handle); +} + +CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const +{ + return (*_cuGetErrorName)(error, pStr); +} + +CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const +{ + return (*_cuGetErrorMessage)(error, pStr); +} + +CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const +{ + return (*_cuFuncSetAttribute)(hfunc, attrib, value); +} + +CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const +{ + return (*_cuLinkComplete)(state, cubinOut, sizeOut); +} + +CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const +{ + return (*_cuModuleUnload)(hmod); +} + +CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const +{ + return (*_cuLinkDestroy)(state); +} + +CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const +{ + return (*_cuModuleLoadData)(module, image); +} + +CUresult CUDADriverWrapper::cuLinkCreate( + unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const +{ + return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); +} + +CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const +{ + return (*_cuModuleGetFunction)(hfunc, hmod, name); +} + +CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const +{ + return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); +} + +CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, + unsigned int numOptions, CUjit_option* options, void** optionValues) const +{ + return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); +} + +CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, + char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const +{ + return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); +} + +CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const +{ + return (*_cuLaunchCooperativeKernel)( + f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams); +} + +CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const +{ + return (*_cuLaunchKernel)( + f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra); +} + +CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, + cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const +{ + return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides, + boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill); +} + +CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const +{ + return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount); +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h new file mode 100644 index 0000000000..c4d470a85f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CUDA_DRIVER_WRAPPER_H +#define CUDA_DRIVER_WRAPPER_H + +#include "tensorrt_llm/common/assert.h" +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +class CUDADriverWrapper +{ +public: + static std::shared_ptr getInstance(); + + ~CUDADriverWrapper(); + CUDADriverWrapper(CUDADriverWrapper const&) = delete; + CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete; + CUDADriverWrapper(CUDADriverWrapper&&) = delete; + CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete; + + CUresult cuGetErrorName(CUresult error, char const** pStr) const; + + CUresult cuGetErrorMessage(CUresult error, char const** pStr) const; + + CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; + + CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const; + + CUresult cuModuleUnload(CUmodule hmod) const; + + CUresult cuLinkDestroy(CUlinkState state) const; + + CUresult cuModuleLoadData(CUmodule* module, void const* image) const; + + CUresult cuLinkCreate( + unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; + + CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const; + + CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const; + + CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, + CUjit_option* options, void** optionValues) const; + + CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name, + unsigned int numOptions, CUjit_option* options, void** optionValues) const; + + CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const; + + CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, + CUstream hStream, void** kernelParams, void** extra) const; + + CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, + void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim, + cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const; + + CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const; + +private: + void* handle; + CUDADriverWrapper(); + + CUresult (*_cuGetErrorName)(CUresult, char const**); + CUresult (*_cuGetErrorMessage)(CUresult, char const**); + CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); + CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); + CUresult (*_cuModuleUnload)(CUmodule); + CUresult (*_cuLinkDestroy)(CUlinkState); + CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); + CUresult (*_cuModuleLoadData)(CUmodule*, void const*); + CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*); + CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*); + CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**); + CUresult (*_cuLinkAddData)( + CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**); + CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void**); + CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, + CUstream hStream, void** kernelParams, void** extra); + CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, + cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); + CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount); +}; + +template +void checkDriver( + T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line) +{ + if (result) + { + char const* errorName = nullptr; + char const* errorMsg = nullptr; + wrap.cuGetErrorName(result, &errorName); + wrap.cuGetErrorMessage(result, &errorMsg); + throw TllmException( + file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg)); + } +} + +} // namespace tensorrt_llm::common + +/* + * Macros compliant with TensorRT coding conventions + */ +#define TLLM_CU_CHECK(stat) \ + do \ + { \ + tensorrt_llm::common::checkDriver( \ + (stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \ + } while (0) + +#endif // CUDA_DRIVER_WRAPPER_H diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h new file mode 100644 index 0000000000..aa93b55a57 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_FP8 +#include +#include +#include + +#define FP8_MHA +#define FUSE_GEMM_ACT +#define FP8_GEMM_OUTPUT_QUANT_DISABLE + +#ifdef FUSE_GEMM_ACT +#define USE_QGMMA +#endif + +namespace tensorrt_llm +{ +namespace common +{ + +constexpr float FP8_E4M3_MAX = 448.0f; + +enum QuantizeMode +{ + PER_CHANNEL, + PER_TENSOR, + PER_CHANNEL_WEIGHT_PER_TENSOR_ACT, + PER_TOKEN, +}; + +// Packed Data Type +typedef struct __CUDA_ALIGN__(32) +{ + float array[8]; +} float8; + +typedef struct __CUDA_ALIGN__(16) +{ + half array[8]; +} half8; + +typedef struct __CUDA_ALIGN__(8) +{ + half2 array[2]; +} half2_2; + +typedef struct __CUDA_ALIGN__(8) +{ + half array[4]; +} half_4; + +#ifdef ENABLE_BF16 +typedef struct __CUDA_ALIGN__(4) +{ + __nv_bfloat16 array[2]; +} __nv_bfloat16_2; + +typedef struct __CUDA_ALIGN__(8) +{ + __nv_bfloat162 x, y; +} __nv_bfloat162_2_xy; + +typedef struct __CUDA_ALIGN__(8) +{ + __nv_bfloat16 array[4]; +} __nv_bfloat164; + +typedef struct __CUDA_ALIGN__(8) +{ + __nv_bfloat162 array[2]; +} __nv_bfloat162_2; + +typedef struct __CUDA_ALIGN__(16) +{ + __nv_bfloat16 array[8]; +} __nv_bfloat168; + +typedef struct __CUDA_ALIGN__(16) +{ + __nv_bfloat162 array[4]; +} __nv_bfloat162_4; + +typedef struct __CUDA_ALIGN__(32) +{ + __nv_bfloat16 array[16]; +} __nv_bfloat1616; +#endif + +#ifdef ENABLE_FP8 +typedef struct __CUDA_ALIGN__(2) +{ + __nv_fp8_e4m3 array[2]; +} __nv_fp8_2_e4m3; + +typedef struct __CUDA_ALIGN__(4) +{ + __nv_fp8_e4m3 array[4]; +} __nv_fp8_4_e4m3; + +typedef struct __CUDA_ALIGN__(4) +{ + __nv_fp8x2_e4m3 array[2]; +} __nv_fp8x2_x2_e4m3; + +typedef struct __CUDA_ALIGN__(8) +{ + __nv_fp8_e4m3 array[8]; +} __nv_fp8_8_e4m3; + +typedef struct __CUDA_ALIGN__(8) +{ + __nv_fp8x2_e4m3 array[4]; +} __nv_fp8x2_x4_e4m3; + +typedef struct __CUDA_ALIGN__(16) +{ + __nv_fp8_e4m3 array[16]; +} __nv_fp8x16_e4m3; +#endif + +// only BF16 and FP8 +template +struct PackType +{ + using type = float; +}; + +#ifdef ENABLE_BF16 +template <> +struct PackType<__nv_bfloat16, 2> +{ + using type = __nv_bfloat16_2; +}; + +template <> +struct PackType<__nv_bfloat16, 4> +{ + using type = __nv_bfloat164; +}; + +template <> +struct PackType<__nv_bfloat16, 8> +{ + using type = __nv_bfloat168; +}; +#endif + +#ifdef ENABLE_FP8 +template <> +struct PackType<__nv_fp8_e4m3, 2> +{ + using type = __nv_fp8_2_e4m3; +}; + +template <> +struct PackType<__nv_fp8_e4m3, 4> +{ + using type = __nv_fp8_4_e4m3; +}; + +template <> +struct PackType<__nv_fp8_e4m3, 8> +{ + using type = __nv_fp8_8_e4m3; +}; +#endif + +__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, __nv_fp8x4_e4m3 const* in) +{ + const char4 tmp_val = reinterpret_cast(in)[0]; + *out1 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); + *out2 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]); +} + +__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(__nv_fp8x2_e4m3 const* in) +{ + const char2 tmp_val = reinterpret_cast(in)[0]; + __nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); + return out; +} + +__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, __nv_fp8x4_e4m3 const* in) +{ + const char4 tmp_val = reinterpret_cast(in)[0]; + *out1 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); + *out2 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]); +} + +__inline__ __device__ half2 fp8x2_e4m3_to_half2(__nv_fp8x2_e4m3 const* in) +{ + const char2 tmp_val = reinterpret_cast(in)[0]; + half2 out = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); + return out; +} + +template +void invokeQuantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream); + +template +void invokeDequantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream); + +template +void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream); + +template +void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t k, const int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream); + +template +void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* weights, const int64_t numel, + const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream); + +} // namespace common +} // namespace tensorrt_llm +#endif // ENABLE_FP8 diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp deleted file mode 100644 index 5576fe782f..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/cudaProfilerUtils.h" -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/common/stringUtils.h" -#include -#include - -namespace -{ - -std::tuple, std::unordered_set> populateIterationIndexesImpl( - std::string const& envVarName) -{ - auto envVarVal = std::getenv(envVarName.c_str()); - auto envVarValStr = std::string{envVarVal != nullptr ? envVarVal : ""}; - auto values = tensorrt_llm::common::str2set(envVarValStr, ','); - std::unordered_set startSet; - std::unordered_set endSet; - for (std::string const& value : values) - { - size_t dashIdx = value.find("-"); - if (dashIdx != std::string::npos) - { - int32_t start = std::stoi(value.substr(0, dashIdx)); - startSet.insert(start); - int32_t end = std::stoi(value.substr(dashIdx + 1)); - endSet.insert(end); - } - else - { - int32_t start_end = std::stoi(value); - startSet.insert(start_end); - endSet.insert(start_end); - } - } - - return std::make_pair(startSet, endSet); -} - -} // namespace - -namespace tensorrt_llm::common -{ - -std::pair, std::unordered_set> populateIterationIndexes( - std::string const& envVarName, std::optional const& legacyEnvVarName) -{ - auto [profileIterIdxs, stopIterIdxs] = populateIterationIndexesImpl(envVarName); - - // If empty, try to use legacy env var name - if (legacyEnvVarName && profileIterIdxs.empty() && stopIterIdxs.empty()) - { - std::tie(profileIterIdxs, stopIterIdxs) = populateIterationIndexesImpl(legacyEnvVarName.value()); - - if (!profileIterIdxs.empty() || !stopIterIdxs.empty()) - { - TLLM_LOG_WARNING( - "Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. " - "Please " - "use %s " - "instead.", - legacyEnvVarName.value().c_str(), envVarName.c_str()); - } - } - - return std::make_pair(profileIterIdxs, stopIterIdxs); -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h new file mode 100644 index 0000000000..13ee3367e9 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h @@ -0,0 +1,641 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tensorrt_llm/common/cudaBf16Wrapper.h" +#include "tensorrt_llm/common/cudaDriverWrapper.h" +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/tllmException.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef _WIN32 // Linux +#include +#endif // not WIN32 +#include +#ifdef _WIN32 // Windows +#include +#undef ERROR // A Windows header file defines ERROR as 0, but it's used in our logger.h enum. Logging breaks without + // this undef. +#endif // WIN32 + +namespace tensorrt_llm::common +{ + +// workspace for cublas gemm : 32MB +#define CUBLAS_WORKSPACE_SIZE 33554432 + +typedef struct __align__(4) +{ + half x, y, z, w; +} + +half4; + +/* **************************** type definition ***************************** */ + +enum CublasDataType +{ + FLOAT_DATATYPE = 0, + HALF_DATATYPE = 1, + BFLOAT16_DATATYPE = 2, + INT8_DATATYPE = 3, + FP8_DATATYPE = 4 +}; + +enum TRTLLMCudaDataType +{ + FP32 = 0, + FP16 = 1, + BF16 = 2, + INT8 = 3, + FP8 = 4 +}; + +enum class OperationType +{ + FP32, + FP16, + BF16, + INT8, + FP8 +}; + +/* **************************** debug tools ********************************* */ +static char const* _cudaGetErrorEnum(cudaError_t error) +{ + return cudaGetErrorString(error); +} + +static char const* _cudaGetErrorEnum(cublasStatus_t error) +{ + switch (error) + { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return ""; +} + +template +void check(T result, char const* const func, char const* const file, int const line) +{ + if (result) + { + throw TllmException( + file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(result))); + } +} + +template +void checkEx(T result, std::initializer_list const& validReturns, char const* const func, char const* const file, + int const line) +{ + if (std::all_of(std::begin(validReturns), std::end(validReturns), [&result](T const& t) { return t != result; })) + { + throw TllmException( + file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(result))); + } +} + +#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) +#define check_cuda_error_2(val, file, line) check((val), #val, file, line) + +inline std::optional isCudaLaunchBlocking() +{ + static bool firstCall = true; + static std::optional result = std::nullopt; + + if (firstCall) + { + char const* env = std::getenv("CUDA_LAUNCH_BLOCKING"); + if (env != nullptr && std::string(env) == "1") + { + result = true; + } + else if (env != nullptr && std::string(env) == "0") + { + result = false; + } + firstCall = false; + } + + return result; +} + +inline bool doCheckError() +{ + auto const cudaLaunchBlocking = isCudaLaunchBlocking(); +#ifndef NDEBUG + bool const checkError = cudaLaunchBlocking.value_or(true); +#else + bool const checkError = cudaLaunchBlocking.value_or(false); +#endif + + return checkError; +} + +inline void syncAndCheck(char const* const file, int const line) +{ + if (doCheckError()) + { + cudaDeviceSynchronize(); + check(cudaGetLastError(), "cudaGetLastError", file, line); + } +} + +#define sync_check_cuda_error() tensorrt_llm::common::syncAndCheck(__FILE__, __LINE__) + +#define PRINT_FUNC_NAME_() \ + do \ + { \ + std::cout << "[TensorRT-LLM][CALL] " << __FUNCTION__ << " " << std::endl; \ + } while (0) + +// clang-format off +template struct packed_type; +template <> struct packed_type { using type = float; }; // we don't need to pack float by default +template <> struct packed_type { using type = half2; }; + +#ifdef ENABLE_BF16 +template<> +struct packed_type<__nv_bfloat16> { + using type = __nv_bfloat162; +}; +#endif + +#ifdef ENABLE_FP8 +template<> +struct packed_type<__nv_fp8_e4m3> { + using type = __nv_fp8x2_e4m3; +}; +#endif + +template struct num_elems; +template <> struct num_elems { static constexpr int value = 1; }; +template <> struct num_elems { static constexpr int value = 2; }; +template <> struct num_elems { static constexpr int value = 4; }; +template <> struct num_elems { static constexpr int value = 1; }; +template <> struct num_elems { static constexpr int value = 2; }; +#ifdef ENABLE_BF16 +template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; }; +template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; }; +#endif +#ifdef ENABLE_FP8 +template <> struct num_elems<__nv_fp8_e4m3> { static constexpr int value = 1; }; +template <> struct num_elems<__nv_fp8x2_e4m3> { static constexpr int value = 2; }; +#endif + +template struct packed_as; +template struct packed_as { using type = T; }; +template<> struct packed_as { using type = half2; }; +template<> struct packed_as { using type = float2; }; +template<> struct packed_as { using type = int16_t; }; +template<> struct packed_as { using type = int2; }; +template<> struct packed_as { using type = half; }; +template<> struct packed_as { using type = float; }; +#ifdef ENABLE_BF16 +template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; }; +template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; }; +#endif +#ifdef ENABLE_FP8 +template<> struct packed_as<__nv_fp8_e4m3, 2> { using type = __nv_fp8x2_e4m3; }; +template<> struct packed_as<__nv_fp8x2_e4m3, 1> { using type = __nv_fp8_e4m3; }; +template<> struct packed_as<__nv_fp8_e5m2, 2> { using type = __nv_fp8x2_e5m2; }; +template<> struct packed_as<__nv_fp8x2_e5m2, 1> { using type = __nv_fp8_e5m2; }; +#endif + +inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } +inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); } +inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); } + +inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } +inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); } +inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); } + +// clang-format on + +template +struct CudaDataType +{ +}; + +template <> +struct CudaDataType +{ + static constexpr cudaDataType_t value = cudaDataType::CUDA_R_32F; +}; + +template <> +struct CudaDataType +{ + static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16F; +}; + +#ifdef ENABLE_BF16 +template <> +struct CudaDataType<__nv_bfloat16> +{ + static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16BF; +}; +#endif + +inline int getSMVersion() +{ + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); + check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +inline int getDevice() +{ + int current_dev_id = 0; + check_cuda_error(cudaGetDevice(¤t_dev_id)); + return current_dev_id; +} + +inline int getDeviceCount() +{ + int count = 0; + check_cuda_error(cudaGetDeviceCount(&count)); + return count; +} + +/// @brief Identifies the memory type of the given pointer. +template +cudaMemoryType getPtrCudaMemoryType(T* ptr) +{ + cudaPointerAttributes attributes{}; + check_cuda_error(cudaPointerGetAttributes(&attributes, ptr)); + return attributes.type; +} + +/// Get the memory info +/// \return The free and total amount of memory in bytes +inline std::tuple getDeviceMemoryInfo(bool const useUvm) +{ + if (useUvm) + { + size_t freeSysMem = 0; + size_t totalSysMem = 0; +#ifndef _WIN32 // Linux + struct sysinfo info + { + }; + + sysinfo(&info); + totalSysMem = info.totalram * info.mem_unit; + freeSysMem = info.freeram * info.mem_unit; +#else // Windows + MEMORYSTATUSEX memInfo; + memInfo.dwLength = sizeof(memInfo); + GlobalMemoryStatusEx(&memInfo); + totalSysMem = memInfo.ullTotalPhys; + freeSysMem = memInfo.ullAvailPhys; +#endif // WIN32 + + TLLM_LOG_INFO("Using UVM based system memory for KV cache, total memory %0.2f GB, available memory %0.2f GB", + ((double) totalSysMem / 1e9), ((double) freeSysMem / 1e9)); + return {freeSysMem, totalSysMem}; + } + + size_t free = 0; + size_t total = 0; + check_cuda_error(cudaMemGetInfo(&free, &total)); + TLLM_LOG_DEBUG("Using GPU memory for KV cache, total memory %0.2f GB, available memory %0.2f GB", + ((double) total / 1e9), ((double) free / 1e9)); + return {free, total}; +} + +/// @brief Gets the memory allocation granularity for the current device. +/// +/// @return size_t The size of the smallest difference in memory size supported by the current device. +inline size_t getAllocationGranularity() +{ + auto const currentDevice = getDevice(); + ::CUmemAllocationProp prop = {}; + + prop.type = ::CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = ::CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = currentDevice; + prop.requestedHandleTypes = ::CU_MEM_HANDLE_TYPE_NONE; + + // Get the minimum granularity supported for allocation with cuMemCreate() + size_t granularity = 0; + TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + return granularity; +} + +inline int getMultiProcessorCount() +{ + int device_id = 0; + int multi_processor_count = 0; + check_cuda_error(cudaGetDevice(&device_id)); + check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id)); + return multi_processor_count; +} + +inline int getMaxSharedMemoryPerBlockOptin() +{ + int device_id = 0; + int max_shared_memory_per_block = 0; + check_cuda_error(cudaGetDevice(&device_id)); + check_cuda_error( + cudaDeviceGetAttribute(&max_shared_memory_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id)); + return max_shared_memory_per_block; +} + +template +inline size_t divUp(const T1& a, const T2& n) +{ + auto const tmp_a = static_cast(a); + auto const tmp_n = static_cast(n); + return (tmp_a + tmp_n - 1) / tmp_n; +} + +inline int roundUp(int a, int n) +{ + return divUp(a, n) * n; +} + +template ::value>, + typename = std::enable_if_t::value>> +auto constexpr ceilDiv(T numerator, U denominator) +{ + return (numerator + denominator - 1) / denominator; +} + +template +void printAbsMean(T const* buf, uint64_t size, cudaStream_t stream, std::string name = "") +{ + if (buf == nullptr) + { + TLLM_LOG_WARNING("%s is an nullptr, skip!", name.c_str()); + return; + } + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + T* h_tmp = new T[size]; + cudaMemcpyAsync(h_tmp, buf, sizeof(T) * size, cudaMemcpyDeviceToHost, stream); + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + double sum = 0.0f; + uint64_t zero_count = 0; + float max_val = -1e10; + bool find_inf = false; + for (uint64_t i = 0; i < size; i++) + { + if (std::isinf((float) (h_tmp[i]))) + { + find_inf = true; + continue; + } + sum += abs((double) h_tmp[i]); + if ((float) h_tmp[i] == 0.0f) + { + zero_count++; + } + max_val = max_val > abs(float(h_tmp[i])) ? max_val : abs(float(h_tmp[i])); + } + TLLM_LOG_INFO("%20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, find inf: %s", name.c_str(), size, sum / size, + sum, max_val, find_inf ? "true" : "false"); + delete[] h_tmp; + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +} + +template +void printToStream(T const* result, int const size, FILE* strm) +{ + bool const split_rows = (strm == stdout); + if (result == nullptr) + { + TLLM_LOG_WARNING("It is an nullptr, skip! \n"); + return; + } + T* tmp = reinterpret_cast(malloc(sizeof(T) * size)); + check_cuda_error(cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost)); + for (int i = 0; i < size; ++i) + { + fprintf(strm, "%f, ", static_cast(tmp[i])); + if (split_rows && ((i + 1) % 10) == 0) + fprintf(strm, "\n"); + } + if (!split_rows || (size % 10) != 0) + { + fprintf(strm, "\n"); + } + free(tmp); +} + +template +void printToScreen(T const* result, int const size) +{ + printToStream(result, size, stdout); +} + +template +void print2dToStream(T const* result, int const r, int const c, int const stride, FILE* strm) +{ + if (result == nullptr) + { + TLLM_LOG_WARNING("It is an nullptr, skip! \n"); + return; + } + for (int ri = 0; ri < r; ++ri) + { + T const* ptr = result + ri * stride; + printToStream(ptr, c, strm); + } + fprintf(strm, "\n"); +} + +template +void print2dToScreen(T const* result, int const r, int const c, int const stride) +{ + print2dToStream(result, r, c, stride, stdout); +} + +template +void print2dToFile(std::string fname, T const* result, int const r, int const c, int const stride) +{ + FILE* fp = fopen(fname.c_str(), "wt"); + if (fp != nullptr) + { + print2dToStream(result, r, c, stride, fp); + fclose(fp); + } +} + +inline void print_float_(float x) +{ + printf("%7.3f ", x); +} + +inline void print_element_(float x) +{ + print_float_(x); +} + +inline void print_element_(half x) +{ + print_float_((float) x); +} + +#ifdef ENABLE_BF16 +inline void print_element_(__nv_bfloat16 x) +{ + print_float_((float) x); +} +#endif + +#ifdef ENABLE_FP8 +inline void print_element_(__nv_fp8_e4m3 x) +{ + print_float_((float) x); +} +#endif + +inline void print_element_(uint32_t ul) +{ + printf("%7" PRIu32, ul); +} + +inline void print_element_(uint64_t ull) +{ + printf("%7" PRIu64, ull); +} + +inline void print_element_(int32_t il) +{ + printf("%7" PRId32, il); +} + +inline void print_element_(int64_t ill) +{ + printf("%7" PRId64, ill); +} + +template +inline void printMatrix(T const* ptr, int m, int k, int stride, bool is_device_ptr) +{ + T* tmp; + if (is_device_ptr) + { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } + else + { + tmp = const_cast(ptr); + } + + for (int ii = -1; ii < m; ++ii) + { + if (ii >= 0) + { + printf("%07d ", ii); + } + else + { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) + { + if (ii >= 0) + { + print_element_(tmp[ii * stride + jj]); + } + else + { + printf("%7d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) + { + free(tmp); + } +} + +template void printMatrix(float const* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(half const* ptr, int m, int k, int stride, bool is_device_ptr); +#ifdef ENABLE_BF16 +template void printMatrix(__nv_bfloat16 const* ptr, int m, int k, int stride, bool is_device_ptr); +#endif +#ifdef ENABLE_FP8 +template void printMatrix(__nv_fp8_e4m3 const* ptr, int m, int k, int stride, bool is_device_ptr); +#endif +template void printMatrix(uint32_t const* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(uint64_t const* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(int const* ptr, int m, int k, int stride, bool is_device_ptr); + +} // namespace tensorrt_llm::common + +/* + * Macros compliant with TensorRT coding conventions + */ +#define TLLM_CUDA_CHECK(stat) \ + do \ + { \ + tensorrt_llm::common::check((stat), #stat, __FILE__, __LINE__); \ + } while (0) + +// We use singleton memory pool and the order of destructors depends on the compiler implementation. We find that the +// cudaFree/cudaFreeHost is called after cudaruntime destruction on Windows. There will be an cudaErrorCudartUnloading +// error. However, it is safe to ignore this error because the cuda runtime is already exited, we are no more worried +// about the memory leaks. +#define TLLM_CUDA_CHECK_FREE_RESOURCE(stat) \ + do \ + { \ + tensorrt_llm::common::checkEx((stat), {cudaSuccess, cudaErrorCudartUnloading}, #stat, __FILE__, __LINE__); \ + } while (0) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h deleted file mode 100644 index d7bf43b407..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace tensorrt_llm::utils::customAllReduceUtils -{ - -constexpr size_t NUM_POINTERS_PER_RANK = 7; - -// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py -inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept -{ - if (worldSize <= 2) - { - return 16 * 1000 * 1000; - } - return 8 * 1000 * 1000; -} - -} // namespace tensorrt_llm::utils::customAllReduceUtils diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp deleted file mode 100644 index 64d3d44acb..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp +++ /dev/null @@ -1,214 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "envUtils.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/common/logger.h" -#include - -namespace tensorrt_llm::common -{ - -std::optional getIntEnv(char const* name) -{ - char const* const env = std::getenv(name); - if (env == nullptr) - { - return std::nullopt; - } - int32_t const val = std::stoi(env); - if (val <= 0) - { - return std::nullopt; - } - return {val}; -}; - -// Returns true if the env variable exists and is set to "1" -static bool getBoolEnv(char const* name) -{ - char const* env = std::getenv(name); - return env && env[0] == '1' && env[1] == '\0'; -} - -// XQA kernels (optimized kernels for generation phase). -bool forceXQAKernels() -{ - static bool const forceXQA = (getIntEnv("TRTLLM_FORCE_XQA").value_or(0) != 0); - return forceXQA; -} - -std::optional getEnvEnableXQAJIT() -{ - static bool init = false; - static bool exists = false; - static bool enableXQAJIT = false; - if (!init) - { - init = true; - char const* enable_xqa_jit_var = std::getenv("TRTLLM_ENABLE_XQA_JIT"); - if (enable_xqa_jit_var) - { - exists = true; - if (enable_xqa_jit_var[0] == '1' && enable_xqa_jit_var[1] == '\0') - { - enableXQAJIT = true; - } - } - } - if (exists) - { - return enableXQAJIT; - } - else - { - return std::nullopt; - } -} - -// Tune the number of blocks per sequence for accuracy/performance purpose. -bool getEnvMmhaMultiblockDebug() -{ - static bool init = false; - static bool forceMmhaMaxSeqLenTile = false; - if (!init) - { - init = true; - char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG"); - if (enable_mmha_debug_var) - { - if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0') - { - forceMmhaMaxSeqLenTile = true; - } - } - } - return forceMmhaMaxSeqLenTile; -} - -int getEnvMmhaBlocksPerSequence() -{ - static bool init = false; - static int mmhaBlocksPerSequence = 0; - if (!init) - { - init = true; - char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE"); - if (mmhaBlocksPerSequenceEnv) - { - mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv); - if (mmhaBlocksPerSequence <= 0) - { - TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!"); - } - } - } - return mmhaBlocksPerSequence; -} - -int getEnvMmhaKernelBlockSize() -{ - static bool init = false; - static int mmhaKernelBlockSize = 0; - if (!init) - { - init = true; - char const* mmhaKernelBlockSizeEnv = std::getenv("TRTLLM_MMHA_KERNEL_BLOCK_SIZE"); - if (mmhaKernelBlockSizeEnv) - { - mmhaKernelBlockSize = std::atoi(mmhaKernelBlockSizeEnv); - if (mmhaKernelBlockSize <= 0) - { - TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_KERNEL_BLOCK_SIZE. Will use default values instead!"); - } - } - } - return mmhaKernelBlockSize; -} - -bool getEnvEnablePDL() -{ - static bool init = false; - static bool enablePDL = false; - if (!init) - { - init = true; - // PDL only available when arch >= 90 - if (getSMVersion() >= 90) - { - // PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1` - enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL"); - } - } - return enablePDL; -} - -bool getEnvUseUCXKvCache() -{ - static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE"); - return useUCXKVCache; -} - -std::string getEnvUCXInterface() -{ - static bool init = false; - static std::string ucxInterface; - if (!init) - { - init = true; - { - char const* ucx_interface = std::getenv("TRTLLM_UCX_INTERFACE"); - if (ucx_interface) - { - ucxInterface = ucx_interface; - } - } - } - return ucxInterface; -} - -bool getEnvDisaggLayerwise() -{ - static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE"); - return disaggLayerwise; -} - -bool getEnvParallelCacheSend() -{ - static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND"); - return parallelCacheSend; -} - -bool getEnvRequestKVCacheSerial() -{ - static bool const requestKVCacheSerial = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_SERIAL"); - return requestKVCacheSerial; -} - -bool getEnvDisableKVCacheTransferOverlap() -{ - static bool const disableKVCacheTransferOverlap = getBoolEnv("TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP"); - return disableKVCacheTransferOverlap; -} - -bool getEnvDisableReceiveKVCacheParallel() -{ - static bool const disableReceiveParallel = getBoolEnv("TRTLLM_DISABLE_KVCACHE_RECEIVE_PARALLEL"); - return disableReceiveParallel; -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h deleted file mode 100644 index 027c7cfbb3..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h +++ /dev/null @@ -1,60 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include - -namespace tensorrt_llm::common -{ -// Useful when you want to inject some debug code controllable with env var. -std::optional getIntEnv(char const* name); - -// XQA kernels (optimized kernels for generation phase). -bool forceXQAKernels(); - -// Whether XQA JIT is enabled. -// -// Returns the value of TRTLLM_ENABLE_XQA_JIT env var. If such env var doesn't exist, std::nullopt is returned. -std::optional getEnvEnableXQAJIT(); - -// Tune the number of blocks per sequence for accuracy/performance purpose. -bool getEnvMmhaMultiblockDebug(); - -int getEnvMmhaBlocksPerSequence(); - -int getEnvMmhaKernelBlockSize(); - -// Whether PDL is enabled. -bool getEnvEnablePDL(); - -bool getEnvUseUCXKvCache(); - -std::string getEnvUCXInterface(); - -bool getEnvDisaggLayerwise(); - -bool getEnvParallelCacheSend(); - -bool getEnvRequestKVCacheSerial(); - -bool getEnvDisableKVCacheTransferOverlap(); - -bool getEnvDisableReceiveKVCacheParallel(); - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h new file mode 100644 index 0000000000..df84e22638 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/stringUtils.h" + +namespace tensorrt_llm::common +{ + +class Logger +{ + +// On Windows, the file wingdi.h is included which has +// #define ERROR 0 +// This breaks everywhere ERROR is used in the Level enum +#ifdef _WIN32 +#undef ERROR +#endif // _WIN32 + +public: + enum Level + { + TRACE = 0, + DEBUG = 10, + INFO = 20, + WARNING = 30, + ERROR = 40 + }; + + static Logger* getLogger(); + + Logger(Logger const&) = delete; + void operator=(Logger const&) = delete; + +#if defined(_MSC_VER) + template + void log(Level level, char const* format, Args const&... args); + + template + void log(Level level, int rank, char const* format, Args const&... args); +#else + template + void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0))); + + template + void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0))); +#endif + + template + void log(Level level, std::string const& format, Args const&... args) + { + return log(level, format.c_str(), args...); + } + + template + void log(Level const level, int const rank, std::string const& format, Args const&... args) + { + return log(level, rank, format.c_str(), args...); + } + + void log(std::exception const& ex, Level level = Level::ERROR); + + Level getLevel() const + { + return level_; + } + + void setLevel(Level const level) + { + level_ = level; + log(INFO, "Set logger level to %s", getLevelName(level)); + } + + bool isEnabled(Level const level) const + { + return level_ <= level; + } + +private: + static auto constexpr kPREFIX = "[TensorRT-LLM]"; + +#ifndef NDEBUG + Level const DEFAULT_LOG_LEVEL = DEBUG; +#else + Level const DEFAULT_LOG_LEVEL = INFO; +#endif + Level level_ = DEFAULT_LOG_LEVEL; + + Logger(); // NOLINT(modernize-use-equals-delete) + + static inline char const* getLevelName(Level const level) + { + switch (level) + { + case TRACE: return "TRACE"; + case DEBUG: return "DEBUG"; + case INFO: return "INFO"; + case WARNING: return "WARNING"; + case ERROR: return "ERROR"; + } + + TLLM_THROW("Unknown log level: %d", level); + } + + static inline std::string getPrefix(Level const level) + { + return fmtstr("%s[%s] ", kPREFIX, getLevelName(level)); + } + + static inline std::string getPrefix(Level const level, int const rank) + { + return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank); + } +}; + +template +void Logger::log(Logger::Level level, char const* format, Args const&... args) +{ + if (isEnabled(level)) + { + auto const fmt = getPrefix(level) + format; + auto& out = level_ < WARNING ? std::cout : std::cerr; + if constexpr (sizeof...(args) > 0) + { + out << fmtstr(fmt.c_str(), args...); + } + else + { + out << fmt; + } + out << std::endl; + } +} + +template +void Logger::log(Logger::Level const level, int const rank, char const* format, Args const&... args) +{ + if (isEnabled(level)) + { + auto const fmt = getPrefix(level, rank) + format; + auto& out = level_ < WARNING ? std::cout : std::cerr; + if constexpr (sizeof...(args) > 0) + { + out << fmtstr(fmt.c_str(), args...); + } + else + { + out << fmt; + } + out << std::endl; + } +} + +#define TLLM_LOG(level, ...) \ + do \ + { \ + auto* const logger = tensorrt_llm::common::Logger::getLogger(); \ + if (logger->isEnabled(level)) \ + { \ + logger->log(level, __VA_ARGS__); \ + } \ + } while (0) + +#define TLLM_LOG_TRACE(...) TLLM_LOG(tensorrt_llm::common::Logger::TRACE, __VA_ARGS__) +#define TLLM_LOG_DEBUG(...) TLLM_LOG(tensorrt_llm::common::Logger::DEBUG, __VA_ARGS__) +#define TLLM_LOG_INFO(...) TLLM_LOG(tensorrt_llm::common::Logger::INFO, __VA_ARGS__) +#define TLLM_LOG_WARNING(...) TLLM_LOG(tensorrt_llm::common::Logger::WARNING, __VA_ARGS__) +#define TLLM_LOG_ERROR(...) TLLM_LOG(tensorrt_llm::common::Logger::ERROR, __VA_ARGS__) +#define TLLM_LOG_EXCEPTION(ex, ...) tensorrt_llm::common::Logger::getLogger()->log(ex, ##__VA_ARGS__) +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h deleted file mode 100644 index 1bad3a2c15..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace tensorrt_llm -{ -namespace common -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ __host__ T divUp(T m, T n) -{ - return (m + n - 1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu deleted file mode 100644 index d13217b203..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu +++ /dev/null @@ -1,906 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaTypeUtils.cuh" -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/common/memoryUtils.h" - -#include -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ - -template -void deviceMalloc(T** ptr, size_t size, bool is_random_initialize) -{ - check_cuda_error(cudaMalloc((void**) (ptr), sizeof(T) * size)); - if (is_random_initialize) - { - cudaRandomUniform(*ptr, size); - } -} - -template void deviceMalloc(float** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(half** ptr, size_t size, bool is_random_initialize); -#ifdef ENABLE_BF16 -template void deviceMalloc(__nv_bfloat16** ptr, size_t size, bool is_random_initialize); -#endif -template void deviceMalloc(uint16_t** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(int** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(bool** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(char** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(int8_t** ptr, size_t size, bool is_random_initialize); -#ifdef ENABLE_FP8 -template void deviceMalloc(__nv_fp8_e4m3** ptr, size_t size, bool is_random_initialize); -#endif - -template -void deviceMemSetZero(T* ptr, size_t size) -{ - check_cuda_error(cudaMemset(static_cast(ptr), 0, sizeof(T) * size)); -} - -template void deviceMemSetZero(float* ptr, size_t size); -template void deviceMemSetZero(half* ptr, size_t size); -template void deviceMemSetZero(int* ptr, size_t size); -template void deviceMemSetZero(uint32_t* ptr, size_t size); -template void deviceMemSetZero(bool* ptr, size_t size); -#ifdef ENABLE_FP8 -template void deviceMemSetZero(__nv_fp8_e4m3* ptr, size_t size); -#endif -#ifdef ENABLE_BF16 -template void deviceMemSetZero(__nv_bfloat16* ptr, size_t size); -#endif - -template -void deviceFree(T*& ptr) -{ - if (ptr != NULL) - { - check_cuda_error(cudaFree(ptr)); - ptr = NULL; - } -} - -template void deviceFree(float*& ptr); -template void deviceFree(half*& ptr); -#ifdef ENABLE_BF16 -template void deviceFree(__nv_bfloat16*& ptr); -#endif -template void deviceFree(unsigned short*& ptr); -template void deviceFree(int*& ptr); -template void deviceFree(bool*& ptr); -template void deviceFree(char*& ptr); -template void deviceFree(int8_t*& ptr); -#ifdef ENABLE_FP8 -template void deviceFree(__nv_fp8_e4m3*& ptr); -#endif - -template -void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream) -{ - T* arr = new T[size]; - std::fill(arr, arr + size, value); - check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream)); - delete[] arr; -} - -template void deviceFill(float* devptr, size_t size, float value, cudaStream_t stream); -template void deviceFill(half* devptr, size_t size, half value, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void deviceFill(__nv_bfloat16* devptr, size_t size, __nv_bfloat16 value, cudaStream_t stream); -#endif -template void deviceFill(int* devptr, size_t size, int value, cudaStream_t stream); -template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream); - -template -void cudaD2Hcpy(T* tgt, T const* src, const size_t size) -{ - check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost)); -} - -template void cudaD2Hcpy(float* tgt, float const* src, size_t size); -template void cudaD2Hcpy(half* tgt, half const* src, size_t size); -#ifdef ENABLE_BF16 -template void cudaD2Hcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); -#endif -template void cudaD2Hcpy(int* tgt, int const* src, size_t size); -template void cudaD2Hcpy(bool* tgt, bool const* src, size_t size); -#ifdef ENABLE_FP8 -template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); -#endif -template void cudaD2Hcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); -template void cudaD2Hcpy(unsigned int* tgt, unsigned int const* src, size_t size); -template void cudaD2Hcpy(int8_t* tgt, int8_t const* src, size_t size); - -template -void cudaH2Dcpy(T* tgt, T const* src, const size_t size) -{ - check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice)); -} - -template void cudaH2Dcpy(float* tgt, float const* src, size_t size); -template void cudaH2Dcpy(half* tgt, half const* src, size_t size); -#ifdef ENABLE_BF16 -template void cudaH2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); -#endif -template void cudaH2Dcpy(int* tgt, int const* src, size_t size); -template void cudaH2Dcpy(bool* tgt, bool const* src, size_t size); -#ifdef ENABLE_FP8 -template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); -#endif -template void cudaH2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); -template void cudaH2Dcpy(unsigned int* tgt, unsigned int const* src, size_t size); -template void cudaH2Dcpy(int8_t* tgt, int8_t const* src, size_t size); - -template -void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) -{ - check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream)); -} - -template void cudaD2Dcpy(float* tgt, float const* src, size_t size, cudaStream_t stream); -template void cudaD2Dcpy(half* tgt, half const* src, size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void cudaD2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); -#endif -template void cudaD2Dcpy(int* tgt, int const* src, size_t size, cudaStream_t stream); -template void cudaD2Dcpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); -template void cudaD2Dcpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); -#ifdef ENABLE_FP8 -template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size, cudaStream_t stream); -#endif -template void cudaD2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); - -template -__global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (T_OUT) ((float) (src[tid])); - } -} - -template -void invokeCudaCast(T_OUT* dst, T_IN const* const src, const size_t size, cudaStream_t stream) -{ - cudaCast<<<256, 256, 0, stream>>>(dst, src, size); -} - -template void invokeCudaCast(float* dst, half const* const src, const size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeCudaCast(float* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(__nv_bfloat16* dst, float const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(__nv_bfloat16* dst, half const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(half* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); -#endif -#ifdef ENABLE_FP8 -template void invokeCudaCast(float* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast( - __nv_bfloat16* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(half* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(__nv_fp8_e4m3* dst, float const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast( - __nv_fp8_e4m3* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(__nv_fp8_e4m3* dst, half const* const src, const size_t size, cudaStream_t stream); -#endif - -template -void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) -{ - if (stream != NULL) - { - check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDefault, stream)); - } - else - { - check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDefault)); - } -} - -template void cudaAutoCpy(float* tgt, float const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(half* tgt, half const* src, size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void cudaAutoCpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); -#endif -template void cudaAutoCpy(int* tgt, int const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(uint8_t* tgt, uint8_t const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(uint32_t* tgt, uint32_t const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(unsigned long* tgt, unsigned long const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(char* tgt, char const* src, size_t size, cudaStream_t stream); - -template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void cudaAutoCpy(__nv_bfloat16 const** tgt, __nv_bfloat16 const* const* src, size_t size, cudaStream_t stream); -#endif -template void cudaAutoCpy(int const** tgt, int const* const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(bool const** tgt, bool const* const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(int8_t const** tgt, int8_t const* const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy( - unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream); - -template -__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, int const seq_offset) -{ - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - curandState_t local_state; - curand_init((unsigned long long int) 1337, idx + seq_offset, 0, &local_state); - for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) - { - buffer[index] = (T) (curand_uniform(&local_state) * 0.2f - 0.1f); - } -} - -template <> -__global__ void cuda_random_uniform_kernel(int* buffer, const size_t size, int const seq_offset) -{ - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - curandState_t local_state; - curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); - for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) - { - buffer[index] = curand(&local_state); - } -} - -template <> -__global__ void cuda_random_uniform_kernel(bool* buffer, const size_t size, int const seq_offset) -{ - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - curandState_t local_state; - curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); - for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) - { - buffer[index] = (curand(&local_state) % 2 == 0); - } -} - -template <> -__global__ void cuda_random_uniform_kernel(char* buffer, const size_t size, int const seq_offset) -{ - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - curandState_t local_state; - curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); - for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) - { - buffer[index] = curand(&local_state) % 0xFF; - } -} - -template -void cudaRandomUniform(T* buffer, const size_t size) -{ - static int seq_offset = 0; - cuda_random_uniform_kernel<<<256, 256>>>(buffer, size, seq_offset); - seq_offset += 256 * 256; -} - -template void cudaRandomUniform(float* buffer, const size_t size); -template void cudaRandomUniform(half* buffer, const size_t size); -#ifdef ENABLE_BF16 -template void cudaRandomUniform(__nv_bfloat16* buffer, const size_t size); -#endif -template void cudaRandomUniform(int* buffer, const size_t size); -template void cudaRandomUniform(bool* buffer, const size_t size); -template void cudaRandomUniform(char* buffer, const size_t size); -#ifdef ENABLE_FP8 -template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size); -#endif - -// loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or -// the product of the elements in shape is 0, this function will return an empty vector. -template -std::vector loadWeightFromBinHelper(std::vector shape, std::string filename) -{ - if (shape.size() > 2) - { - printf("[ERROR] shape should have less than two dims \n"); - return std::vector(); - } - size_t dim0 = shape[0], dim1 = 1; - if (shape.size() == 2) - { - dim1 = shape[1]; - } - size_t size = dim0 * dim1; - if (size == 0) - { - TLLM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); - return std::vector(); - } - - std::vector host_array(size); - std::ifstream in(filename, std::ios::in | std::ios::binary); - if (!in.is_open()) - { - TLLM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); - return std::vector(); - } - - size_t loaded_data_size = sizeof(T) * size; - in.seekg(0, in.end); - in.seekg(0, in.beg); - - TLLM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); - in.read((char*) host_array.data(), loaded_data_size); - - size_t in_get_size = in.gcount(); - if (in_get_size != loaded_data_size) - { - TLLM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n", filename.c_str(), - in_get_size, loaded_data_size); - return std::vector(); - } - in.close(); - // If we succeed, return an array with values. - return host_array; -} - -template -int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) -{ - std::vector host_array = loadWeightFromBinHelper(shape, filename); - - if (host_array.empty()) - { - return 0; - } - - if (std::is_same::value == true) - { - cudaH2Dcpy(ptr, (T*) host_array.data(), host_array.size()); - } - else - { - T_IN* ptr_2 = nullptr; - deviceMalloc(&ptr_2, host_array.size(), false); - cudaH2Dcpy(ptr_2, host_array.data(), host_array.size()); - invokeCudaD2DcpyConvert(ptr, ptr_2, host_array.size()); - deviceFree(ptr_2); - } - return 0; -} - -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(int8_t* ptr, std::vector shape, std::string filename); -#ifdef ENABLE_BF16 -template int loadWeightFromBinFunc<__nv_bfloat16, float>( - __nv_bfloat16* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc<__nv_bfloat16, half>( - __nv_bfloat16* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>( - __nv_bfloat16* ptr, std::vector shape, std::string filename); -#endif // ENABLE_BF16 -template int loadWeightFromBinFunc(int* ptr, std::vector shape, std::string filename); -#ifdef ENABLE_FP8 -template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>( - __nv_fp8_e4m3* ptr, std::vector shape, std::string filename); -#endif // ENABLE_FP8 - -template -int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type) -{ - switch (model_file_type) - { - case TRTLLMCudaDataType::FP32: loadWeightFromBinFunc(ptr, shape, filename); break; - case TRTLLMCudaDataType::FP16: loadWeightFromBinFunc(ptr, shape, filename); break; - case TRTLLMCudaDataType::INT8: loadWeightFromBinFunc(ptr, shape, filename); break; -#ifdef ENABLE_BF16 - case TRTLLMCudaDataType::BF16: loadWeightFromBinFunc(ptr, shape, filename); break; -#endif -#ifdef ENABLE_FP8 - case TRTLLMCudaDataType::FP8: loadWeightFromBinFunc(ptr, shape, filename); break; -#endif - default: TLLM_LOG_ERROR("Does not support TRTLLMCudaDataType=%d", model_file_type); TLLM_CHECK(false); - } - return 0; -} - -template <> -int loadWeightFromBin(int* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type) -{ - loadWeightFromBinFunc(ptr, shape, filename); - return 0; -} - -template int loadWeightFromBin( - float* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); -template int loadWeightFromBin( - half* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); -template int loadWeightFromBin( - int8_t* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); -#ifdef ENABLE_BF16 -template int loadWeightFromBin( - __nv_bfloat16* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); -#endif -#ifdef ENABLE_FP8 -template int loadWeightFromBin( - __nv_fp8_e4m3* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); -#endif -template int loadWeightFromBin( - int* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); - -template -__global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = cuda_cast(src[tid]); - } -} - -template -void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cudaStream_t stream) -{ - cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size); -} - -template void invokeCudaD2DcpyConvert(int8_t* tgt, float const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, int8_t const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, int const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(half* tgt, int const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, float const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(half* tgt, float const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, half const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(uint32_t* tgt, int const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, uint32_t const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, float const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, half const* src, const size_t size, cudaStream_t stream); - -#ifdef ENABLE_BF16 -template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, float const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, int const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); -#endif // ENABLE_BF16 - -template -__global__ void cudaD2DScaleCpyConvert( - T_OUT* dst, const T_IN* src, float const* scale, bool invert_scale, const size_t size) -{ - float const scale_value = invert_scale ? 1.0f / scale[0] : scale[0]; - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = cuda_cast(cuda_cast(src[tid]) * scale_value); - } -} - -template -void invokeCudaD2DScaleCpyConvert( - T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream) -{ - cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size); -} - -// clang-format off -template void invokeCudaD2DScaleCpyConvert(float* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const float* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -template void invokeCudaD2DScaleCpyConvert(half* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const half* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeCudaD2DScaleCpyConvert(__nv_bfloat16* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const __nv_bfloat16* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -#endif // ENABLE_BF16 -#ifdef ENABLE_FP8 -template void invokeCudaD2DScaleCpyConvert(float* tgt, const __nv_fp8_e4m3* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -#endif // ENABLE_FP8 -// clang-format on - -void invokeCudaD2DcpyHalf2Float(float* dst, half* src, const size_t size, cudaStream_t stream) -{ - invokeCudaD2DcpyConvert(dst, src, size, stream); -} - -void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const size_t size, cudaStream_t stream) -{ - invokeCudaD2DcpyConvert(dst, src, size, stream); -} - -template -void saveToBinary(T const* ptr, const size_t size, std::string filename) -{ - - std::vector h_ptr(size); - cudaD2Hcpy(h_ptr.data(), ptr, size); - std::vector float_ptr(size); - for (size_t i = 0; i < size; i++) - { - float_ptr[i] = (float) h_ptr[i]; - } - - std::ofstream out(filename, std::ios::out | std::ios::binary); - TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); - - out.write((char*) float_ptr.data(), size * sizeof(float)); -} - -template void saveToBinary(float const* ptr, const size_t size, std::string filename); -template void saveToBinary(half const* ptr, const size_t size, std::string filename); -#ifdef ENABLE_BF16 -template void saveToBinary(__nv_bfloat16 const* ptr, const size_t size, std::string filename); -#endif // ENABLE_BF16 - -template <> -void saveToBinary(int const* ptr, const size_t size, std::string filename) -{ - std::vector h_ptr(size); - cudaD2Hcpy(h_ptr.data(), ptr, size); - std::ofstream out(filename, std::ios::out | std::ios::binary); - TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); - out.write((char*) h_ptr.data(), size * sizeof(int)); -} - -template -__global__ void fakeCast(T_IN* input_ptr, const size_t size) -{ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) - { - T_fake_type tmp_val = (T_fake_type) ((float) input_ptr[i]); - input_ptr[i] = (T_IN) ((float) tmp_val); - } -} - -template -void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream) -{ - dim3 block(256); - dim3 grid((size + 255) / 256); - fakeCast<<>>(input_ptr, size); -} - -#ifdef ENABLE_FP8 -__global__ void cudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (float) (src[tid]); - } -} - -void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream) -{ - cudaD2Dcpyfp82Float<<<256, 256, 0, stream>>>(dst, src, size); -} - -__global__ void cudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (half) ((float) (src[tid])); - } -} - -void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream) -{ - cudaD2Dcpyfp82Half<<<256, 256, 0, stream>>>(dst, src, size); -} - -__global__ void cudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (__nv_fp8_e4m3) src[tid]; - } -} - -void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size, cudaStream_t stream) -{ - cudaD2DcpyFloat2fp8<<<256, 256, 0, stream>>>(dst, src, size); -} - -__global__ void cudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (__nv_fp8_e4m3) src[tid]; - } -} - -void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size, cudaStream_t stream) -{ - cudaD2DcpyHalf2fp8<<<256, 256, 0, stream>>>(dst, src, size); -} - -__global__ void cudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (__nv_fp8_e4m3) src[tid]; - } -} - -void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size, cudaStream_t stream) -{ - cudaD2DcpyBfloat2fp8<<<256, 256, 0, stream>>>(dst, src, size); -} - -#endif // ENABLE_FP8 - -template -__global__ void transpose(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1; tid += blockDim.x * gridDim.x) - { - const size_t src_col_id = tid % dim1; - const size_t src_row_id = tid / dim1; - dst[src_col_id * dim0 + src_row_id] = (T_OUT) (src[tid]); - } -} - -template -void invokeInPlaceTranspose(T* data, T* workspace, const size_t dim0, const size_t dim1) -{ - // copy data to workspace, and then transpose from workspace to data - cudaD2Dcpy(workspace, data, dim0 * dim1); - transpose<<<256, 256>>>(data, workspace, dim0, dim1); -} - -#ifdef ENABLE_FP8 -template void invokeInPlaceTranspose( - __nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1); -#endif // ENABLE_FP8 -#ifdef ENABLE_BF16 -template void invokeInPlaceTranspose( - __nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1); -#endif // ENABLE_BF16 -template void invokeInPlaceTranspose(float* data, float* workspace, const size_t dim0, const size_t dim1); - -template -__global__ void transpose0213( - T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3) -{ - // src permutation: [0, 1, 2, 3] - // dst permutation: [0, 2, 1, 3] - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2 * dim3; - tid += blockDim.x * gridDim.x) - { - size_t tmp_idx = tid; - const size_t dim_3_idx = tmp_idx % dim3; - tmp_idx = (tmp_idx - dim_3_idx) / dim3; - const size_t dim_2_idx = tmp_idx % dim2; - tmp_idx = (tmp_idx - dim_2_idx) / dim2; - const size_t dim_1_idx = tmp_idx % dim1; - tmp_idx = (tmp_idx - dim_1_idx) / dim1; - const size_t dim_0_idx = tmp_idx % dim0; - dst[dim_0_idx * dim1 * dim2 * dim3 + dim_2_idx * dim1 * dim3 + dim_1_idx * dim3 + dim_3_idx] = src[tid]; - } -} - -template -void invokeInPlaceTranspose0213( - T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3) -{ - // copy data to workspace, and then transpose from workspace to data - // Note that this kernel is used for pre-processing and not very efficient. - cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2 * dim3); - transpose0213<<<256, 256>>>(data, workspace, dim0, dim1, dim2, dim3); -} - -#ifdef ENABLE_FP8 -template void invokeInPlaceTranspose0213(__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, - const size_t dim1, const size_t dim2, const size_t dim3); -#endif // ENABLE_FP8 -#ifdef ENABLE_BF16 -template void invokeInPlaceTranspose0213(__nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, - const size_t dim1, const size_t dim2, const size_t dim3); -#endif // ENABLE_BF16 -template void invokeInPlaceTranspose0213( - float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3); - -template -__global__ void transpose102(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2) -{ - // src permutation: [0, 1, 2] - // dst permutation: [1, 0, 2] - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2; tid += blockDim.x * gridDim.x) - { - size_t tmp_idx = tid; - const size_t dim_2_idx = tmp_idx % dim2; - tmp_idx = (tmp_idx - dim_2_idx) / dim2; - const size_t dim_1_idx = tmp_idx % dim1; - tmp_idx = (tmp_idx - dim_1_idx) / dim1; - const size_t dim_0_idx = tmp_idx % dim0; - dst[dim_1_idx * dim0 * dim2 + dim_0_idx * dim2 + dim_2_idx] = src[tid]; - } -} - -template -void invokeInPlaceTranspose102(T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2) -{ - // copy data to workspace, and then transpose from workspace to data - // Note that this kernel is used for pre-processing and not very efficient. - cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2); - transpose102<<<256, 256>>>(data, workspace, dim0, dim1, dim2); -} - -#ifdef ENABLE_FP8 -template void invokeInPlaceTranspose102( - __nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1, const size_t dim2); -#endif // ENABLE_FP8 -#ifdef ENABLE_BF16 -template void invokeInPlaceTranspose102( - __nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1, const size_t dim2); -#endif // ENABLE_BF16 -template void invokeInPlaceTranspose102( - float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2); - -template -void __global__ multiplyScale(T* tensor, float scale, const size_t size) -{ - for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x) - { - tensor[index] = (T) (((float) tensor[index]) * scale); - } -} - -template -void invokeMultiplyScale(T* tensor, float scale, const size_t size, cudaStream_t stream) -{ - int block = 256; - int grid = (size + 255) / 256; - multiplyScale<<>>(tensor, scale, size); -} - -template void invokeMultiplyScale(float* tensor, float scale, const size_t size, cudaStream_t stream); -template void invokeMultiplyScale(half* tensor, float scale, const size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeMultiplyScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream); -#endif -#ifdef ENABLE_FP8 -template void invokeMultiplyScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream); -#endif - -template -void __global__ divideScale(T* tensor, float scale, const size_t size) -{ - for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x) - { - tensor[index] = (T) (((float) tensor[index]) / scale); - } -} - -template -void invokeDivideScale(T* tensor, float scale, const size_t size, cudaStream_t stream) -{ - int block = 256; - int grid = (size + 255) / 256; - divideScale<<>>(tensor, scale, size); -} - -template void invokeDivideScale(float* tensor, float scale, const size_t size, cudaStream_t stream); -template void invokeDivideScale(half* tensor, float scale, const size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeDivideScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream); -#endif -#ifdef ENABLE_FP8 -template void invokeDivideScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream); -#endif -#ifdef ENABLE_BF16 -template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); -template void invokeFakeCast<__nv_bfloat16, __nv_bfloat16>( - __nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream); -template void invokeFakeCast(half* input_ptr, const size_t size, cudaStream_t stream); -#endif -template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); -template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); -#ifdef ENABLE_FP8 -template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); -template void invokeFakeCast(half* input_ptr, const size_t size, cudaStream_t stream); -template void invokeFakeCast<__nv_bfloat16, __nv_fp8_e4m3>( - __nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream); -#endif - -size_t cuda_datatype_size(TRTLLMCudaDataType dt) -{ - static const std::unordered_map sizes{ - {TRTLLMCudaDataType::FP32, sizeof(float)}, {TRTLLMCudaDataType::FP16, sizeof(half)} -#ifdef ENABLE_BF16 - , - {TRTLLMCudaDataType::BF16, sizeof(__nv_bfloat16)} -#endif - }; - - return sizes.at(dt); -} - -template -__global__ void check_range(T const* buffer, size_t size, T min, T max, bool* d_within_range) -{ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) - { - const T val = buffer[i]; - if (val < min || val > max) - { - *d_within_range = false; - } - } -} - -template -bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream) -{ - cudaMemsetAsync(d_within_range, true, sizeof(bool), stream); - - dim3 block(256); - dim3 grid((size + 255) / 256); - check_range<<>>(buffer, size, min, max, d_within_range); - - bool result; - cudaD2Hcpy(&result, d_within_range, 1); - return result; -} - -template bool invokeCheckRange( - int const* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream); - -/* - * Determine the total workspace size based on a vector containing multiple variable sizes. - */ -size_t calcAlignedSize(std::vector const& sizes, const size_t ALIGN_BYTES) -{ - const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); - // Check ALIGN_BYTES is a power of 2 - assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0); - - size_t total = 0; - for (auto sz : sizes) - { - total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; - } - - // We add extra "ALIGN_BYTES - 1" bytes in case the start address passed to the function calcAlignedPointers() is - // not aligned. - return total + ALIGN_BYTES - 1; -} - -/* - * Given the address of the workspace and the vector containing multiple variable sizes, calculate the start addresses - * of each variable. - */ -void calcAlignedPointers( - std::vector& outPtrs, void const* p, std::vector const& sizes, size_t ALIGN_BYTES) -{ - const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); - // Check ALIGN_BYTES is a power of 2 - assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0); - - // In case the start address is not aligned - char* ptr = reinterpret_cast((reinterpret_cast(p) + ALIGN_BYTES - 1) & ALIGN_MASK); - - outPtrs.reserve(sizes.size()); - for (auto sz : sizes) - { - outPtrs.push_back(ptr); - ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; - } -} - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h deleted file mode 100644 index 9e413a1beb..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include "tensorrt_llm/common/cudaUtils.h" - -#include - -namespace tensorrt_llm -{ -namespace common -{ - -template -void deviceMalloc(T** ptr, size_t size, bool is_random_initialize = true); - -template -void deviceMemSetZero(T* ptr, size_t size); - -template - -void deviceFree(T*& ptr); - -template -void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0); - -template -void cudaD2Hcpy(T* tgt, T const* src, size_t const size); - -template -void cudaH2Dcpy(T* tgt, T const* src, size_t const size); - -template -void cudaD2Dcpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL); - -template -void cudaAutoCpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL); - -template -void cudaRandomUniform(T* buffer, size_t const size); - -template -int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, - TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32); - -// template -// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr, -// T* scale_ptr, -// std::vector shape, -// std::string filename, -// TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32); - -void invokeCudaD2DcpyHalf2Float(float* dst, half* src, size_t const size, cudaStream_t stream); -void invokeCudaD2DcpyFloat2Half(half* dst, float* src, size_t const size, cudaStream_t stream); -#ifdef ENABLE_FP8 -void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream); -void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream); -void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, size_t const size, cudaStream_t stream); -void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, size_t const size, cudaStream_t stream); -void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream); -#endif // ENABLE_FP8 -#ifdef ENABLE_BF16 -void invokeCudaD2DcpyBfloat2Float(float* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream); -#endif // ENABLE_BF16 - -template -void invokeCudaCast(T_OUT* dst, T_IN const* const src, size_t const size, cudaStream_t stream); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// The following functions implement conversion of multi-dimensional indices to an index in a flat array. -// The shape of the Tensor dimensions is passed as one array (`dims`), the indices are given as individual arguments. -// For examples on how to use these functions, see their tests `test_memory_utils.cu`. -// All of these functions can be evaluated at compile time by recursive template expansion. - -template -__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( - T const& acc, TDim dims, TIndex const& index) -{ - assert(index < dims[0]); - return acc * dims[0] + index; -} - -template -__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( - T const& acc, TDim dims, TIndex const& index, TIndices... indices) -{ - assert(index < dims[0]); - return flat_index(acc * dims[0] + index, dims + 1, indices...); -} - -template -__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( - [[maybe_unused]] TDim dims, T const& index) -{ - assert(index < dims[0]); - return index; -} - -template -__inline__ __host__ __device__ - std::enable_if_t::value, typename std::remove_pointer::type> constexpr flat_index( - TDim dims, TIndex const& index, TIndices... indices) -{ - assert(index < dims[0]); - return flat_index(static_cast::type>(index), dims + 1, indices...); -} - -template -__inline__ __host__ __device__ T constexpr flat_index( - std::array const& dims, TIndex const& index, TIndices... indices) -{ - static_assert(skip < N); - static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); - return flat_index(&dims[skip], index, indices...); -} - -template -__inline__ __host__ __device__ T constexpr flat_index( - T const& acc, std::array const& dims, TIndex const& index, TIndices... indices) -{ - static_assert(skip < N); - static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); - return flat_index(acc, &dims[skip], index, indices...); -} - -template -__inline__ __host__ __device__ T constexpr flat_index(T const (&dims)[N], TIndex const& index, TIndices... indices) -{ - static_assert(skip < N); - static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); - return flat_index(static_cast(dims) + skip, index, indices...); -} - -template -__inline__ __host__ __device__ T constexpr flat_index( - T const& acc, T const (&dims)[N], TIndex const& index, TIndices... indices) -{ - static_assert(skip < N); - static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); - return flat_index(acc, static_cast(dims) + skip, index, indices...); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// These are simpler functions for multi-dimensional index conversion. Indices and dimensions are passed as individual -// arguments. These functions are more suitable for usage inside kernels than the corresponding flat_index functions -// which require arrays as arguments. Usage examples can be found in `test_memory_utils.cu`. The functions can be -// evaluated at compile time. - -template -__inline__ __host__ __device__ T constexpr flat_index2(TIndex const& index_0, TIndex const& index_1, T const& dim_1) -{ - assert(index_1 < dim_1); - return index_0 * dim_1 + index_1; -} - -template -__inline__ __host__ __device__ T constexpr flat_index3( - TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& dim_1, T const& dim_2) -{ - assert(index_2 < dim_2); - return flat_index2(index_0, index_1, dim_1) * dim_2 + index_2; -} - -template -__inline__ __host__ __device__ T constexpr flat_index4(TIndex const& index_0, TIndex const& index_1, - TIndex const& index_2, TIndex const& index_3, T const& dim_1, T const& dim_2, T const& dim_3) -{ - assert(index_3 < dim_3); - return flat_index3(index_0, index_1, index_2, dim_1, dim_2) * dim_3 + index_3; -} - -template -__inline__ __host__ __device__ T constexpr flat_index5(TIndex const& index_0, TIndex const& index_1, - TIndex const& index_2, TIndex const& index_3, TIndex const& index_4, T const& dim_1, T const& dim_2, T const& dim_3, - T const& dim_4) -{ - assert(index_4 < dim_4); - return flat_index4(index_0, index_1, index_2, index_3, dim_1, dim_2, dim_3) * dim_4 + index_4; -} - -template -__inline__ __host__ __device__ T constexpr flat_index_strided3( - TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& stride_1, T const& stride_2) -{ - assert(index_1 < stride_1 / stride_2); - assert(index_2 < stride_2); - return index_0 * stride_1 + index_1 * stride_2 + index_2; -} - -template -__inline__ __host__ __device__ T constexpr flat_index_strided4(TIndex const& index_0, TIndex const& index_1, - TIndex const& index_2, TIndex const& index_3, T const& stride_1, T const& stride_2, T const& stride_3) -{ - assert(index_1 < stride_1 / stride_2); - assert(index_2 < stride_2 / stride_3); - assert(index_3 < stride_3); - return index_0 * stride_1 + index_1 * stride_2 + index_2 * stride_3 + index_3; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void invokeInPlaceTranspose(T* data, T* workspace, size_t const dim0, size_t const dim1); - -template -void invokeInPlaceTranspose0213( - T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2, size_t const dim3); - -template -void invokeInPlaceTranspose102(T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2); - -template -void invokeMultiplyScale(T* tensor, float scale, size_t const size, cudaStream_t stream); - -template -void invokeDivideScale(T* tensor, float scale, size_t const size, cudaStream_t stream); - -template -void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, size_t const size, cudaStream_t stream = 0); - -template -void invokeCudaD2DScaleCpyConvert( - T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, size_t const size, cudaStream_t stream = 0); - -inline bool checkIfFileExist(std::string const& file_path) -{ - std::ifstream in(file_path, std::ios::in | std::ios::binary); - if (in.is_open()) - { - in.close(); - return true; - } - return false; -} - -template -void saveToBinary(T const* ptr, size_t const size, std::string filename); - -template -void invokeFakeCast(T_IN* input_ptr, size_t const size, cudaStream_t stream); - -size_t cuda_datatype_size(TRTLLMCudaDataType dt); - -template -bool invokeCheckRange(T const* buffer, size_t const size, T min, T max, bool* d_within_range, cudaStream_t stream); - -constexpr size_t DEFAULT_ALIGN_BYTES = 256; - -size_t calcAlignedSize(std::vector const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES); -void calcAlignedPointers(std::vector& outPtrs, void const* p, std::vector const& sizes, - size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES); - -struct AlignedPointersUnpacker -{ - template - void operator()(T*&... outPtrs) - { - assert(sizeof...(T) == alignedPointers.size()); - auto it = alignedPointers.begin(); - ((outPtrs = static_cast(*it++)), ...); - } - - std::vector alignedPointers; -}; - -AlignedPointersUnpacker inline calcAlignedPointers( - void const* p, std::vector const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES) -{ - AlignedPointersUnpacker unpacker{}; - calcAlignedPointers(unpacker.alignedPointers, p, sizes, ALIGN_BYTES); - return unpacker; -} - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp deleted file mode 100644 index dbdaca4ee7..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp +++ /dev/null @@ -1,588 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "tensorrt_llm/common/mpiUtils.h" - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/runtime/common.h" -#include "tensorrt_llm/runtime/iBuffer.h" - -#include -#include -#include -#include -#include -#ifndef _WIN32 -#include -#endif - -// We rely on SizeType32 being int32_t in some places with weak type checking, -// i.e. we're passing void ptr to some function. To prevent mysterious errors -// in the future, we trigger a compilation error here if SizeType32 isn't int32_t. -static_assert(std::is_same::value); - -namespace tensorrt_llm::mpi -{ - -MPI_Datatype getMpiDtype(MpiType dtype) -{ -#if ENABLE_MULTI_DEVICE - static std::unordered_map const dtype_map{ - {MpiType::kBYTE, MPI_BYTE}, - {MpiType::kHALF, MPI_UINT16_T}, - {MpiType::kFLOAT, MPI_FLOAT}, - {MpiType::kDOUBLE, MPI_DOUBLE}, - {MpiType::kBOOL, MPI_C_BOOL}, - {MpiType::kINT8, MPI_INT8_T}, - {MpiType::kUINT8, MPI_UINT8_T}, - {MpiType::kINT32, MPI_INT32_T}, - {MpiType::kUINT32, MPI_UINT32_T}, - {MpiType::kINT64, MPI_INT64_T}, - {MpiType::kUINT64, MPI_UINT64_T}, - {MpiType::kFP8, MPI_UINT8_T}, - {MpiType::kBF16, MPI_UINT16_T}, - {MpiType::kCHAR, MPI_CHAR}, - }; - return dtype_map.at(dtype); -#else - TLLM_THROW("Multi device support is disabled."); -#endif -} - -MPI_Op getMpiOp(MpiOp op) -{ -#if ENABLE_MULTI_DEVICE - static std::unordered_map const op_map{ - {MpiOp::NULLOP, MPI_OP_NULL}, - {MpiOp::MAX, MPI_MAX}, - {MpiOp::MIN, MPI_MIN}, - {MpiOp::SUM, MPI_SUM}, - {MpiOp::PROD, MPI_PROD}, - {MpiOp::LAND, MPI_LAND}, - {MpiOp::BAND, MPI_BAND}, - {MpiOp::LOR, MPI_LOR}, - {MpiOp::BOR, MPI_BOR}, - {MpiOp::LXOR, MPI_LXOR}, - {MpiOp::BXOR, MPI_BXOR}, - {MpiOp::MINLOC, MPI_MINLOC}, - {MpiOp::MAXLOC, MPI_MAXLOC}, - {MpiOp::REPLACE, MPI_REPLACE}, - }; - return op_map.at(op); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -namespace -{ - -bool mpiInitialized = false; -std::recursive_mutex mpiMutex; - -MpiComm initLocalSession() -{ -#if ENABLE_MULTI_DEVICE - MPI_Comm localComm = nullptr; - MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm); - MpiComm localSession{localComm, false}; -#else - MpiComm localSession{COMM_SESSION, false}; -#endif // ENABLE_MULTI_DEVICE - return localSession; -} - -} // namespace - -std::vector getWorldRanks(MpiComm const& comm) -{ -#if ENABLE_MULTI_DEVICE - MPI_Group group = nullptr; - MPI_Group worldGroup = nullptr; - - MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); - MPICHECK(MPI_Comm_group(comm, &group)); - - int groupSize = 0; - MPICHECK(MPI_Group_size(group, &groupSize)); - std::vector ranks(groupSize); - std::vector worldRanks(groupSize); - std::iota(ranks.begin(), ranks.end(), 0); - - MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data())); - MPICHECK(MPI_Group_free(&group)); - MPICHECK(MPI_Group_free(&worldGroup)); -#else - std::vector worldRanks{0}; -#endif - return worldRanks; -} - -void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent) -{ - // double-checked locking - if (mpiInitialized) - { - return; - } - std::lock_guard lk(mpiMutex); - if (mpiInitialized) - { - return; - } -#if ENABLE_MULTI_DEVICE - int initialized = 0; - TLLM_MPI_CHECK(MPI_Initialized(&initialized)); - if (!initialized) - { - TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode); - int providedMode = 0; - auto requiredMode = static_cast(threadMode); - MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode)); - TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed"); - std::atexit([]() { MPI_Finalize(); }); - - /* - * We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2 - * signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers - * correctly. - */ - for (int sig : {SIGABRT, SIGSEGV}) - { - __sighandler_t previousHandler = nullptr; - if (forwardAbortToParent) - { - previousHandler = std::signal(sig, - [](int signal) - { -#ifndef _WIN32 - pid_t parentProcessId = getppid(); - kill(parentProcessId, SIGKILL); -#endif - MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); - }); - } - else - { - previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); }); - } - TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed"); - } - - // ensure local MPI communicator is initialized - MpiComm::localSession(); - TLLM_LOG_INFO("Initialized MPI"); - } -#endif // ENABLE_MULTI_DEVICE - mpiInitialized = true; -} - -void MpiComm::barrier() const -{ -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Barrier(mComm)); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -#if ENABLE_MULTI_DEVICE -template >>> -size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args) -{ - constexpr auto maxP1 = static_cast(std::numeric_limits::max()) + 1; - if (TLLM_LIKELY(size < maxP1)) - { - MPICHECK(func(buffer, size, dtype, args...)); - return 1; - } - - constexpr size_t alignment = 256; - int elementSize = 1; - MPICHECK(MPI_Type_size(dtype, &elementSize)); - elementSize = std::min(elementSize, alignment); - - // We cap at max alignment-bytes chunks that can be sent at once. - auto const step = maxP1 - (alignment / elementSize); - - using TCast = std::conditional_t, uint8_t const, uint8_t>; - size_t count = 0; - while (size != 0) - { - auto currentStep = static_cast(std::min(size, step)); - MPICHECK(func(buffer, currentStep, dtype, args...)); - size -= currentStep; - size_t diff = static_cast(currentStep) * elementSize; - buffer = static_cast(buffer) + diff; - ++count; - } - - return count; -} -#endif // ENABLE_MULTI_DEVICE - -std::shared_ptr MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const -{ - std::shared_ptr r = std::make_shared(); -#if ENABLE_MULTI_DEVICE - invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE - return r; -} - -std::shared_ptr MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const -{ - TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU); - return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); -} - -void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const -{ -#if ENABLE_MULTI_DEVICE - invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -void MpiComm::bcast(runtime::IBuffer& buf, int root) const -{ - bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); -} - -std::shared_ptr MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const -{ - TLLM_LOG_DEBUG("start MPI_Isend with size %d", size); - std::shared_ptr r = std::make_shared(); -#if ENABLE_MULTI_DEVICE - invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, tag, mComm, &r->mRequest); -#else - TLLM_THROW("Multi device support is disabled."); -#endif - TLLM_LOG_DEBUG("end MPI_Isend with size %d", size); - return r; -} - -std::shared_ptr MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, int tag) const -{ - return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); -} - -void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const -{ - TLLM_LOG_DEBUG("start MPI_Send with size %d", size); -#if ENABLE_MULTI_DEVICE - invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE - TLLM_LOG_DEBUG("end MPI_Send with size %d", size); -} - -void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const -{ - send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); -} - -MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const -{ - TLLM_LOG_DEBUG("start MPI_Recv with size %d", size); - MPI_Status status{}; -#if ENABLE_MULTI_DEVICE - invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE - TLLM_LOG_DEBUG("end MPI_Recv with size %d", size); - return status; -} - -MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const -{ - return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag); -} - -MpiComm MpiComm::split(int color, int key) const -{ - MPI_Comm splitComm = nullptr; -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm)); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE - return MpiComm{splitComm, true}; -} - -void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const -{ -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm)); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const -{ -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm)); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf, - std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const -{ -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(), - getMpiDtype(recvtype), mComm)); - -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const -{ -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status)); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const -{ -#if ENABLE_MULTI_DEVICE - int flag{0}; - MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status)); - return flag != 0; -#else - TLLM_THROW("Multi device support is disabled."); - return false; -#endif -} - -bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const -{ -#if ENABLE_MULTI_DEVICE - int flag{0}; - MPICHECK(MPI_Iprobe(source, tag, mComm, &flag, status)); - return flag != 0; -#else - TLLM_THROW("Multi device support is disabled."); - return false; -#endif -} - -void MpiComm::recvPoll(int source, int tag, int periodMs) const -{ - MPI_Status status; - while (!iprobe(source, tag, &status)) - { - std::this_thread::sleep_for(std::chrono::milliseconds(periodMs)); - } -} - -int MpiComm::getRank() const -{ - int rank = 0; -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Comm_rank(mComm, &rank)); -#endif - return rank; -} - -int MpiComm::getSize() const -{ - int world_size = 1; -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Comm_size(mComm, &world_size)); -#endif - return world_size; -} - -MpiComm const& MpiComm::world() -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - static MpiComm commWorld{MPI_COMM_WORLD, false}; - initialize(); - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return commWorld; -} - -MpiComm& MpiComm::mutableSession() -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - static MpiComm commSession{MPI_COMM_WORLD, false}; - initialize(); - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return commSession; -} - -MpiComm& MpiComm::mutableLocalSession() -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - static MpiComm localSession = initLocalSession(); - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return localSession; -} - -void MpiComm::refreshLocalSession() -{ -#if ENABLE_MULTI_DEVICE - static std::mutex mutex; - std::unique_lock lock(mutex); - auto initSessionRanks = getWorldRanks(MpiComm::session()); - auto localSessionRanks = getWorldRanks(MpiComm::localSession()); - - // Add to intersectionRanks in order of initSessionRanks - std::vector intersectionRanks; - std::unordered_set localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end()); - for (auto rank : initSessionRanks) - { - if (localSessionRanksSet.find(rank) != localSessionRanksSet.end()) - { - intersectionRanks.push_back(rank); - } - } - - MPI_Group worldGroup = nullptr; - MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); - MPI_Group localGroup = nullptr; - MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup)); - MPI_Comm localComm = nullptr; - MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm)); - MpiComm::mutableLocalSession().mFreeComm = true; - MpiComm::mutableLocalSession() = MpiComm{localComm, false}; - TLLM_LOG_INFO("Refreshed the MPI local session"); -#endif // ENABLE_MULTI_DEVICE -} - -MpiComm::MpiComm(MPI_Comm g, bool freeComm) - : mComm{g} - , mFreeComm{freeComm} -{ - TLLM_CHECK(mComm != MPI_COMM_NULL); -} - -MpiComm::~MpiComm() noexcept -{ -#if ENABLE_MULTI_DEVICE - if (mFreeComm && mComm) - { - if (MPI_Comm_free(&mComm) != MPI_SUCCESS) - { - TLLM_LOG_ERROR("MPI_Comm_free failed"); - } - } -#endif // ENABLE_MULTI_DEVICE -} - -MpiComm::MpiComm(MpiComm&& comm) noexcept - : mComm{comm.mComm} - , mFreeComm{comm.mFreeComm} -{ - comm.mFreeComm = false; -} - -MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept -{ - this->~MpiComm(); - mComm = comm.mComm; - mFreeComm = comm.mFreeComm; - comm.mFreeComm = false; - return *this; -} - -MpiWaitThread::MpiWaitThread(std::string name, std::function funcWait, std::function funcSetup) - : mName{name.c_str()} - , mFuncWait{funcWait} - , mFuncSetup{funcSetup} -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - mThread = std::make_unique(&MpiWaitThread::sideThread, this); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -MpiWaitThread::~MpiWaitThread() -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - waitStop(); - mShouldExit.store(true); - notifyStart(); - mThread->join(); - mThread.reset(nullptr); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -void MpiWaitThread::sideThread() -{ - if (mFuncSetup) - { - mFuncSetup(); - } - while (!mShouldExit.load()) - { - notifyStop(); - waitStart(); - mFuncWait(); - } -} - -void MpiWaitThread::waitStart() -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - std::unique_lock lock(mMutex); - mCondVar.wait(lock, [this] { return mRunning; }); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -void MpiWaitThread::waitStop() -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - std::unique_lock lock(mMutex); - mCondVar.wait(lock, [this] { return !mRunning; }); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -void MpiWaitThread::notifyStart() -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - std::lock_guard lock(mMutex); - mRunning = true; - mCondVar.notify_one(); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -void MpiWaitThread::notifyStop() -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - std::lock_guard lock(mMutex); - mRunning = false; - mCondVar.notify_one(); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -} // namespace tensorrt_llm::mpi diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h deleted file mode 100644 index 0a9d51975a..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include - -namespace tensorrt_llm::common::nvtx -{ -inline nvtx3::color nextColor() -{ -#ifndef NVTX_DISABLE - constexpr std::array kColors{nvtx3::color{0xff00ff00}, nvtx3::color{0xff0000ff}, nvtx3::color{0xffffff00}, - nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}}; - constexpr auto numColors = kColors.size(); - - static thread_local std::size_t colorId = 0; - auto const color = kColors[colorId]; - colorId = colorId + 1 >= numColors ? 0 : colorId + 1; - return color; -#else - return nvtx3::color{0}; -#endif -} - -} // namespace tensorrt_llm::common::nvtx - -#define NVTX3_SCOPED_RANGE_WITH_NAME(range, name) \ - ::nvtx3::scoped_range range(::tensorrt_llm::common::nvtx::nextColor(), name) -#define NVTX3_SCOPED_RANGE(range) NVTX3_SCOPED_RANGE_WITH_NAME(range##_range, #range) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp deleted file mode 100644 index 39aefda481..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp +++ /dev/null @@ -1,323 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tensorrt_llm/common/opUtils.h" -#include "tensorrt_llm/common/mpiUtils.h" - -#include "cuda.h" -#include -#include -#include -#include -#include -#include -#include - -#ifdef _MSC_VER -#define FN_NAME __FUNCTION__ -#else -#define FN_NAME __func__ -#endif - -#if ENABLE_MULTI_DEVICE - -std::unordered_map* getDtypeMap() -{ - static std::unordered_map dtypeMap = {{nvinfer1::DataType::kFLOAT, ncclFloat32}, - {nvinfer1::DataType::kHALF, ncclFloat16}, {nvinfer1::DataType::kBF16, ncclBfloat16}}; - return &dtypeMap; -} - -namespace -{ - -// Get NCCL unique ID for a group of ranks. -ncclUniqueId getUniqueId(std::set const& group) noexcept -{ - auto const rank = COMM_SESSION.getRank(); - TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); - ncclUniqueId id; - if (rank == *group.begin()) - { - NCCLCHECK(ncclGetUniqueId(&id)); - for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) - { - COMM_SESSION.sendValue(id, *it, 0); - } - } - else - { - COMM_SESSION.recvValue(id, *group.begin(), 0); - } - TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); - return id; -} -} // namespace - -std::shared_ptr getComm(std::set const& group) -{ - auto const rank = COMM_SESSION.getRank(); - TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); - static std::map, std::shared_ptr> commMap; - static std::mutex mutex; - std::lock_guard lock(mutex); - std::ostringstream oss; - int index = 0; - for (auto const& rank : group) - { - if (index != 0) - { - oss << ","; - } - oss << rank; - index++; - } - auto groupStr = oss.str(); - auto it = commMap.find(group); - if (it != commMap.end()) - { - auto ncclComm = it->second; - TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); - return ncclComm; - } - - TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); - ncclUniqueId id = getUniqueId(group); - int groupRank = 0; - for (auto const& currentRank : group) - { - if (rank == currentRank) - break; - ++groupRank; - } - TLLM_CHECK(groupRank < group.size()); - std::shared_ptr ncclComm(new ncclComm_t, - [](ncclComm_t* comm) - { - ncclCommDestroy(*comm); - delete comm; - }); - NCCLCHECK(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); - commMap[group] = ncclComm; - TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); - return ncclComm; -} -#endif // ENABLE_MULTI_DEVICE - -void const* tensorrt_llm::common::getCommSessionHandle() -{ -#if ENABLE_MULTI_DEVICE - return &COMM_SESSION; -#else - return nullptr; -#endif // ENABLE_MULTI_DEVICE -} - -namespace -{ - -// Get current cuda context, a default context will be created if there is no context. -inline CUcontext getCurrentCudaCtx() -{ - CUcontext ctx{}; - CUresult err = cuCtxGetCurrent(&ctx); - if (err == CUDA_ERROR_NOT_INITIALIZED || ctx == nullptr) - { - TLLM_CUDA_CHECK(cudaFree(nullptr)); - err = cuCtxGetCurrent(&ctx); - } - TLLM_CHECK(err == CUDA_SUCCESS); - return ctx; -} - -// Helper to create per-cuda-context singleton managed by std::shared_ptr. -// Unlike conventional singletons, singleton created with this will be released -// when not needed, instead of on process exit. -// Objects of this class shall always be declared static / global, and shall never own CUDA -// resources. -template -class PerCudaCtxSingletonCreator -{ -public: - using CreatorFunc = std::function()>; - using DeleterFunc = std::function; - - // creator returning std::unique_ptr is by design. - // It forces separation of memory for T and memory for control blocks. - // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. - // creator itself must not own CUDA resources. Only the object it creates can. - PerCudaCtxSingletonCreator(CreatorFunc creator, DeleterFunc deleter) - : mCreator{std::move(creator)} - , mDeleter{std::move(deleter)} - { - } - - std::shared_ptr operator()() - { - std::lock_guard lk{mMutex}; - CUcontext ctx{getCurrentCudaCtx()}; - std::shared_ptr result = mObservers[ctx].lock(); - if (result == nullptr) - { - // Create the resource and register with an observer. - result = std::shared_ptr{mCreator().release(), - [this, ctx](T* obj) - { - if (obj == nullptr) - { - return; - } - mDeleter(obj); - - // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts - // frequently. - std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. - std::lock_guard lk{mMutex}; - // Must check observer again because another thread may created new instance for this ctx just - // before we lock mMutex. We can't infer that the observer is stale from the fact that obj is - // destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic - // operation, and the observer may be changed to observe another instance. - observedObjHolder = mObservers.at(ctx).lock(); - if (observedObjHolder == nullptr) - { - mObservers.erase(ctx); - } - }}; - mObservers.at(ctx) = result; - } - return result; - } - -private: - CreatorFunc mCreator; - DeleterFunc mDeleter; - mutable std::mutex mMutex; - // CUDA resources are per-context. - std::unordered_map> mObservers; -}; - -template -class PerThreadSingletonCreator -{ -public: - using CreatorFunc = std::function()>; - using DeleterFunc = std::function; - - // creator returning std::unique_ptr is by design. - // It forces separation of memory for T and memory for control blocks. - // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. - // creator itself must not own CUDA resources. Only the object it creates can. - PerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter) - : mCreator{std::move(creator)} - , mDeleter{std::move(deleter)} - { - } - - std::shared_ptr operator()() - { - std::lock_guard lk{mMutex}; - - std::thread::id thread = std::this_thread::get_id(); - std::shared_ptr result = mObservers[thread].lock(); - - if (result == nullptr) - { - // Create the resource and register with an observer. - result = std::shared_ptr{mCreator().release(), - [this, thread](T* obj) - { - if (obj == nullptr) - { - return; - } - mDeleter(obj); - - // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts - // frequently. - std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. - std::lock_guard lk{mMutex}; - // Must check observer again because another thread may created new instance for this ctx just - // before we lock mMutex. We can't infer that the observer is stale from the fact that obj is - // destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic - // operation, and the observer may be changed to observe another instance. - observedObjHolder = mObservers.at(thread).lock(); - if (observedObjHolder == nullptr) - { - mObservers.erase(thread); - } - }}; - mObservers.at(thread) = result; - } - return result; - } - -private: - CreatorFunc mCreator; - DeleterFunc mDeleter; - mutable std::mutex mMutex; - // CUDA resources are per-thread. - std::unordered_map> mObservers; -}; - -} // namespace - -std::shared_ptr getCublasHandle() -{ - static PerThreadSingletonCreator creator( - []() -> auto - { - auto handle = std::unique_ptr(new cublasHandle_t); - TLLM_CUDA_CHECK(cublasCreate(handle.get())); - return handle; - }, - [](cublasHandle_t* handle) - { - TLLM_CUDA_CHECK(cublasDestroy(*handle)); - delete handle; - }); - return creator(); -} - -std::shared_ptr getCublasLtHandle() -{ - static PerThreadSingletonCreator creator( - []() -> auto - { - auto handle = std::unique_ptr(new cublasLtHandle_t); - TLLM_CUDA_CHECK(cublasLtCreate(handle.get())); - return handle; - }, - [](cublasLtHandle_t* handle) - { - TLLM_CUDA_CHECK(cublasLtDestroy(*handle)); - delete handle; - }); - return creator(); -} - -std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, - std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace) -{ - static PerThreadSingletonCreator creator( - [cublasHandle, cublasltHandle, stream, workspace]() -> auto - { - auto wrapper = std::unique_ptr( - new tensorrt_llm::common::CublasMMWrapper(cublasHandle, cublasltHandle, stream, workspace)); - return wrapper; - }, - [](tensorrt_llm::common::CublasMMWrapper* wrapper) { delete wrapper; }); - return creator(); -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h deleted file mode 100644 index 4e278e5cf2..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h +++ /dev/null @@ -1,215 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cublasMMWrapper.h" -#include "tensorrt_llm/common/workspace.h" - -#include -#include -#include -#include -#if ENABLE_MULTI_DEVICE -#include -#endif // ENABLE_MULTI_DEVICE - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tensorrt_llm::common -{ - -// Write values into buffer -template -void write(char*& buffer, T const& val) -{ - std::memcpy(buffer, &val, sizeof(T)); - buffer += sizeof(T); -} - -// Read values from buffer -template -void read(char const*& buffer, T& val) -{ - std::memcpy(&val, buffer, sizeof(T)); - buffer += sizeof(T); -} - -// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members. -// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and -// your clone() implementation is responsible for initializing such data members. -// With this we can simplify clone() implementation when there are many data members including at least one unique_ptr. -template > -class UniqPtrWNullCopy : public std::unique_ptr -{ -public: - using std::unique_ptr::unique_ptr; - - // for compatibility with std::make_unique - explicit UniqPtrWNullCopy(std::unique_ptr&& src) - : std::unique_ptr::unique_ptr{std::move(src)} - { - } - - // copy constructor produces nullptr - UniqPtrWNullCopy(UniqPtrWNullCopy const&) - : std::unique_ptr::unique_ptr{} - { - } -}; - -// for testing only -void const* getCommSessionHandle(); -} // namespace tensorrt_llm::common - -inline bool isBuilding() -{ - auto constexpr key = "IS_BUILDING"; - auto const val = getenv(key); - return val != nullptr && std::string(val) == "1"; -} - -#if ENABLE_MULTI_DEVICE -#define NCCLCHECK(cmd) \ - do \ - { \ - ncclResult_t r = cmd; \ - if (r != ncclSuccess) \ - { \ - printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - -std::unordered_map* getDtypeMap(); - -std::shared_ptr getComm(std::set const& group); - -#endif // ENABLE_MULTI_DEVICE - -//! To save GPU memory, all the plugins share the same cublas and cublasLt handle globally. -//! Get cublas and cublasLt handle for current cuda context -std::shared_ptr getCublasHandle(); -std::shared_ptr getCublasLtHandle(); -std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, - std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace); - -#ifndef DEBUG - -#define PLUGIN_CHECK(status) \ - do \ - { \ - if (status != 0) \ - abort(); \ - } while (0) - -#define ASSERT_PARAM(exp) \ - do \ - { \ - if (!(exp)) \ - return STATUS_BAD_PARAM; \ - } while (0) - -#define ASSERT_FAILURE(exp) \ - do \ - { \ - if (!(exp)) \ - return STATUS_FAILURE; \ - } while (0) - -#define CSC(call, err) \ - do \ - { \ - cudaError_t cudaStatus = call; \ - if (cudaStatus != cudaSuccess) \ - { \ - return err; \ - } \ - } while (0) - -#define DEBUG_PRINTF(...) \ - do \ - { \ - } while (0) - -#else - -#define ASSERT_PARAM(exp) \ - do \ - { \ - if (!(exp)) \ - { \ - fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \ - return STATUS_BAD_PARAM; \ - } \ - } while (0) - -#define ASSERT_FAILURE(exp) \ - do \ - { \ - if (!(exp)) \ - { \ - fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \ - return STATUS_FAILURE; \ - } \ - } while (0) - -#define CSC(call, err) \ - do \ - { \ - cudaError_t cudaStatus = call; \ - if (cudaStatus != cudaSuccess) \ - { \ - printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \ - return err; \ - } \ - } while (0) - -#define PLUGIN_CHECK(status) \ - { \ - if (status != 0) \ - { \ - DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \ - abort(); \ - } \ - } - -#define DEBUG_PRINTF(...) \ - do \ - { \ - printf(__VA_ARGS__); \ - } while (0) - -#endif // DEBUG - -#define NVML_CHECK(cmd) \ - do \ - { \ - nvmlReturn_t r = cmd; \ - if (r != NVML_SUCCESS) \ - { \ - printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h b/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h new file mode 100644 index 0000000000..052d9c8c81 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h @@ -0,0 +1,358 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +class QuantMode +{ + // [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py +public: + using BaseType = std::uint32_t; + + explicit constexpr QuantMode(BaseType value) noexcept + : mValue{value} + { + } + + QuantMode() noexcept = default; + + constexpr QuantMode(QuantMode const&) noexcept = default; + + constexpr QuantMode& operator=(QuantMode const& other) noexcept = default; + + static constexpr QuantMode none() noexcept + { + return QuantMode(BaseType(0)); + } + + static constexpr QuantMode int4Weights() noexcept + { + return QuantMode(BaseType(1u) << 0); + } + + static constexpr QuantMode int8Weights() noexcept + { + return QuantMode(BaseType(1u) << 1); + } + + static constexpr QuantMode activations() noexcept + { + return QuantMode(BaseType(1u) << 2); + } + + static constexpr QuantMode perChannelScaling() noexcept + { + return QuantMode(BaseType(1u) << 3); + } + + static constexpr QuantMode perTokenScaling() noexcept + { + return QuantMode(BaseType(1u) << 4); + } + + static constexpr QuantMode perGroupScaling() noexcept + { + return QuantMode(BaseType(1u) << 5); + } + + static constexpr QuantMode int8KvCache() noexcept + { + return QuantMode(BaseType(1u) << 6); + } + + static constexpr QuantMode fp8KvCache() noexcept + { + return QuantMode(BaseType(1u) << 7); + } + + static constexpr QuantMode fp8Qdq() noexcept + { + return QuantMode(BaseType(1u) << 8); + } + + static constexpr QuantMode fp8RowWise() noexcept + { + return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9); + } + + static constexpr QuantMode w4a8QServe() noexcept + { + return QuantMode(BaseType(1u) << 10); + } + + constexpr BaseType value() const noexcept + { + return mValue; + } + + constexpr bool isSet(QuantMode const& mode) const noexcept + { + return (mValue & mode.value()) == mode.value(); + } + + constexpr bool hasInt4Weights() const noexcept + { + return isSet(int4Weights()); + } + + constexpr bool hasInt8Weights() const noexcept + { + return isSet(int8Weights()); + } + + constexpr bool hasActivations() const noexcept + { + return isSet(activations()); + } + + constexpr bool hasPerChannelScaling() const noexcept + { + return isSet(perChannelScaling()); + } + + constexpr bool hasPerTokenScaling() const noexcept + { + return isSet(perTokenScaling()); + } + + constexpr bool hasPerGroupScaling() const noexcept + { + return isSet(perGroupScaling()); + } + + constexpr bool hasStaticActivationScaling() const noexcept + { + return !hasPerTokenScaling(); + } + + constexpr bool hasInt8KvCache() const noexcept + { + return isSet(int8KvCache()); + } + + constexpr bool hasFp8KvCache() const noexcept + { + return isSet(fp8KvCache()); + } + + constexpr bool hasFp8Qdq() const noexcept + { + return isSet(fp8Qdq()); + } + + constexpr bool hasFp8RowWise() const noexcept + { + return isSet(fp8RowWise()); + } + + constexpr bool hasKvCacheQuant() const noexcept + { + return hasInt8KvCache() || hasFp8KvCache(); + } + + static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false, + bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false, + bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false, + bool useW4a8QServe = false) + { + QuantMode quantMode{}; + if (quantizeWeights) + { + if (useInt4Weights) + quantMode += int4Weights(); + else + quantMode += int8Weights(); + } + + if (quantizeActivations) + { + quantMode += activations(); + } + + if (perChannel) + { + quantMode += QuantMode::perChannelScaling(); + } + if (perToken) + { + quantMode += QuantMode::perTokenScaling(); + } + if (perGroup) + { + quantMode += QuantMode::perGroupScaling(); + } + + if (useInt8KvCache) + { + quantMode += int8KvCache(); + } + + if (useFp8KvCache) + { + quantMode += fp8KvCache(); + } + + if (useFp8Qdq) + { + quantMode += fp8Qdq(); + } + + if (useFp8RowWise) + { + quantMode += fp8RowWise(); + } + + if (useW4a8QServe) + { + quantMode += w4a8QServe(); + } + + return quantMode; + } + + static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false) + { + return fromDescription(true, true, perToken, perChannel); + } + + static constexpr QuantMode useQServe(bool perGroup) + { + return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true); + } + + static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false) + { + return fromDescription(true, false, false, false, perGroup, useInt4Weights); + } + + static QuantMode const fromQuantAlgo( + std::optional quantAlgo = std::nullopt, std::optional kvCacheQuantAlgo = std::nullopt) + { + QuantMode quantMode{}; + if (quantAlgo == "W8A16") + { + quantMode = useWeightOnly(false, false); + } + else if (quantAlgo == "W4A16") + { + quantMode = useWeightOnly(true, false); + } + else if (quantAlgo == "W4A16_AWQ") + { + quantMode = useWeightOnly(true, true); + } + else if (quantAlgo == "W4A8_AWQ") + { + quantMode = useWeightOnly(true, true); + } + else if (quantAlgo == "W4A8_QSERVE_PER_GROUP") + { + quantMode = useQServe(false); + } + else if (quantAlgo == "W4A8_QSERVE_PER_CHANNEL") + { + quantMode = useQServe(true); + } + else if (quantAlgo == "W4A16_GPTQ") + { + quantMode = useWeightOnly(true, true); + } + else if (quantAlgo == "W8A8_SQ_PER_CHANNEL") + { + quantMode = useSmoothQuant(false, true); + } + else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN") + { + quantMode = useSmoothQuant(false, false); + } + else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN") + { + quantMode = useSmoothQuant(true, true); + } + else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN") + { + quantMode = useSmoothQuant(false, true); + } + else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN") + { + quantMode = useSmoothQuant(true, false); + } + else if (quantAlgo == "FP8") + { + quantMode = fromDescription(false, false, false, false, false, false, false, false, true); + } + else if (quantAlgo == "FP8_ROWWISE") + { + quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true); + } + + if (kvCacheQuantAlgo == "INT8") + { + quantMode += int8KvCache(); + } + else if (kvCacheQuantAlgo == "FP8") + { + quantMode += fp8KvCache(); + } + + return quantMode; + } + + constexpr QuantMode operator+(QuantMode const& other) const noexcept + { + return QuantMode(mValue | other.mValue); + } + + constexpr QuantMode& operator+=(QuantMode const& other) noexcept + { + return *this = *this + other; + } + + constexpr QuantMode operator-(QuantMode const& other) const noexcept + { + return QuantMode(mValue & ~other.mValue); + } + + constexpr QuantMode& operator-=(QuantMode const& other) noexcept + { + return *this = *this - other; + } + + constexpr bool operator==(QuantMode const& other) const noexcept + { + return mValue == other.mValue; + } + + constexpr bool operator!=(QuantMode const& other) const noexcept + { + return !(*this == other); + } + +private: + BaseType mValue{0}; +}; + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h deleted file mode 100644 index 9cda9fa0d4..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace tensorrt_llm::common::stl_utils -{ - -template -constexpr TOutputIt basicInclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, TBinOp op) -{ - if (first != last) - { - auto val = *first; - while (true) - { - *dFirst = val; - ++dFirst; - ++first; - if (first == last) - { - break; - } - val = op(std::move(val), *first); - } - } - return dFirst; -} - -template -constexpr TOutputIt inclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst) -{ -#if defined(__GNUC__) && __GNUC__ <= 8 - return basicInclusiveScan(first, last, dFirst, std::plus<>{}); -#else - return std::inclusive_scan(first, last, dFirst); -#endif -} - -template -constexpr TOutputIt basicExclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init, TBinOp op) -{ - if (first != last) - { - while (true) - { - T tmp{op(init, *first)}; - *dFirst = init; - ++dFirst; - ++first; - if (first == last) - { - break; - } - init = std::move(tmp); - } - } - return dFirst; -} - -template -constexpr TOutputIt exclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init) -{ -#if defined(__GNUC__) && __GNUC__ <= 8 - return basicExclusiveScan(first, last, dFirst, std::move(init), std::plus<>{}); -#else - return std::exclusive_scan(first, last, dFirst, std::move(init)); -#endif -} - -template -struct HasOperatorOutput : std::false_type -{ -}; - -template -struct HasOperatorOutput() << std::declval()))>> - : std::true_type -{ -}; - -template -std::string toString(T const& t, typename std::enable_if_t::value, int> = 0) -{ - std::ostringstream oss; - oss << t; - return oss.str(); -} - -template -std::string toString(std::optional const& t, typename std::enable_if_t::value, int> = 0) -{ - std::ostringstream oss; - if (t) - { - oss << t.value(); - } - else - { - oss << "None"; - } - return oss.str(); -} - -} // namespace tensorrt_llm::common::stl_utils diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h new file mode 100644 index 0000000000..9c5ecde98c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if ENABLE_BF16 +#include +#endif // ENABLE_BF16 +#include + +#include // std::make_unique +#include // std::stringstream +#include +#include +#include + +namespace tensorrt_llm::common +{ +#if ENABLE_BF16 +static inline std::basic_ostream& operator<<(std::basic_ostream& stream, __nv_bfloat16 const& val) +{ + stream << __bfloat162float(val); + return stream; +} +#endif // ENABLE_BF16 + +static inline std::basic_ostream& operator<<(std::basic_ostream& stream, __half const& val) +{ + stream << __half2float(val); + return stream; +} + +inline std::string fmtstr(std::string const& s) +{ + return s; +} + +inline std::string fmtstr(std::string&& s) +{ + return s; +} + +#if defined(_MSC_VER) +std::string fmtstr(char const* format, ...); +#else +std::string fmtstr(char const* format, ...) __attribute__((format(printf, 1, 2))); +#endif + +// __PRETTY_FUNCTION__ is used for neat debugging printing but is not supported on Windows +// The alternative is __FUNCSIG__, which is similar but not identical +#if defined(_WIN32) +#define __PRETTY_FUNCTION__ __FUNCSIG__ +#endif + +auto constexpr kDefaultDelimiter = ", "; + +template +inline TStream& arr2outCasted(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter) +{ + out << "("; + if (size > 0) + { + for (size_t i = 0; i < size - 1; ++i) + { + out << static_cast(arr[i]) << delim; + } + out << static_cast(arr[size - 1]); + } + out << ")"; + return out; +} + +template +inline TStream& arr2out(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter) +{ + return arr2outCasted(out, arr, size, delim); +} + +template +inline std::string arr2str(T* arr, size_t size, char const* delim = kDefaultDelimiter) +{ + std::stringstream ss; + return arr2out(ss, arr, size, delim).str(); +} + +template +inline std::string vec2str(std::vector const& vec, char const* delim = kDefaultDelimiter) +{ + return arr2str(vec.data(), vec.size(), delim); +} + +inline bool strStartsWith(std::string const& str, std::string const& prefix) +{ + return str.rfind(prefix, 0) == 0; +} + +/// @brief Split a string into a set of strings using a delimiter +std::unordered_set str2set(std::string const& input, char delimiter); + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp deleted file mode 100644 index c00041abda..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include "tensorrt_llm/common/timestampUtils.h" - -namespace tensorrt_llm::common -{ - -std::string getCurrentTimestamp() -{ - auto now = std::chrono::system_clock::now(); - auto now_t = std::chrono::system_clock::to_time_t(now); - auto tm = *std::localtime(&now_t); - - auto epoch_to_now = now.time_since_epoch(); - auto seconds = std::chrono::duration_cast(epoch_to_now); - auto us = std::chrono::duration_cast(epoch_to_now - seconds); - - std::ostringstream stream; - stream << std::put_time(&tm, "%m-%d-%Y %H:%M:%S"); - stream << "." << std::setfill('0') << std::setw(6) << us.count(); - return stream.str(); -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h similarity index 50% rename from sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h rename to sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h index f52f23028c..47e0e63d3f 100644 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h @@ -14,12 +14,35 @@ * limitations under the License. */ +#pragma once + +#include +#include +#include #include +#define NEW_TLLM_EXCEPTION(...) \ + tensorrt_llm::common::TllmException(__FILE__, __LINE__, tensorrt_llm::common::fmtstr(__VA_ARGS__)) + namespace tensorrt_llm::common { -/// @brief Get the current timestamp in the format "MM-DD-YYYY HH:MM:SS:uuuuuu" -std::string getCurrentTimestamp(); +class TllmException : public std::runtime_error +{ +public: + static auto constexpr MAX_FRAMES = 128; + + explicit TllmException(char const* file, std::size_t line, std::string const& msg); + + ~TllmException() noexcept override; + + [[nodiscard]] std::string getTrace() const; + + static std::string demangle(char const* name); + +private: + std::array mCallstack{}; + int mNbFrames; +}; } // namespace tensorrt_llm::common From cde4bbd5cca252589d9d9bf2b5f3b1c0ad48355a Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Thu, 30 Jan 2025 18:28:22 -0800 Subject: [PATCH 07/52] docs: add Novita for adoption and sponsorship (#3227) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e4c5f12f39..b27271a181 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487) ## Adoption and Sponsorship -The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. +The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. ## Acknowledgment and Citation We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. From 9829e77e3fe93a86294d70722eff7bff4cdf4540 Mon Sep 17 00:00:00 2001 From: Ravi Theja Date: Fri, 31 Jan 2025 13:31:46 +0530 Subject: [PATCH 08/52] Docs: Update supported models with Mistral 3 (#3229) Co-authored-by: Ravi Theja Desetty --- docs/references/supported_models.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 93c4273765..85de12f9f4 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -2,7 +2,7 @@ ## Generative Models - Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 -- Mistral / Mixtral / Mistral NeMo +- Mistral / Mixtral / Mistral NeMo / Mistral Small 3 - Gemma / Gemma 2 - Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL - DeepSeek / DeepSeek 2 / [DeepSeek 3](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3) From 3ee62235c612a0f5f1b7a8eee272961869574efc Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 31 Jan 2025 16:51:41 +0800 Subject: [PATCH 09/52] revert the MoE dependence (#3230) --- .../3rdparty/tensorrt_llm/common/assert.cpp | 34 - .../3rdparty/tensorrt_llm/common/assert.h | 92 -- .../tensorrt_llm/common/cublasMMWrapper.cpp | 360 -------- .../tensorrt_llm/common/cublasMMWrapper.h | 148 ---- .../tensorrt_llm/common/cublasVersionCheck.h | 35 - .../tensorrt_llm/common/cudaBf16Fallbacks.cuh | 313 ------- .../tensorrt_llm/common/cudaBf16Wrapper.h | 21 - .../tensorrt_llm/common/cudaDriverWrapper.cpp | 187 ---- .../tensorrt_llm/common/cudaDriverWrapper.h | 138 --- .../tensorrt_llm/common/cudaFp8Utils.cu | 436 ---------- .../tensorrt_llm/common/cudaFp8Utils.h | 239 ----- .../tensorrt_llm/common/cudaTypeUtils.cuh | 752 ---------------- .../3rdparty/tensorrt_llm/common/cudaUtils.h | 641 -------------- .../3rdparty/tensorrt_llm/common/logger.cpp | 70 -- .../3rdparty/tensorrt_llm/common/logger.h | 190 ---- .../tensorrt_llm/common/quantTypeUtils.cuh | 55 -- .../tensorrt_llm/common/quantization.h | 358 -------- .../tensorrt_llm/common/reduceKernelUtils.cuh | 399 --------- .../tensorrt_llm/common/stringUtils.cpp | 76 -- .../tensorrt_llm/common/stringUtils.h | 113 --- .../tensorrt_llm/common/tllmException.cpp | 105 --- .../tensorrt_llm/common/tllmException.h | 48 - .../3rdparty/tensorrt_llm/common/workspace.h | 87 -- .../arch/copy_red_global.hpp | 352 -------- .../include/cutlass_extensions/arch/mma.h | 120 --- .../cutlass_extensions/compute_occupancy.h | 88 -- .../collective/epilogue_moe_finalize.hpp | 550 ------------ .../epilogue/thread/fused_activations.h | 105 --- .../epilogue_per_row_per_col_scale.h | 352 -------- .../threadblock/epilogue_tensor_op_int32.h | 282 ------ .../cutlass_extensions/epilogue_helpers.h | 141 --- .../builders/sm90_gmma_builder_gated.inl | 221 ----- .../collective/collective_builder_gated.hpp | 58 -- .../gemm/collective/collective_mma_gated.hpp | 59 -- ..._mma_gated_tma_gmma_ss_warpspecialized.hpp | 642 -------------- ..._gated_tma_gmma_ss_warpspecialized_fp8.hpp | 665 -------------- .../gemm/device/gemm_universal_base_compat.h | 438 ---------- .../gemm/device/splitk_gemm_grouped.h | 542 ------------ .../gemm/kernel/default_fpA_intB_traits.h | 162 ---- .../gemm/kernel/default_int8_traits.h | 57 -- .../gemm/kernel/default_splitk_gemm_grouped.h | 207 ----- .../gemm/kernel/fpA_intB_gemm.h | 566 ------------ .../gemm/kernel/fused_moe_kernel.cuh | 218 ----- .../gemm/kernel/fused_moe_kernel_routine.cuh | 799 ----------------- .../gemm/kernel/fused_moe_kernel_traits.cuh | 215 ----- .../gemm/kernel/gemm_moe_problem_visitor.h | 73 -- .../gemm/kernel/gemm_universal_gated.hpp | 70 -- .../gemm/kernel/gemm_with_epilogue_visitor.h | 585 ------------- .../gemm/kernel/mixed_gemm_B_layout.h | 143 --- .../gemm/kernel/moe_cute_util.cuh | 185 ---- .../gemm/kernel/moe_cutlass_kernel.h | 553 ------------ .../gemm/kernel/moe_problem_visitor.h | 344 -------- ..._gated_tma_warpspecialized_cooperative.hpp | 646 -------------- ...emm_gated_tma_warpspecialized_pingpong.hpp | 621 ------------- .../gemm/kernel/splitk_gemm_grouped.h | 494 ----------- .../gemm/threadblock/default_dq_mma.h | 125 --- .../threadblock/default_dq_mma_multistage.h | 302 ------- .../threadblock/default_dq_mma_pipelined.h | 284 ------ .../gemm/threadblock/default_mma.h | 351 -------- .../gemm/threadblock/default_mma_bf16.h | 353 -------- .../gemm/threadblock/dq_mma_base.h | 257 ------ .../gemm/threadblock/dq_mma_multistage.h | 110 --- .../dq_mma_multistage_finegrained.h | 708 --------------- .../threadblock/dq_mma_multistage_percol.h | 647 -------------- .../gemm/threadblock/dq_mma_pipelined.h | 106 --- .../dq_mma_pipelined_finegrained.h | 486 ----------- .../threadblock/dq_mma_pipelined_percol.h | 399 --------- .../gemm/warp/default_mma_tensor_op.h | 107 --- .../warp/mma_tensorop_compute_B_with_f16.h | 306 ------- .../gemm/warp/mma_tensorop_dequantizer.h | 463 ---------- .../include/cutlass_extensions/gemm_configs.h | 224 ----- .../interleaved_numeric_conversion.h | 447 ---------- .../tile_interleaved_layout.h | 66 -- .../fine_grained_scale_zero_iterator.h | 250 ------ .../cutlass_extensions/util/gather_tensor.hpp | 181 ---- .../cutlass_extensions/weight_only_quant_op.h | 58 -- .../launchers/fused_moe_gemm_launcher_sm80.h | 25 - .../fused_moe_gemm_launcher_sm80.inl | 96 -- .../launchers/moe_gemm_launcher_sm90.h | 37 - .../launchers/moe_gemm_launcher_sm90.inl | 348 -------- .../moe_gemm/moe_gemm_hopper_input.cu | 131 --- .../moe_gemm/moe_gemm_kernels.h | 230 ----- .../moe_gemm/moe_gemm_kernels_bf16_bf16.cu | 24 - .../moe_gemm/moe_gemm_kernels_bf16_uint4.cu | 24 - .../moe_gemm/moe_gemm_kernels_bf16_uint8.cu | 24 - .../moe_gemm/moe_gemm_kernels_fp16_fp16.cu | 22 - .../moe_gemm/moe_gemm_kernels_fp16_uint4.cu | 22 - .../moe_gemm/moe_gemm_kernels_fp16_uint8.cu | 22 - .../moe_gemm/moe_gemm_kernels_fp32_fp32.cu | 22 - .../moe_gemm/moe_gemm_kernels_fp8_fp8.cu | 28 - .../moe_gemm/moe_gemm_kernels_template.h | 823 ------------------ .../moe_gemm/moe_gemm_kernels_template_sm90.h | 222 ----- .../moe_gemm/moe_sm90_traits.h | 44 - sgl-kernel/setup.py | 4 - 94 files changed, 23828 deletions(-) delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/assert.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/logger.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp deleted file mode 100644 index eaaf662447..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/assert.h" - -namespace -{ - -bool initCheckDebug() -{ - auto constexpr kDebugEnabled = "TLLM_DEBUG_MODE"; - auto const debugEnabled = std::getenv(kDebugEnabled); - return debugEnabled && debugEnabled[0] == '1'; -} -} // namespace - -bool DebugConfig::isCheckDebugEnabled() -{ - static bool const debugEnabled = initCheckDebug(); - return debugEnabled; -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h deleted file mode 100644 index 7f51dbf1b4..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/stringUtils.h" -#include "tensorrt_llm/common/tllmException.h" - -#include - -namespace tensorrt_llm::common -{ -[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "") -{ - throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str())); -} - -} // namespace tensorrt_llm::common - -class DebugConfig -{ -public: - static bool isCheckDebugEnabled(); -}; - -#if defined(_WIN32) -#define TLLM_LIKELY(x) (__assume((x) == 1), (x)) -#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x)) -#else -#define TLLM_LIKELY(x) __builtin_expect((x), 1) -#define TLLM_UNLIKELY(x) __builtin_expect((x), 0) -#endif - -#define TLLM_CHECK(val) \ - do \ - { \ - TLLM_LIKELY(static_cast(val)) ? ((void) 0) \ - : tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \ - } while (0) - -#define TLLM_CHECK_WITH_INFO(val, info, ...) \ - do \ - { \ - TLLM_LIKELY(static_cast(val)) \ - ? ((void) 0) \ - : tensorrt_llm::common::throwRuntimeError( \ - __FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \ - } while (0) - -#define TLLM_CHECK_DEBUG(val) \ - do \ - { \ - if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \ - { \ - TLLM_LIKELY(static_cast(val)) ? ((void) 0) \ - : tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \ - } \ - } while (0) - -#define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \ - do \ - { \ - if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \ - { \ - TLLM_LIKELY(static_cast(val)) \ - ? ((void) 0) \ - : tensorrt_llm::common::throwRuntimeError( \ - __FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \ - } \ - } while (0) - -#define TLLM_THROW(...) \ - do \ - { \ - throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \ - } while (0) - -#define TLLM_WRAP(ex) \ - NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what()) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp deleted file mode 100644 index 351257f4d2..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp +++ /dev/null @@ -1,360 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/cublasMMWrapper.h" -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cublasVersionCheck.h" -#include - -#ifndef CUDART_VERSION -#error CUDART_VERSION Undefined! -#endif - -namespace tensorrt_llm -{ -namespace common -{ - -CublasMMWrapper::CublasMMWrapper(std::shared_ptr cublasHandle, - std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace) - : mCublasHandle(cublasHandle) - , mCublasLtHandle(cublasltHandle) - , mStream(stream) - , mCublasWorkspace(workspace) -{ -} - -CublasMMWrapper::~CublasMMWrapper() {} - -CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper) - : mCublasHandle(wrapper.mCublasHandle) - , mCublasLtHandle(wrapper.mCublasLtHandle) - , mStream(wrapper.mStream) -{ -} - -void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, - int const k, int const lda, int const ldb, int const ldc, int8_t fastAcc) -{ - // -------------------------------------- - // Create descriptors for the original matrices - check_cuda_error( - cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda)); - check_cuda_error( - cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb)); - check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc)); - check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType)); - check_cuda_error(cublasLtMatmulDescSetAttribute( - mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t))); - check_cuda_error(cublasLtMatmulDescSetAttribute( - mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t))); - check_cuda_error( - cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t))); -} - -void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b) -{ - check_cuda_error( - cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*))); - check_cuda_error( - cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*))); -} - -void CublasMMWrapper::destroyDescriptors() -{ - check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc)); - check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc)); - check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc)); - check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc)); - mOperationDesc = NULL; - mADesc = NULL; - mBDesc = NULL; - mCDesc = NULL; -} - -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc) -{ - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f); -} - -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, - std::optional const& heuristic) -{ - if (heuristic) - { - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo, - (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE, - /* usingCublasLt */ true); - } - else - { - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false, - /* usingCublasLt */ true); - } -} - -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, - std::optional const& heuristic) -{ - if (heuristic) - { - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, /* hasAlgo */ (*heuristic).algo, - (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE, - /* usingCublasLt */ true); - } - else - { - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false, - /* usingCublasLt */ true); - } -} - -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta) -{ - bool usingCublasLt = mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3; - - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false, - /* usingCublasLt */ usingCublasLt); -} - -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, - cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt) -{ - half h_alpha = (half) (f_alpha); - half h_beta = (half) (f_beta); - - // TODO: default cublas libs - usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3); - bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F; - int batch_count = 1; - // fp32 use cublas as default - // fp16 use cublasLt as default - void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); - int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; - - if (usingCublasLt) - { - if (hasAlgo) - { - hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo); - } - - check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C, - mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream)); - - sync_check_cuda_error(); - } - else - { - check_cuda_error(cublasSetStream(getCublasHandle(), mStream)); - check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize)); - // Go with default heuristic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+ - cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT; - check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb, - beta, C, mCType, ldc, mComputeType, static_cast(cublasAlgo))); - sync_check_cuda_error(); - } -} - -void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, - int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, - const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha, - float const f_beta) -{ - half h_alpha = (half) f_alpha; - half h_beta = (half) f_beta; - - int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; - void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); - - check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, - strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType, - mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} - -void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, - int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, - void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, - cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType) -{ - half h_alpha = (half) f_alpha; - half h_beta = (half) f_beta; - - bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; - void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); - - check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda, - strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType, - mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} - -void CublasMMWrapper::setWorkspace(void* workspace) -{ - mCublasWorkspace = workspace; -} - -void CublasMMWrapper::setFP32GemmConfig() -{ - setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F); -} - -void CublasMMWrapper::setFP16GemmConfig(cudaDataType_t outputType) -{ - setGemmConfig(CUDA_R_16F, CUDA_R_16F, outputType, CUDA_R_32F); -} - -#ifdef ENABLE_BF16 -void CublasMMWrapper::setBF16GemmConfig(cudaDataType_t outputType) -{ - setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, outputType, CUDA_R_32F); -} -#endif - -#ifdef ENABLE_FP8 -void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType) -{ - setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F); -} -#endif - -void CublasMMWrapper::setGemmConfig( - cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType) -{ - mAType = aType; - mBType = bType; - mCType = cType; - bool isFp16ComputeType = computeType == CUDA_R_16F; - if (isFp16ComputeType) - { - mComputeType = CUBLAS_COMPUTE_16F; - mScaleType = CUDA_R_16F; - } - else - { - mComputeType = CUBLAS_COMPUTE_32F; - mScaleType = CUDA_R_32F; - } -} - -CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type) -{ - if (data_type == CUDA_R_16F) - { - return HALF_DATATYPE; - } - else if (data_type == CUDA_R_32F) - { - return FLOAT_DATATYPE; - } - else if (data_type == CUDA_R_8I) - { - return INT8_DATATYPE; - } -#ifdef ENABLE_BF16 - else if (data_type == CUDA_R_16BF) - { - return BFLOAT16_DATATYPE; - } -#endif - return FLOAT_DATATYPE; -} - -void CublasMMWrapper::setStream(cudaStream_t stream) -{ - mStream = stream; -} - -bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, - int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo) -{ - TLLM_CHECK_WITH_INFO( - descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); - - int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; - - cublasLtMatmulHeuristicResult_t heurResult; - cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( - getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult); - - if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS - || heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE) - { - return false; - } - - sync_check_cuda_error(); - - return true; -} - -std::vector CublasMMWrapper::getTactics(cublasOperation_t transa, - cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc) -{ - TLLM_CHECK_WITH_INFO( - descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); - - auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc); - - sync_check_cuda_error(); - - return heuristics; -} - -std::vector CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle, - cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, - cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc) -{ -#if TLLM_CUBLAS_VER_LE(11, 4, 2) - TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2."); - return {}; -#else - std::vector heuristics(200); - cublasLtMatmulPreference_t preference; - check_cuda_error(cublasLtMatmulPreferenceCreate(&preference)); - check_cuda_error(cublasLtMatmulPreferenceInit(preference)); - uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE; - check_cuda_error(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); - // Restrict reduction algorithms for numerical stability and better determinism - uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK; - check_cuda_error(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask))); -#if TLLM_CUBLAS_VER_LT(12, 0, 0) - uint32_t pointer_mode_mask = 0; - check_cuda_error(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask))); -#endif - - int return_count = 0; - check_cuda_error(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, - heuristics.size(), heuristics.data(), &return_count)); - heuristics.resize(return_count); - - return heuristics; -#endif -} - -} // namespace common - -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h deleted file mode 100644 index 79b7c92a47..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cudaUtils.h" -#include -#include -#include -#include -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ - -class CublasMMWrapper -{ -protected: - std::shared_ptr mCublasHandle; - std::shared_ptr mCublasLtHandle; - - cudaDataType_t mAType{}; - cudaDataType_t mBType{}; - cudaDataType_t mCType{}; - cublasComputeType_t mComputeType{}; - cudaDataType_t mScaleType{}; - - cublasLtMatmulDesc_t mOperationDesc{NULL}; - cublasLtMatrixLayout_t mADesc{NULL}; - cublasLtMatrixLayout_t mBDesc{NULL}; - cublasLtMatrixLayout_t mCDesc{NULL}; - - cudaStream_t mStream; - - void* mCublasWorkspace = nullptr; - -private: - bool descriptorsCreated() const - { - return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL; - } - -public: - CublasMMWrapper(std::shared_ptr cublasHandle, std::shared_ptr cublasLtHandle, - cudaStream_t stream, void* workspace); - - ~CublasMMWrapper(); - - CublasMMWrapper(CublasMMWrapper const& wrapper); - - /********************** GEMMs **********************/ - void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, - int const lda, void const* B, int const ldb, void* C, int const ldc); - - void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, - int const lda, void const* B, int const ldb, void* C, int const ldc, - std::optional const& algo); - - void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, - int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, - std::optional const& algo); - - void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, - int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta); - - void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, - int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, - cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt); - - void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB, - void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f, - float const f_beta = 0.0f); - - void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B, - cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType, - int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType); - - /********************** Tactic selection helpers **********************/ - bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo); - - std::vector getTactics(cublasOperation_t transa, cublasOperation_t transb, - int const m, int const n, int const k, int const lda, int const ldb, int const ldc); - - std::vector getTactics(cublasLtHandle_t lightHandle, - cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, - cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc); - - using MatrixLayout = std::tuple; - using cache_idx_t = std::tuple>; - - MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc); - - /********************** Utils **********************/ - void setWorkspace(void* workspace); - - void setFP32GemmConfig(); - void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F); -#ifdef ENABLE_BF16 - void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF); -#endif -#ifdef ENABLE_FP8 - void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F); -#endif - - void setStream(cudaStream_t stream); - - void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType); - - CublasDataType getCublasDataType(cudaDataType_t data_type); - - void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - int const lda, int const ldb, int const ldc, int8_t fastAcc = 0); - void setScaleDescriptors(void* scale_a, void* scale_b); - void destroyDescriptors(); - - cublasHandle_t getCublasHandle() - { - return *(this->mCublasHandle); - } - - cublasLtHandle_t getCublasLtHandle() const - { - return *(this->mCublasLtHandle); - } -}; - -} // namespace common - -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h deleted file mode 100644 index 1ee72c6356..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -// We don't want to include cublas_api.h. It contains the CUBLAS_VER_* macro -// definition which is not sufficient to determine if we include cublas.h, -// cublas_v2.h or cublasLt.h. - -#define TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) (MAJOR * 10000 + MINOR * 100 + PATCH) -#define TLLM_CUBLAS_VER_LE(MAJOR, MINOR, PATCH) \ - TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ - <= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) -#define TLLM_CUBLAS_VER_LT(MAJOR, MINOR, PATCH) \ - TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ - < TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) -#define TLLM_CUBLAS_VER_GE(MAJOR, MINOR, PATCH) \ - TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ - >= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) -#define TLLM_CUBLAS_VER_GT(MAJOR, MINOR, PATCH) \ - TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ - > TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh deleted file mode 100644 index 0519251e6f..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh +++ /dev/null @@ -1,313 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cudaBf16Wrapper.h" -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ - -#ifdef ENABLE_BF16 -inline __device__ float2 bf1622float2(const __nv_bfloat162 val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = __low2float(val); - f_val.y = __high2float(val); - return f_val; -#else - return __bfloat1622float2(val); -#endif -} - -inline __device__ int16_t bf1622int16(__nv_bfloat162 val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = max(min(__low2float(val), 127.f), -128.f); - f_val.y = max(min(__high2float(val), 127.f), -128.f); - - union - { - int8_t int8[2]; - int16_t int16; - }; - - int8[0] = static_cast(static_cast(f_val.x)); - int8[1] = static_cast(static_cast(f_val.y)); - return int16; -#else - val = __hmin2(val, make_bfloat162(127., 127.)); - val = __hmax2(val, make_bfloat162(-128., -128.)); - - union - { - int8_t int8[2]; - int16_t int16; - }; - - int8[0] = static_cast(static_cast(val.x)); - int8[1] = static_cast(static_cast(val.y)); - return int16; -#endif -} - -inline __device__ __nv_bfloat162 float22bf162(const float2 val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __floats2bfloat162_rn(val.x, val.y); -#else - return __float22bfloat162_rn(val); -#endif -} - -inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - __nv_bfloat162 val2; - val2.x = val; - val2.y = val; - return val2; -#else - return __bfloat162bfloat162(val); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); -#else - return __hadd2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y)); -#else - return __hadd(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); -#else - return __hsub2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y)); -#else - return __hsub(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); -#else - return __hmul2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y)); -#else - return __hmul(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh, fzl, fzh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - fzl = __low2float(z); - fzh = __high2float(z); - return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); -#else - return __hfma2(x, y, z); -#endif -} - -inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); -#else - return __hfma(x, y, z); -#endif -} - -inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh; - fxl = __low2float(x); - fxh = __high2float(x); - ; - return __floats2bfloat162_rn(expf(fxl), expf(fxh)); -#else - return h2exp(x); -#endif -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020) - -inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ - __nv_bfloat162 t; - t.x = x; - t.y = y; - return t; -} -#endif -#endif - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); -#else - return (__nv_bfloat16) ((float) a + (float) b + (float) c + (float) d); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - fdl = __low2float(d); - fdh = __high2float(d); - return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); -#else - return a * b * c + d; -#endif -} - -#endif // ENABLE_BF16 - -} // namespace common -} // namespace tensorrt_llm - -// Operator definitions intentionally in global namespace -namespace -{ -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020) - -inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) -{ - return tensorrt_llm::common::bf16hmul2(x, y); -}; - -inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) -{ - return tensorrt_llm::common::bf16hadd2(x, y); -}; -#endif -#endif -} // namespace diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h deleted file mode 100644 index fb2a89af5c..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef ENABLE_BF16 -#include -#endif diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp deleted file mode 100644 index 7eca46a1ca..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#define CUDA_LIB_NAME "cuda" - -#if defined(_WIN32) -#include -#define dllOpen(name) LoadLibrary("nv" name ".dll") -#define dllClose(handle) FreeLibrary(static_cast(handle)) -#define dllGetSym(handle, name) static_cast(GetProcAddress(static_cast(handle), name)) -#else // For non-Windows platforms -#include -#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY) -#define dllClose(handle) dlclose(handle) -#define dllGetSym(handle, name) dlsym(handle, name) -#endif // defined(_WIN32) - -#include "cudaDriverWrapper.h" -#include "tensorrt_llm/common/assert.h" -#include -#include - -namespace tensorrt_llm::common -{ - -std::shared_ptr CUDADriverWrapper::getInstance() -{ - static std::mutex mutex; - static std::weak_ptr instance; - std::shared_ptr result = instance.lock(); - if (result) - { - return result; - } - - std::lock_guard lock(mutex); - result = instance.lock(); - if (!result) - { - result = std::shared_ptr(new CUDADriverWrapper()); - instance = result; - } - return result; -} - -CUDADriverWrapper::CUDADriverWrapper() - : handle(dllOpen(CUDA_LIB_NAME)) -{ - - TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); - - auto load_sym = [](void* handle, char const* name) - { - void* ret = dllGetSym(handle, name); - return ret; - }; - - *reinterpret_cast(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName"); - *reinterpret_cast(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage"); - *reinterpret_cast(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute"); - *reinterpret_cast(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete"); - *reinterpret_cast(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload"); - *reinterpret_cast(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy"); - *reinterpret_cast(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData"); - *reinterpret_cast(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2"); - *reinterpret_cast(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction"); - *reinterpret_cast(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2"); - *reinterpret_cast(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2"); - *reinterpret_cast(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2"); - *reinterpret_cast(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel"); - *reinterpret_cast(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel"); - *reinterpret_cast(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled"); - *reinterpret_cast(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2"); -} - -CUDADriverWrapper::~CUDADriverWrapper() -{ - dllClose(handle); -} - -CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const -{ - return (*_cuGetErrorName)(error, pStr); -} - -CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const -{ - return (*_cuGetErrorMessage)(error, pStr); -} - -CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const -{ - return (*_cuFuncSetAttribute)(hfunc, attrib, value); -} - -CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const -{ - return (*_cuLinkComplete)(state, cubinOut, sizeOut); -} - -CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const -{ - return (*_cuModuleUnload)(hmod); -} - -CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const -{ - return (*_cuLinkDestroy)(state); -} - -CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const -{ - return (*_cuModuleLoadData)(module, image); -} - -CUresult CUDADriverWrapper::cuLinkCreate( - unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const -{ - return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); -} - -CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const -{ - return (*_cuModuleGetFunction)(hfunc, hmod, name); -} - -CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const -{ - return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); -} - -CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, - unsigned int numOptions, CUjit_option* options, void** optionValues) const -{ - return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); -} - -CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, - char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const -{ - return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); -} - -CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const -{ - return (*_cuLaunchCooperativeKernel)( - f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams); -} - -CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const -{ - return (*_cuLaunchKernel)( - f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra); -} - -CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, - cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const -{ - return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides, - boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill); -} - -CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const -{ - return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount); -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h deleted file mode 100644 index c4d470a85f..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef CUDA_DRIVER_WRAPPER_H -#define CUDA_DRIVER_WRAPPER_H - -#include "tensorrt_llm/common/assert.h" -#include -#include -#include -#include - -namespace tensorrt_llm::common -{ - -class CUDADriverWrapper -{ -public: - static std::shared_ptr getInstance(); - - ~CUDADriverWrapper(); - CUDADriverWrapper(CUDADriverWrapper const&) = delete; - CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete; - CUDADriverWrapper(CUDADriverWrapper&&) = delete; - CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete; - - CUresult cuGetErrorName(CUresult error, char const** pStr) const; - - CUresult cuGetErrorMessage(CUresult error, char const** pStr) const; - - CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; - - CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const; - - CUresult cuModuleUnload(CUmodule hmod) const; - - CUresult cuLinkDestroy(CUlinkState state) const; - - CUresult cuModuleLoadData(CUmodule* module, void const* image) const; - - CUresult cuLinkCreate( - unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; - - CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const; - - CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const; - - CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, - CUjit_option* options, void** optionValues) const; - - CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name, - unsigned int numOptions, CUjit_option* options, void** optionValues) const; - - CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const; - - CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, - unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, - CUstream hStream, void** kernelParams, void** extra) const; - - CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, - void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim, - cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, - CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const; - - CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const; - -private: - void* handle; - CUDADriverWrapper(); - - CUresult (*_cuGetErrorName)(CUresult, char const**); - CUresult (*_cuGetErrorMessage)(CUresult, char const**); - CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); - CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); - CUresult (*_cuModuleUnload)(CUmodule); - CUresult (*_cuLinkDestroy)(CUlinkState); - CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); - CUresult (*_cuModuleLoadData)(CUmodule*, void const*); - CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*); - CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*); - CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**); - CUresult (*_cuLinkAddData)( - CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**); - CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, - unsigned int, unsigned int, unsigned int, CUstream, void**); - CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, - unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, - CUstream hStream, void** kernelParams, void** extra); - CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, - cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); - CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount); -}; - -template -void checkDriver( - T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line) -{ - if (result) - { - char const* errorName = nullptr; - char const* errorMsg = nullptr; - wrap.cuGetErrorName(result, &errorName); - wrap.cuGetErrorMessage(result, &errorMsg); - throw TllmException( - file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg)); - } -} - -} // namespace tensorrt_llm::common - -/* - * Macros compliant with TensorRT coding conventions - */ -#define TLLM_CU_CHECK(stat) \ - do \ - { \ - tensorrt_llm::common::checkDriver( \ - (stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \ - } while (0) - -#endif // CUDA_DRIVER_WRAPPER_H diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu deleted file mode 100644 index 8e140609f2..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu +++ /dev/null @@ -1,436 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/common/reduceKernelUtils.cuh" -#include -#include -#include -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ -#ifdef ENABLE_FP8 - -constexpr int CTA_SIZE = 256; - -template -__inline__ __device__ float scale(float a, float b) -{ - return QUANTIZE ? a / b : a * b; -} - -template -__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda) -{ - for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x) - { - - if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) - { - output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i % lda]))); - } - else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) - { - output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i / lda]))); - } - else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) - { - output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[0]))); - } - } -} - -template -void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream) -{ - dim3 grid(1024); - dim3 block(CTA_SIZE); - if (quantize_mode == QuantizeMode::PER_CHANNEL) - { - scaleMatrix - <<>>(output, input_scale, input, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_TOKEN) - { - scaleMatrix<<>>(output, input_scale, input, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_TENSOR) - { - scaleMatrix<<>>(output, input_scale, input, numel, lda); - } - sync_check_cuda_error(); -} - -template -void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream) -{ - dim3 grid(1024); - dim3 block(CTA_SIZE); - if (quantize_mode == QuantizeMode::PER_CHANNEL) - { - scaleMatrix - <<>>(output, input_scale, input, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_TOKEN) - { - scaleMatrix<<>>(output, input_scale, input, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_TENSOR) - { - scaleMatrix - <<>>(output, input_scale, input, numel, lda); - } - sync_check_cuda_error(); -} - -template -__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel) -{ - for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x) - { - T_FAKE tmp = (T_FAKE) (static_cast(src[tid])); - dst[tid] = (T_OUT) (static_cast(tmp)); - } -} - -template -void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream) -{ - fakeQuantize<<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel); - sync_check_cuda_error(); -} - -template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>( - float* dst, float const* src, const int64_t numel, cudaStream_t stream); -template void invokeFakeQuantize( - float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream); -template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>( - half* dst, half const* src, const int64_t numel, cudaStream_t stream); -template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>( - __nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream); - -template void invokeFakeQuantize( - half* dst, float const* src, const int64_t numel, cudaStream_t stream); - -__device__ float atomicMaxExtd(float* address, float val) -{ - assert(val >= 0); - unsigned int* address_as_u = reinterpret_cast(address); - unsigned int old = atomicMax(address_as_u, __float_as_uint(val)); - return __uint_as_float(old); -} - -template -inline __device__ T atomicMaxExtdV2(T* address, T val) -{ -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - static_assert(std::is_same_v | std::is_same_v, "T needs to be either half or bfloat16"); - // The address in 64 bits. - uint64_t address_u64 = reinterpret_cast(address); - - // Pack the input value into 32 bits. - union - { - T v[2]; - uint16_t u[2]; - } old, tmp = {}; - - int const loc = (address_u64 & 0x2) >> 1; - tmp.v[loc] = val; - - // 4B aligned pointer. - auto aligned_address = reinterpret_cast(address_u64 & ~0x3ull); - - if constexpr (std::is_same_v) - { - asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};" - : "=h"(old.u[0]), "=h"(old.u[1]) - : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); - } - if constexpr (std::is_same_v) - { - asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};" - : "=h"(old.u[0]), "=h"(old.u[1]) - : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); - } - - // Return the correct half. - return old.v[loc]; -#endif -} - -__device__ half atomicMaxExtd(half* address, half val) -{ - unsigned short int* address_as_u = reinterpret_cast(address); - unsigned short int old = *address_as_u, assumed; - - while (val > __ushort_as_half(old)) - { - assumed = old; - old = atomicCAS(address_as_u, assumed, __half_as_ushort(val)); - } - - return __ushort_as_half(old); -} - -__device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val) -{ -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - unsigned short int* address_as_u = reinterpret_cast(address); - unsigned short int old = *address_as_u, assumed; - - while (val > __ushort_as_bfloat16(old)) - { - assumed = old; - old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val)); - } - - return __ushort_as_bfloat16(old); -#else - assert(0); - asm volatile("brkpt;\n" ::); - return __nv_bfloat16(0); -#endif -} - -template -__global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n) -{ - constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); - if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) - { - for (int64_t col = threadIdx.x; col < n; col += blockDim.x) - { - float max = 0.f; - for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n) - { - auto val = fabs(static_cast(weights[i])); - max = max > val ? max : val; - } - auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - if constexpr (std::is_same_v) - { - atomicMaxExtd(quant_ptr + col, scale); - } - else - { - auto const address_u64 = reinterpret_cast(quant_ptr + col); - if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0)) - atomicMaxExtd(quant_ptr + col, scale); - else - atomicMaxExtdV2(quant_ptr + col, scale); - } -#else // Vector atomics require __CUDA_ARCH__ >= 900 - atomicMaxExtd(quant_ptr + col, scale); -#endif - } - } - else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) - { - auto const nrows = size / n; - for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) - { - float max = 0.f; - for (int64_t i = threadIdx.x; i < n; i += blockDim.x) - { - auto val = fabs(static_cast(weights[row * n + i])); - max = max > val ? max : val; - } - max = blockReduceMax(max); - if (threadIdx.x == 0) - { - auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); - quant_ptr[row] = scale; - } - } - } - else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) - { - float max = 0.f; - for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x) - { - auto val = fabs(static_cast(weights[i])); - max = max > val ? max : val; - } - max = blockReduceMax(max); - if (threadIdx.x == 0) - { - auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); - atomicMaxExtd(quant_ptr, scale); - } - } -} - -template -void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream) -{ - if (quantize_mode == QuantizeMode::PER_TOKEN) - { - dim3 block(CTA_SIZE); - dim3 grid(numel / lda); - computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_CHANNEL) - { - dim3 block(CTA_SIZE); - dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); - cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); - sync_check_cuda_error(); - computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_TENSOR) - { - dim3 block(1024); - dim3 grid(1024); - cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); - sync_check_cuda_error(); - computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); - } - sync_check_cuda_error(); -} - -#define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \ - template void invokeComputeFP8QuantizeScale(type_scale * input_scale, type_in const* weights, \ - int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream); - -DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half); -DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half); -DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float); -#ifdef ENABLE_BF16 -DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16); -DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16); -#endif - -template -__global__ void dynamicQuantizeMatrixPerToken( - T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda) -{ - extern __shared__ __align__(sizeof(float)) char _shmem[]; - T_IN* shmem = reinterpret_cast(_shmem); - constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); - auto const nrows = numel / lda; - for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) - { - float max = 0.f; - for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) - { - auto const in = input[row * lda + i]; - shmem[i] = in; - auto val = fabs(static_cast(in)); - max = max > val ? max : val; - } - max = blockAllReduceMax(max); // __syncthreads() called so we can read shmem - auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); - for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) - { - // true means we are quantizing - output[row * lda + i] = (T_OUT) scale(static_cast(shmem[i]), static_cast(s)); - } - if (threadIdx.x == 0) - { - quant_ptr[row] = s; - } - } -} - -template -void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel, - const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream) -{ - if (quantize_mode == QuantizeMode::PER_TOKEN) - { - dim3 grid(numel / lda); - bool use_shmem = true; - auto const shmem_size = lda * sizeof(T_IN); - if (shmem_size >= (48 << 10)) - { - cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); - use_shmem = ret == cudaSuccess; - } - if (use_shmem) - { - // ensure the threadblock is as large as possible to increase occupancy - dim3 block(std::min((lda + 31) / 32 * 32, static_cast(1024))); - dynamicQuantizeMatrixPerToken<<>>(output, quant_ptr, input, numel, lda); - } - else - { - dim3 block(CTA_SIZE); - computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); - sync_check_cuda_error(); - invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); - } - } - else if (quantize_mode == QuantizeMode::PER_CHANNEL) - { - dim3 block(CTA_SIZE); - dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); - cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); - sync_check_cuda_error(); - computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); - sync_check_cuda_error(); - invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); - } - else if (quantize_mode == QuantizeMode::PER_TENSOR) - { - dim3 block(1024); - dim3 grid(1024); - cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); - sync_check_cuda_error(); - computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); - sync_check_cuda_error(); - invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); - } - sync_check_cuda_error(); -} - -#define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \ - template void invokeQuantizeMatrix(type_out * output, \ - type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ - cudaStream_t stream); \ - template void invokeDequantizeMatrix(type_out * output, \ - type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ - cudaStream_t stream); \ - template void invokeComputeScalesAndQuantizeMatrix(type_out * output, \ - type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ - cudaStream_t stream); - -#ifdef ENABLE_FP8 -DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float); -DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half); -DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half); -DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3); -DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3); -DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3); -#ifdef ENABLE_BF16 -DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16); -DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3); -#endif -#endif - -#endif // ENABLE_FP8 -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h deleted file mode 100644 index aa93b55a57..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef ENABLE_FP8 -#include -#include -#include - -#define FP8_MHA -#define FUSE_GEMM_ACT -#define FP8_GEMM_OUTPUT_QUANT_DISABLE - -#ifdef FUSE_GEMM_ACT -#define USE_QGMMA -#endif - -namespace tensorrt_llm -{ -namespace common -{ - -constexpr float FP8_E4M3_MAX = 448.0f; - -enum QuantizeMode -{ - PER_CHANNEL, - PER_TENSOR, - PER_CHANNEL_WEIGHT_PER_TENSOR_ACT, - PER_TOKEN, -}; - -// Packed Data Type -typedef struct __CUDA_ALIGN__(32) -{ - float array[8]; -} float8; - -typedef struct __CUDA_ALIGN__(16) -{ - half array[8]; -} half8; - -typedef struct __CUDA_ALIGN__(8) -{ - half2 array[2]; -} half2_2; - -typedef struct __CUDA_ALIGN__(8) -{ - half array[4]; -} half_4; - -#ifdef ENABLE_BF16 -typedef struct __CUDA_ALIGN__(4) -{ - __nv_bfloat16 array[2]; -} __nv_bfloat16_2; - -typedef struct __CUDA_ALIGN__(8) -{ - __nv_bfloat162 x, y; -} __nv_bfloat162_2_xy; - -typedef struct __CUDA_ALIGN__(8) -{ - __nv_bfloat16 array[4]; -} __nv_bfloat164; - -typedef struct __CUDA_ALIGN__(8) -{ - __nv_bfloat162 array[2]; -} __nv_bfloat162_2; - -typedef struct __CUDA_ALIGN__(16) -{ - __nv_bfloat16 array[8]; -} __nv_bfloat168; - -typedef struct __CUDA_ALIGN__(16) -{ - __nv_bfloat162 array[4]; -} __nv_bfloat162_4; - -typedef struct __CUDA_ALIGN__(32) -{ - __nv_bfloat16 array[16]; -} __nv_bfloat1616; -#endif - -#ifdef ENABLE_FP8 -typedef struct __CUDA_ALIGN__(2) -{ - __nv_fp8_e4m3 array[2]; -} __nv_fp8_2_e4m3; - -typedef struct __CUDA_ALIGN__(4) -{ - __nv_fp8_e4m3 array[4]; -} __nv_fp8_4_e4m3; - -typedef struct __CUDA_ALIGN__(4) -{ - __nv_fp8x2_e4m3 array[2]; -} __nv_fp8x2_x2_e4m3; - -typedef struct __CUDA_ALIGN__(8) -{ - __nv_fp8_e4m3 array[8]; -} __nv_fp8_8_e4m3; - -typedef struct __CUDA_ALIGN__(8) -{ - __nv_fp8x2_e4m3 array[4]; -} __nv_fp8x2_x4_e4m3; - -typedef struct __CUDA_ALIGN__(16) -{ - __nv_fp8_e4m3 array[16]; -} __nv_fp8x16_e4m3; -#endif - -// only BF16 and FP8 -template -struct PackType -{ - using type = float; -}; - -#ifdef ENABLE_BF16 -template <> -struct PackType<__nv_bfloat16, 2> -{ - using type = __nv_bfloat16_2; -}; - -template <> -struct PackType<__nv_bfloat16, 4> -{ - using type = __nv_bfloat164; -}; - -template <> -struct PackType<__nv_bfloat16, 8> -{ - using type = __nv_bfloat168; -}; -#endif - -#ifdef ENABLE_FP8 -template <> -struct PackType<__nv_fp8_e4m3, 2> -{ - using type = __nv_fp8_2_e4m3; -}; - -template <> -struct PackType<__nv_fp8_e4m3, 4> -{ - using type = __nv_fp8_4_e4m3; -}; - -template <> -struct PackType<__nv_fp8_e4m3, 8> -{ - using type = __nv_fp8_8_e4m3; -}; -#endif - -__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, __nv_fp8x4_e4m3 const* in) -{ - const char4 tmp_val = reinterpret_cast(in)[0]; - *out1 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); - *out2 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]); -} - -__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(__nv_fp8x2_e4m3 const* in) -{ - const char2 tmp_val = reinterpret_cast(in)[0]; - __nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); - return out; -} - -__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, __nv_fp8x4_e4m3 const* in) -{ - const char4 tmp_val = reinterpret_cast(in)[0]; - *out1 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); - *out2 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]); -} - -__inline__ __device__ half2 fp8x2_e4m3_to_half2(__nv_fp8x2_e4m3 const* in) -{ - const char2 tmp_val = reinterpret_cast(in)[0]; - half2 out = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); - return out; -} - -template -void invokeQuantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream); - -template -void invokeDequantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream); - -template -void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream); - -template -void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t k, const int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream); - -template -void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* weights, const int64_t numel, - const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream); - -} // namespace common -} // namespace tensorrt_llm -#endif // ENABLE_FP8 diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh deleted file mode 100644 index a0463a3a49..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh +++ /dev/null @@ -1,752 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" -#include "tensorrt_llm/common/cudaBf16Wrapper.h" -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include -#include -#include -#if ENABLE_BF16 -#include -#endif - -namespace tensorrt_llm -{ -namespace common -{ - -template -inline __device__ T ldg(T const* val) -{ - return __ldg(val); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return val[0]; -#else - return __ldg(val); -#endif -} - -template <> -inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return val[0]; -#else - return __ldg(val); -#endif -} -#endif // ENABLE_BF16 - -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter -{ - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter -{ - using Type = half; -}; - -template <> -struct TypeConverter -{ - using Type = half2; -}; - -#if ENABLE_BF16 -template <> -struct TypeConverter<__nv_bfloat162> -{ - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> -{ - using Type = __nv_bfloat162; -}; -#endif // ENABLE_BF16 - -// Defined math operations (bfloat16 fallback to fp32 when it is not supported) -template -inline __device__ T hadd2(T a, T b) -{ - return __hadd2(a, b); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hadd2(a, b); -} -#endif // ENABLE_BF16 - -template -inline __device__ T add(T a, T b) -{ - return a + b; -} - -template <> -inline __device__ half2 add(half2 a, half2 b) -{ - return __hadd2(a, b); -} - -template <> -inline __device__ half add(half a, half b) -{ - return __hadd(a, b); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hadd2(a, b); -} - -template <> -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) -{ - return bf16hadd(a, b); -} - -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b) -{ - return bf16hadd(a, __float2bfloat16(b)); -} -#endif // ENABLE_BF16 - -// applies to all 4 values addition -template -inline __device__ T add(T a, T b, T c) -{ - return a + b + c; -} - -#if ENABLE_BF16 -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) -{ - return bf16hadd(a, b, c); -} - -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hadd2(a, b, c); -} -#endif // ENABLE_BF16 - -// applies to all 4 values addition -template -inline __device__ T add(T a, T b, T c, T d) -{ - return (T) ((float) a + (float) b + (float) c + (float) d); -} - -#if ENABLE_BF16 -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) -{ - return bf16hadd(a, b, c, d); -} -#endif // ENABLE_BF16 - -template -inline __device__ T hsub2(T a, T b) -{ - return __hsub2(a, b); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hsub2(a, b); -} -#endif // ENABLE_BF16 - -template -inline __device__ T hmul2(T a, T b) -{ - return __hmul2(a, b); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hmul2(a, b); -} -#endif // ENABLE_BF16 - -template -inline __device__ T hmul2(T a, T b, T c) -{ - return a * b * c; -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hmul2(a, b, c); -} -#endif // ENABLE_BF16 - -template -inline __device__ T mul(T a, T b, T c) -{ - return a * b * c; -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) -{ - return bf16hmul(a, b, c); -} - -inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hmul2(a, b, c); -} -#endif // ENABLE_BF16 - -template -inline __device__ T fma(T a, T b, T c, T d) -{ - return a * b * c + d; -} - -#if ENABLE_BF16 -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) -{ - return bf16hfma2(a, b, c, d); -} -#endif // ENABLE_BF16 - -template -inline __device__ T fma(T a, T b, T c) -{ - return a * b + c; -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(a, b, c); -} - -template <> -inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) -{ - return bf16hfma(a, b, c); -} -#endif // ENABLE_BF16 - -template -inline __device__ T hexp2(T a) -{ - return h2exp(a); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a) -{ - return bf16exp2(a); -} -#endif // ENABLE_BF16 - -template -__device__ inline T_OUT cuda_cast(T_IN val) -{ - return val; -} - -template <> -__device__ inline float2 cuda_cast(int2 val) -{ - return make_float2(val.x, val.y); -} - -template <> -__device__ inline float2 cuda_cast(float val) -{ - return make_float2(val, val); -} - -template <> -__device__ inline float2 cuda_cast(half2 val) -{ - return __half22float2(val); -} - -template <> -__device__ inline half2 cuda_cast(float2 val) -{ - return __float22half2_rn(val); -} - -template <> -__device__ inline half2 cuda_cast(float val) -{ - return __float2half2_rn(val); -} - -template <> -__device__ inline half2 cuda_cast(half val) -{ - return __half2half2(val); -} - -template <> -__device__ inline int8_t cuda_cast(half val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - union - { - half fp16; - int16_t int16_in; - }; - - fp16 = val; - asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); - return int8[0]; -} - -template <> -__device__ inline int16_t cuda_cast(half2 val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - int8[0] = cuda_cast(val.x); - int8[1] = cuda_cast(val.y); - return int16; -} - -template <> -__device__ inline int8_t cuda_cast(float val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); - return int8[0]; -} - -template <> -__device__ inline int16_t cuda_cast(float2 val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - int8[0] = cuda_cast(val.x); - int8[1] = cuda_cast(val.y); - return int16; -} - -template <> -__device__ inline half2 cuda_cast(int16_t val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - int16 = val; - return make_half2(int8[0], int8[1]); -} - -template <> -__device__ inline float2 cuda_cast(int16_t val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - int16 = val; - return make_float2(int8[0], int8[1]); -} - -#ifdef ENABLE_BF16 -template <> -__device__ inline __nv_bfloat16 cuda_cast(int32_t val) -{ - return static_cast(val); -} - -template <> -__device__ inline __nv_bfloat16 cuda_cast(int8_t val) -{ - return static_cast(val); -} - -template <> -__device__ inline int8_t cuda_cast(__nv_bfloat16 val) -{ - return static_cast(val); -} - -template <> -__device__ inline float cuda_cast(__nv_bfloat16 val) -{ - return __bfloat162float(val); -} - -template <> -__device__ inline float2 cuda_cast(__nv_bfloat162 val) -{ - return bf1622float2(val); -} - -template <> -__device__ inline half cuda_cast(__nv_bfloat16 val) -{ - return __float2half(__bfloat162float(val)); -} - -template <> -__device__ inline int16_t cuda_cast(__nv_bfloat162 val) -{ - return bf1622int16(val); -} - -template <> -__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) -{ - return __float2bfloat16(val); -} - -template <> -__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) -{ - return __float2bfloat16(__half2float(val)); -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) -{ - return bf162bf162(val); -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) -{ - return __float2bfloat162_rn(val); -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) -{ - return float22bf162(val); -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - int16 = val; - __nv_bfloat162 res; - res.x = cuda_cast<__nv_bfloat16>(int8[0]); - res.y = cuda_cast<__nv_bfloat16>(int8[1]); - return res; -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) -{ - return float22bf162(__half22float2(val)); -} - -#endif // ENABLE BF16 - -template -__device__ inline T cuda_abs(T val) -{ - assert(false); - return {}; -} - -template <> -__device__ inline float cuda_abs(float val) -{ - return fabs(val); -} - -template <> -__device__ inline float2 cuda_abs(float2 val) -{ - return make_float2(fabs(val.x), fabs(val.y)); -} - -template <> -__device__ inline half cuda_abs(half val) -{ - return __habs(val); -} - -template <> -__device__ inline half2 cuda_abs(half2 val) -{ - return __habs2(val); -} - -#ifdef ENABLE_BF16 - -#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) -template <> -__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) -{ - return __habs(val); -} - -template <> -__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) -{ - return __habs2(val); -} -#endif - -#endif // ENABLE_FP16 - -template -__device__ inline To cuda_sum(Ti val) -{ - return cuda_cast(val); -}; - -template -__device__ inline To cuda_sum(float2 val) -{ - return cuda_cast(val.x + val.y); -}; - -// Unary maximum: compute the max of a vector type -template -__device__ inline To cuda_max(Ti val) -{ - return cuda_cast(val); -}; - -template <> -__device__ inline float cuda_max(float2 val) -{ - return fmaxf(val.x, val.y); -} - -template <> -__device__ inline half cuda_max(half2 val) -{ - return __hmax(val.x, val.y); -} - -#ifdef ENABLE_BF16 -template <> -__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) -{ -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - return __hmax(val.x, val.y); -#else - assert(0); - asm volatile("brkpt;\n" ::); - return __nv_bfloat16(0); -#endif -} -#endif - -// Binary maximum: compute the max of two values. -template -__device__ inline T cuda_max(T val1, T val2) -{ - return (val1 > val2) ? val1 : val2; -} - -template <> -__device__ inline float2 cuda_max(float2 val1, float2 val2) -{ - float2 out; - out.x = fmaxf(val1.x, val2.x); - out.y = fmaxf(val1.y, val2.y); - return out; -} - -template <> -__device__ inline half2 cuda_max(half2 val1, half2 val2) -{ - return __hmax2(val1, val2); -} - -#ifdef ENABLE_BF16 -template <> -__device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2) -{ - return __hmax2(val1, val2); -} -#endif // ENABLE_BF16 - -// Binary maximum: compute the min of two values. -template -__device__ inline T cuda_min(T val1, T val2) -{ - return (val1 < val2) ? val1 : val2; -} - -template <> -__device__ inline float2 cuda_min(float2 val1, float2 val2) -{ - float2 out; - out.x = fminf(val1.x, val2.x); - out.y = fminf(val1.y, val2.y); - return out; -} - -template <> -__device__ inline half2 cuda_min(half2 val1, half2 val2) -{ - return __hmin2(val1, val2); -} - -#ifdef ENABLE_BF16 -template <> -__device__ inline __nv_bfloat162 cuda_min(__nv_bfloat162 val1, __nv_bfloat162 val2) -{ - return __hmin2(val1, val2); -} -#endif // ENABLE_BF16 - -// Helper function of clamping the val into the given range. -template -inline __device__ T cuda_clamp(T val, T minVal, T maxVal) -{ - return cuda_min(cuda_max(val, minVal), maxVal); -} - -#ifdef ENABLE_FP8 -template <> -__device__ inline float2 cuda_cast(__nv_fp8x2_e4m3 val) -{ - return bf1622float2(fp8x2_e4m3_to_bfloat2(&val)); -} - -template <> -__device__ inline half2 cuda_cast(__nv_fp8x2_e4m3 val) -{ - return fp8x2_e4m3_to_half2(&val); -} - -template <> -__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val) -{ - return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val))); -} - -template <> -__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, half2>(half2 val) -{ - return __nv_fp8x2_e4m3(cuda_cast(val)); -} - -template <> -__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, __nv_bfloat162>(__nv_bfloat162 val) -{ - return __nv_fp8x2_e4m3(cuda_cast(val)); -} - -template <> -__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val) -{ - return __nv_fp8_e4m3(val); -} - -template <> -__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val) -{ - return __nv_fp8_e4m3(val); -} - -template <> -__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val) -{ - return __nv_fp8_e4m3(val); -} - -template <> -__device__ inline float cuda_cast(__nv_fp8_e4m3 val) -{ - return (float) val; -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val) -{ - return fp8x2_e4m3_to_bfloat2(&val); -} - -template <> -__device__ inline int8_t cuda_cast(__nv_fp8_e4m3 val) -{ - // no impl - return 0; -} - -template <> -__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val) -{ - return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast(val))); -} - -#endif // ENABLE_FP8 - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h deleted file mode 100644 index 13ee3367e9..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h +++ /dev/null @@ -1,641 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "tensorrt_llm/common/cudaBf16Wrapper.h" -#include "tensorrt_llm/common/cudaDriverWrapper.h" -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/common/tllmException.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifndef _WIN32 // Linux -#include -#endif // not WIN32 -#include -#ifdef _WIN32 // Windows -#include -#undef ERROR // A Windows header file defines ERROR as 0, but it's used in our logger.h enum. Logging breaks without - // this undef. -#endif // WIN32 - -namespace tensorrt_llm::common -{ - -// workspace for cublas gemm : 32MB -#define CUBLAS_WORKSPACE_SIZE 33554432 - -typedef struct __align__(4) -{ - half x, y, z, w; -} - -half4; - -/* **************************** type definition ***************************** */ - -enum CublasDataType -{ - FLOAT_DATATYPE = 0, - HALF_DATATYPE = 1, - BFLOAT16_DATATYPE = 2, - INT8_DATATYPE = 3, - FP8_DATATYPE = 4 -}; - -enum TRTLLMCudaDataType -{ - FP32 = 0, - FP16 = 1, - BF16 = 2, - INT8 = 3, - FP8 = 4 -}; - -enum class OperationType -{ - FP32, - FP16, - BF16, - INT8, - FP8 -}; - -/* **************************** debug tools ********************************* */ -static char const* _cudaGetErrorEnum(cudaError_t error) -{ - return cudaGetErrorString(error); -} - -static char const* _cudaGetErrorEnum(cublasStatus_t error) -{ - switch (error) - { - case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; - - case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; - - case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; - - case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; - - case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; - - case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; - - case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; - - case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; - - case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; - - case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return ""; -} - -template -void check(T result, char const* const func, char const* const file, int const line) -{ - if (result) - { - throw TllmException( - file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(result))); - } -} - -template -void checkEx(T result, std::initializer_list const& validReturns, char const* const func, char const* const file, - int const line) -{ - if (std::all_of(std::begin(validReturns), std::end(validReturns), [&result](T const& t) { return t != result; })) - { - throw TllmException( - file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(result))); - } -} - -#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) -#define check_cuda_error_2(val, file, line) check((val), #val, file, line) - -inline std::optional isCudaLaunchBlocking() -{ - static bool firstCall = true; - static std::optional result = std::nullopt; - - if (firstCall) - { - char const* env = std::getenv("CUDA_LAUNCH_BLOCKING"); - if (env != nullptr && std::string(env) == "1") - { - result = true; - } - else if (env != nullptr && std::string(env) == "0") - { - result = false; - } - firstCall = false; - } - - return result; -} - -inline bool doCheckError() -{ - auto const cudaLaunchBlocking = isCudaLaunchBlocking(); -#ifndef NDEBUG - bool const checkError = cudaLaunchBlocking.value_or(true); -#else - bool const checkError = cudaLaunchBlocking.value_or(false); -#endif - - return checkError; -} - -inline void syncAndCheck(char const* const file, int const line) -{ - if (doCheckError()) - { - cudaDeviceSynchronize(); - check(cudaGetLastError(), "cudaGetLastError", file, line); - } -} - -#define sync_check_cuda_error() tensorrt_llm::common::syncAndCheck(__FILE__, __LINE__) - -#define PRINT_FUNC_NAME_() \ - do \ - { \ - std::cout << "[TensorRT-LLM][CALL] " << __FUNCTION__ << " " << std::endl; \ - } while (0) - -// clang-format off -template struct packed_type; -template <> struct packed_type { using type = float; }; // we don't need to pack float by default -template <> struct packed_type { using type = half2; }; - -#ifdef ENABLE_BF16 -template<> -struct packed_type<__nv_bfloat16> { - using type = __nv_bfloat162; -}; -#endif - -#ifdef ENABLE_FP8 -template<> -struct packed_type<__nv_fp8_e4m3> { - using type = __nv_fp8x2_e4m3; -}; -#endif - -template struct num_elems; -template <> struct num_elems { static constexpr int value = 1; }; -template <> struct num_elems { static constexpr int value = 2; }; -template <> struct num_elems { static constexpr int value = 4; }; -template <> struct num_elems { static constexpr int value = 1; }; -template <> struct num_elems { static constexpr int value = 2; }; -#ifdef ENABLE_BF16 -template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; }; -template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; }; -#endif -#ifdef ENABLE_FP8 -template <> struct num_elems<__nv_fp8_e4m3> { static constexpr int value = 1; }; -template <> struct num_elems<__nv_fp8x2_e4m3> { static constexpr int value = 2; }; -#endif - -template struct packed_as; -template struct packed_as { using type = T; }; -template<> struct packed_as { using type = half2; }; -template<> struct packed_as { using type = float2; }; -template<> struct packed_as { using type = int16_t; }; -template<> struct packed_as { using type = int2; }; -template<> struct packed_as { using type = half; }; -template<> struct packed_as { using type = float; }; -#ifdef ENABLE_BF16 -template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; }; -template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; }; -#endif -#ifdef ENABLE_FP8 -template<> struct packed_as<__nv_fp8_e4m3, 2> { using type = __nv_fp8x2_e4m3; }; -template<> struct packed_as<__nv_fp8x2_e4m3, 1> { using type = __nv_fp8_e4m3; }; -template<> struct packed_as<__nv_fp8_e5m2, 2> { using type = __nv_fp8x2_e5m2; }; -template<> struct packed_as<__nv_fp8x2_e5m2, 1> { using type = __nv_fp8_e5m2; }; -#endif - -inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } -inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); } -inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); } - -inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } -inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); } -inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); } - -// clang-format on - -template -struct CudaDataType -{ -}; - -template <> -struct CudaDataType -{ - static constexpr cudaDataType_t value = cudaDataType::CUDA_R_32F; -}; - -template <> -struct CudaDataType -{ - static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16F; -}; - -#ifdef ENABLE_BF16 -template <> -struct CudaDataType<__nv_bfloat16> -{ - static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16BF; -}; -#endif - -inline int getSMVersion() -{ - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - int sm_major = 0; - int sm_minor = 0; - check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); - check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); - return sm_major * 10 + sm_minor; -} - -inline int getDevice() -{ - int current_dev_id = 0; - check_cuda_error(cudaGetDevice(¤t_dev_id)); - return current_dev_id; -} - -inline int getDeviceCount() -{ - int count = 0; - check_cuda_error(cudaGetDeviceCount(&count)); - return count; -} - -/// @brief Identifies the memory type of the given pointer. -template -cudaMemoryType getPtrCudaMemoryType(T* ptr) -{ - cudaPointerAttributes attributes{}; - check_cuda_error(cudaPointerGetAttributes(&attributes, ptr)); - return attributes.type; -} - -/// Get the memory info -/// \return The free and total amount of memory in bytes -inline std::tuple getDeviceMemoryInfo(bool const useUvm) -{ - if (useUvm) - { - size_t freeSysMem = 0; - size_t totalSysMem = 0; -#ifndef _WIN32 // Linux - struct sysinfo info - { - }; - - sysinfo(&info); - totalSysMem = info.totalram * info.mem_unit; - freeSysMem = info.freeram * info.mem_unit; -#else // Windows - MEMORYSTATUSEX memInfo; - memInfo.dwLength = sizeof(memInfo); - GlobalMemoryStatusEx(&memInfo); - totalSysMem = memInfo.ullTotalPhys; - freeSysMem = memInfo.ullAvailPhys; -#endif // WIN32 - - TLLM_LOG_INFO("Using UVM based system memory for KV cache, total memory %0.2f GB, available memory %0.2f GB", - ((double) totalSysMem / 1e9), ((double) freeSysMem / 1e9)); - return {freeSysMem, totalSysMem}; - } - - size_t free = 0; - size_t total = 0; - check_cuda_error(cudaMemGetInfo(&free, &total)); - TLLM_LOG_DEBUG("Using GPU memory for KV cache, total memory %0.2f GB, available memory %0.2f GB", - ((double) total / 1e9), ((double) free / 1e9)); - return {free, total}; -} - -/// @brief Gets the memory allocation granularity for the current device. -/// -/// @return size_t The size of the smallest difference in memory size supported by the current device. -inline size_t getAllocationGranularity() -{ - auto const currentDevice = getDevice(); - ::CUmemAllocationProp prop = {}; - - prop.type = ::CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = ::CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = currentDevice; - prop.requestedHandleTypes = ::CU_MEM_HANDLE_TYPE_NONE; - - // Get the minimum granularity supported for allocation with cuMemCreate() - size_t granularity = 0; - TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); - return granularity; -} - -inline int getMultiProcessorCount() -{ - int device_id = 0; - int multi_processor_count = 0; - check_cuda_error(cudaGetDevice(&device_id)); - check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id)); - return multi_processor_count; -} - -inline int getMaxSharedMemoryPerBlockOptin() -{ - int device_id = 0; - int max_shared_memory_per_block = 0; - check_cuda_error(cudaGetDevice(&device_id)); - check_cuda_error( - cudaDeviceGetAttribute(&max_shared_memory_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id)); - return max_shared_memory_per_block; -} - -template -inline size_t divUp(const T1& a, const T2& n) -{ - auto const tmp_a = static_cast(a); - auto const tmp_n = static_cast(n); - return (tmp_a + tmp_n - 1) / tmp_n; -} - -inline int roundUp(int a, int n) -{ - return divUp(a, n) * n; -} - -template ::value>, - typename = std::enable_if_t::value>> -auto constexpr ceilDiv(T numerator, U denominator) -{ - return (numerator + denominator - 1) / denominator; -} - -template -void printAbsMean(T const* buf, uint64_t size, cudaStream_t stream, std::string name = "") -{ - if (buf == nullptr) - { - TLLM_LOG_WARNING("%s is an nullptr, skip!", name.c_str()); - return; - } - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - T* h_tmp = new T[size]; - cudaMemcpyAsync(h_tmp, buf, sizeof(T) * size, cudaMemcpyDeviceToHost, stream); - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - double sum = 0.0f; - uint64_t zero_count = 0; - float max_val = -1e10; - bool find_inf = false; - for (uint64_t i = 0; i < size; i++) - { - if (std::isinf((float) (h_tmp[i]))) - { - find_inf = true; - continue; - } - sum += abs((double) h_tmp[i]); - if ((float) h_tmp[i] == 0.0f) - { - zero_count++; - } - max_val = max_val > abs(float(h_tmp[i])) ? max_val : abs(float(h_tmp[i])); - } - TLLM_LOG_INFO("%20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, find inf: %s", name.c_str(), size, sum / size, - sum, max_val, find_inf ? "true" : "false"); - delete[] h_tmp; - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -} - -template -void printToStream(T const* result, int const size, FILE* strm) -{ - bool const split_rows = (strm == stdout); - if (result == nullptr) - { - TLLM_LOG_WARNING("It is an nullptr, skip! \n"); - return; - } - T* tmp = reinterpret_cast(malloc(sizeof(T) * size)); - check_cuda_error(cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost)); - for (int i = 0; i < size; ++i) - { - fprintf(strm, "%f, ", static_cast(tmp[i])); - if (split_rows && ((i + 1) % 10) == 0) - fprintf(strm, "\n"); - } - if (!split_rows || (size % 10) != 0) - { - fprintf(strm, "\n"); - } - free(tmp); -} - -template -void printToScreen(T const* result, int const size) -{ - printToStream(result, size, stdout); -} - -template -void print2dToStream(T const* result, int const r, int const c, int const stride, FILE* strm) -{ - if (result == nullptr) - { - TLLM_LOG_WARNING("It is an nullptr, skip! \n"); - return; - } - for (int ri = 0; ri < r; ++ri) - { - T const* ptr = result + ri * stride; - printToStream(ptr, c, strm); - } - fprintf(strm, "\n"); -} - -template -void print2dToScreen(T const* result, int const r, int const c, int const stride) -{ - print2dToStream(result, r, c, stride, stdout); -} - -template -void print2dToFile(std::string fname, T const* result, int const r, int const c, int const stride) -{ - FILE* fp = fopen(fname.c_str(), "wt"); - if (fp != nullptr) - { - print2dToStream(result, r, c, stride, fp); - fclose(fp); - } -} - -inline void print_float_(float x) -{ - printf("%7.3f ", x); -} - -inline void print_element_(float x) -{ - print_float_(x); -} - -inline void print_element_(half x) -{ - print_float_((float) x); -} - -#ifdef ENABLE_BF16 -inline void print_element_(__nv_bfloat16 x) -{ - print_float_((float) x); -} -#endif - -#ifdef ENABLE_FP8 -inline void print_element_(__nv_fp8_e4m3 x) -{ - print_float_((float) x); -} -#endif - -inline void print_element_(uint32_t ul) -{ - printf("%7" PRIu32, ul); -} - -inline void print_element_(uint64_t ull) -{ - printf("%7" PRIu64, ull); -} - -inline void print_element_(int32_t il) -{ - printf("%7" PRId32, il); -} - -inline void print_element_(int64_t ill) -{ - printf("%7" PRId64, ill); -} - -template -inline void printMatrix(T const* ptr, int m, int k, int stride, bool is_device_ptr) -{ - T* tmp; - if (is_device_ptr) - { - // k < stride ; stride = col-dimension. - tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); - check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); - cudaDeviceSynchronize(); - } - else - { - tmp = const_cast(ptr); - } - - for (int ii = -1; ii < m; ++ii) - { - if (ii >= 0) - { - printf("%07d ", ii); - } - else - { - printf(" "); - } - - for (int jj = 0; jj < k; jj += 1) - { - if (ii >= 0) - { - print_element_(tmp[ii * stride + jj]); - } - else - { - printf("%7d ", jj); - } - } - printf("\n"); - } - if (is_device_ptr) - { - free(tmp); - } -} - -template void printMatrix(float const* ptr, int m, int k, int stride, bool is_device_ptr); -template void printMatrix(half const* ptr, int m, int k, int stride, bool is_device_ptr); -#ifdef ENABLE_BF16 -template void printMatrix(__nv_bfloat16 const* ptr, int m, int k, int stride, bool is_device_ptr); -#endif -#ifdef ENABLE_FP8 -template void printMatrix(__nv_fp8_e4m3 const* ptr, int m, int k, int stride, bool is_device_ptr); -#endif -template void printMatrix(uint32_t const* ptr, int m, int k, int stride, bool is_device_ptr); -template void printMatrix(uint64_t const* ptr, int m, int k, int stride, bool is_device_ptr); -template void printMatrix(int const* ptr, int m, int k, int stride, bool is_device_ptr); - -} // namespace tensorrt_llm::common - -/* - * Macros compliant with TensorRT coding conventions - */ -#define TLLM_CUDA_CHECK(stat) \ - do \ - { \ - tensorrt_llm::common::check((stat), #stat, __FILE__, __LINE__); \ - } while (0) - -// We use singleton memory pool and the order of destructors depends on the compiler implementation. We find that the -// cudaFree/cudaFreeHost is called after cudaruntime destruction on Windows. There will be an cudaErrorCudartUnloading -// error. However, it is safe to ignore this error because the cuda runtime is already exited, we are no more worried -// about the memory leaks. -#define TLLM_CUDA_CHECK_FREE_RESOURCE(stat) \ - do \ - { \ - tensorrt_llm::common::checkEx((stat), {cudaSuccess, cudaErrorCudartUnloading}, #stat, __FILE__, __LINE__); \ - } while (0) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp deleted file mode 100644 index 334ad23690..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/common/tllmException.h" -#include - -namespace tensorrt_llm::common -{ - -Logger::Logger() -{ - char* isFirstRankOnlyChar = std::getenv("TLLM_LOG_FIRST_RANK_ONLY"); - bool isFirstRankOnly = (isFirstRankOnlyChar != nullptr && std::string(isFirstRankOnlyChar) == "ON"); - - auto const* levelName = std::getenv("TLLM_LOG_LEVEL"); - if (levelName != nullptr) - { - auto level = [levelName = std::string(levelName)]() - { - if (levelName == "TRACE") - return TRACE; - if (levelName == "DEBUG") - return DEBUG; - if (levelName == "INFO") - return INFO; - if (levelName == "WARNING") - return WARNING; - if (levelName == "ERROR") - return ERROR; - TLLM_THROW("Invalid log level: %s", levelName.c_str()); - }(); - // If TLLM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR - if (isFirstRankOnly) - { - auto const deviceId = getDevice(); - if (deviceId != 1) - { - level = ERROR; - } - } - setLevel(level); - } -} - -void Logger::log(std::exception const& ex, Logger::Level level) -{ - log(level, "%s: %s", TllmException::demangle(typeid(ex).name()).c_str(), ex.what()); -} - -Logger* Logger::getLogger() -{ - thread_local Logger instance; - return &instance; -} -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h deleted file mode 100644 index df84e22638..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/stringUtils.h" - -namespace tensorrt_llm::common -{ - -class Logger -{ - -// On Windows, the file wingdi.h is included which has -// #define ERROR 0 -// This breaks everywhere ERROR is used in the Level enum -#ifdef _WIN32 -#undef ERROR -#endif // _WIN32 - -public: - enum Level - { - TRACE = 0, - DEBUG = 10, - INFO = 20, - WARNING = 30, - ERROR = 40 - }; - - static Logger* getLogger(); - - Logger(Logger const&) = delete; - void operator=(Logger const&) = delete; - -#if defined(_MSC_VER) - template - void log(Level level, char const* format, Args const&... args); - - template - void log(Level level, int rank, char const* format, Args const&... args); -#else - template - void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0))); - - template - void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0))); -#endif - - template - void log(Level level, std::string const& format, Args const&... args) - { - return log(level, format.c_str(), args...); - } - - template - void log(Level const level, int const rank, std::string const& format, Args const&... args) - { - return log(level, rank, format.c_str(), args...); - } - - void log(std::exception const& ex, Level level = Level::ERROR); - - Level getLevel() const - { - return level_; - } - - void setLevel(Level const level) - { - level_ = level; - log(INFO, "Set logger level to %s", getLevelName(level)); - } - - bool isEnabled(Level const level) const - { - return level_ <= level; - } - -private: - static auto constexpr kPREFIX = "[TensorRT-LLM]"; - -#ifndef NDEBUG - Level const DEFAULT_LOG_LEVEL = DEBUG; -#else - Level const DEFAULT_LOG_LEVEL = INFO; -#endif - Level level_ = DEFAULT_LOG_LEVEL; - - Logger(); // NOLINT(modernize-use-equals-delete) - - static inline char const* getLevelName(Level const level) - { - switch (level) - { - case TRACE: return "TRACE"; - case DEBUG: return "DEBUG"; - case INFO: return "INFO"; - case WARNING: return "WARNING"; - case ERROR: return "ERROR"; - } - - TLLM_THROW("Unknown log level: %d", level); - } - - static inline std::string getPrefix(Level const level) - { - return fmtstr("%s[%s] ", kPREFIX, getLevelName(level)); - } - - static inline std::string getPrefix(Level const level, int const rank) - { - return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank); - } -}; - -template -void Logger::log(Logger::Level level, char const* format, Args const&... args) -{ - if (isEnabled(level)) - { - auto const fmt = getPrefix(level) + format; - auto& out = level_ < WARNING ? std::cout : std::cerr; - if constexpr (sizeof...(args) > 0) - { - out << fmtstr(fmt.c_str(), args...); - } - else - { - out << fmt; - } - out << std::endl; - } -} - -template -void Logger::log(Logger::Level const level, int const rank, char const* format, Args const&... args) -{ - if (isEnabled(level)) - { - auto const fmt = getPrefix(level, rank) + format; - auto& out = level_ < WARNING ? std::cout : std::cerr; - if constexpr (sizeof...(args) > 0) - { - out << fmtstr(fmt.c_str(), args...); - } - else - { - out << fmt; - } - out << std::endl; - } -} - -#define TLLM_LOG(level, ...) \ - do \ - { \ - auto* const logger = tensorrt_llm::common::Logger::getLogger(); \ - if (logger->isEnabled(level)) \ - { \ - logger->log(level, __VA_ARGS__); \ - } \ - } while (0) - -#define TLLM_LOG_TRACE(...) TLLM_LOG(tensorrt_llm::common::Logger::TRACE, __VA_ARGS__) -#define TLLM_LOG_DEBUG(...) TLLM_LOG(tensorrt_llm::common::Logger::DEBUG, __VA_ARGS__) -#define TLLM_LOG_INFO(...) TLLM_LOG(tensorrt_llm::common::Logger::INFO, __VA_ARGS__) -#define TLLM_LOG_WARNING(...) TLLM_LOG(tensorrt_llm::common::Logger::WARNING, __VA_ARGS__) -#define TLLM_LOG_ERROR(...) TLLM_LOG(tensorrt_llm::common::Logger::ERROR, __VA_ARGS__) -#define TLLM_LOG_EXCEPTION(ex, ...) tensorrt_llm::common::Logger::getLogger()->log(ex, ##__VA_ARGS__) -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh deleted file mode 100644 index a228d3f9fc..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ - -template -struct QuantTypeStaticVals; - -template <> -struct QuantTypeStaticVals -{ - static constexpr float MAX_VAL = 127.f; - static constexpr float MIN_SCALING_FACTOR = 0.f; - static constexpr float MIN_SCALING_FACTOR_RCP = FLT_MAX; -}; - -#ifdef ENABLE_FP8 - -template <> -struct QuantTypeStaticVals<__nv_fp8_e4m3> -{ - static constexpr float MAX_VAL = 448.f; - // Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720 - static constexpr float MIN_SCALING_FACTOR = 1.0f / (448.f * 512.f); - static constexpr float MIN_SCALING_FACTOR_RCP = (448.f * 512.f); -}; - -#endif // ENABLE_FP8 - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h b/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h deleted file mode 100644 index 052d9c8c81..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h +++ /dev/null @@ -1,358 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ - -class QuantMode -{ - // [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py -public: - using BaseType = std::uint32_t; - - explicit constexpr QuantMode(BaseType value) noexcept - : mValue{value} - { - } - - QuantMode() noexcept = default; - - constexpr QuantMode(QuantMode const&) noexcept = default; - - constexpr QuantMode& operator=(QuantMode const& other) noexcept = default; - - static constexpr QuantMode none() noexcept - { - return QuantMode(BaseType(0)); - } - - static constexpr QuantMode int4Weights() noexcept - { - return QuantMode(BaseType(1u) << 0); - } - - static constexpr QuantMode int8Weights() noexcept - { - return QuantMode(BaseType(1u) << 1); - } - - static constexpr QuantMode activations() noexcept - { - return QuantMode(BaseType(1u) << 2); - } - - static constexpr QuantMode perChannelScaling() noexcept - { - return QuantMode(BaseType(1u) << 3); - } - - static constexpr QuantMode perTokenScaling() noexcept - { - return QuantMode(BaseType(1u) << 4); - } - - static constexpr QuantMode perGroupScaling() noexcept - { - return QuantMode(BaseType(1u) << 5); - } - - static constexpr QuantMode int8KvCache() noexcept - { - return QuantMode(BaseType(1u) << 6); - } - - static constexpr QuantMode fp8KvCache() noexcept - { - return QuantMode(BaseType(1u) << 7); - } - - static constexpr QuantMode fp8Qdq() noexcept - { - return QuantMode(BaseType(1u) << 8); - } - - static constexpr QuantMode fp8RowWise() noexcept - { - return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9); - } - - static constexpr QuantMode w4a8QServe() noexcept - { - return QuantMode(BaseType(1u) << 10); - } - - constexpr BaseType value() const noexcept - { - return mValue; - } - - constexpr bool isSet(QuantMode const& mode) const noexcept - { - return (mValue & mode.value()) == mode.value(); - } - - constexpr bool hasInt4Weights() const noexcept - { - return isSet(int4Weights()); - } - - constexpr bool hasInt8Weights() const noexcept - { - return isSet(int8Weights()); - } - - constexpr bool hasActivations() const noexcept - { - return isSet(activations()); - } - - constexpr bool hasPerChannelScaling() const noexcept - { - return isSet(perChannelScaling()); - } - - constexpr bool hasPerTokenScaling() const noexcept - { - return isSet(perTokenScaling()); - } - - constexpr bool hasPerGroupScaling() const noexcept - { - return isSet(perGroupScaling()); - } - - constexpr bool hasStaticActivationScaling() const noexcept - { - return !hasPerTokenScaling(); - } - - constexpr bool hasInt8KvCache() const noexcept - { - return isSet(int8KvCache()); - } - - constexpr bool hasFp8KvCache() const noexcept - { - return isSet(fp8KvCache()); - } - - constexpr bool hasFp8Qdq() const noexcept - { - return isSet(fp8Qdq()); - } - - constexpr bool hasFp8RowWise() const noexcept - { - return isSet(fp8RowWise()); - } - - constexpr bool hasKvCacheQuant() const noexcept - { - return hasInt8KvCache() || hasFp8KvCache(); - } - - static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false, - bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false, - bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false, - bool useW4a8QServe = false) - { - QuantMode quantMode{}; - if (quantizeWeights) - { - if (useInt4Weights) - quantMode += int4Weights(); - else - quantMode += int8Weights(); - } - - if (quantizeActivations) - { - quantMode += activations(); - } - - if (perChannel) - { - quantMode += QuantMode::perChannelScaling(); - } - if (perToken) - { - quantMode += QuantMode::perTokenScaling(); - } - if (perGroup) - { - quantMode += QuantMode::perGroupScaling(); - } - - if (useInt8KvCache) - { - quantMode += int8KvCache(); - } - - if (useFp8KvCache) - { - quantMode += fp8KvCache(); - } - - if (useFp8Qdq) - { - quantMode += fp8Qdq(); - } - - if (useFp8RowWise) - { - quantMode += fp8RowWise(); - } - - if (useW4a8QServe) - { - quantMode += w4a8QServe(); - } - - return quantMode; - } - - static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false) - { - return fromDescription(true, true, perToken, perChannel); - } - - static constexpr QuantMode useQServe(bool perGroup) - { - return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true); - } - - static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false) - { - return fromDescription(true, false, false, false, perGroup, useInt4Weights); - } - - static QuantMode const fromQuantAlgo( - std::optional quantAlgo = std::nullopt, std::optional kvCacheQuantAlgo = std::nullopt) - { - QuantMode quantMode{}; - if (quantAlgo == "W8A16") - { - quantMode = useWeightOnly(false, false); - } - else if (quantAlgo == "W4A16") - { - quantMode = useWeightOnly(true, false); - } - else if (quantAlgo == "W4A16_AWQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W4A8_AWQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W4A8_QSERVE_PER_GROUP") - { - quantMode = useQServe(false); - } - else if (quantAlgo == "W4A8_QSERVE_PER_CHANNEL") - { - quantMode = useQServe(true); - } - else if (quantAlgo == "W4A16_GPTQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL") - { - quantMode = useSmoothQuant(false, true); - } - else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN") - { - quantMode = useSmoothQuant(false, false); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN") - { - quantMode = useSmoothQuant(true, true); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN") - { - quantMode = useSmoothQuant(false, true); - } - else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN") - { - quantMode = useSmoothQuant(true, false); - } - else if (quantAlgo == "FP8") - { - quantMode = fromDescription(false, false, false, false, false, false, false, false, true); - } - else if (quantAlgo == "FP8_ROWWISE") - { - quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true); - } - - if (kvCacheQuantAlgo == "INT8") - { - quantMode += int8KvCache(); - } - else if (kvCacheQuantAlgo == "FP8") - { - quantMode += fp8KvCache(); - } - - return quantMode; - } - - constexpr QuantMode operator+(QuantMode const& other) const noexcept - { - return QuantMode(mValue | other.mValue); - } - - constexpr QuantMode& operator+=(QuantMode const& other) noexcept - { - return *this = *this + other; - } - - constexpr QuantMode operator-(QuantMode const& other) const noexcept - { - return QuantMode(mValue & ~other.mValue); - } - - constexpr QuantMode& operator-=(QuantMode const& other) noexcept - { - return *this = *this - other; - } - - constexpr bool operator==(QuantMode const& other) const noexcept - { - return mValue == other.mValue; - } - - constexpr bool operator!=(QuantMode const& other) const noexcept - { - return !(*this == other); - } - -private: - BaseType mValue{0}; -}; - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh deleted file mode 100644 index c5a4fe0e24..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh +++ /dev/null @@ -1,399 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) -#include -#else -#include -#endif -#include "tensorrt_llm/common/cudaTypeUtils.cuh" -#include -#include -#include -#include -#include - -namespace cg = cooperative_groups; - -namespace tensorrt_llm -{ -namespace common -{ - -template -struct BytesToType; - -template <> -struct BytesToType<1> -{ - using type = uint8_t; -}; - -template <> -struct BytesToType<2> -{ - using type = uint16_t; -}; - -template <> -struct BytesToType<4> -{ - using type = uint32_t; -}; - -template <> -struct BytesToType<8> -{ - using type = uint64_t; -}; - -template <> -struct BytesToType<16> -{ - using type = float4; -}; - -template -__device__ inline void copy(void const* local, void* data) -{ - using T = typename BytesToType::type; - - T const* in = static_cast(local); - T* out = static_cast(data); - *out = *in; -} - -static float constexpr HALF_FLT_MAX = 65504.F; -#define FINAL_MASK 0xffffffff - -template -__inline__ __device__ T warpReduceSum(T val) -{ -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 - return val; -} - -/* Calculate the sum of all elements in a block */ -template -__inline__ __device__ T blockReduceSum(T val) -{ - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - val = warpReduceSum(val); - - if (lane == 0) - shared[wid] = val; - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T) (0.0f); - val = warpReduceSum(val); - - return val; -} - -template -__inline__ __device__ T warpReduceMax(T val) -{ -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); - return val; -} - -/* Calculate the maximum of all elements in a block */ -template -__inline__ __device__ T blockReduceMax(T val) -{ - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - - val = warpReduceMax(val); // get maxx in each warp - - if (lane == 0) // record in-warp maxx by warp Idx - shared[wid] = val; - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; - val = warpReduceMax(val); - - return val; -} - -/* Calculate the maximum of all elements in a block */ -template -__inline__ __device__ T blockAllReduceMax(T val) -{ - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - - val = warpReduceMax(val); // get maxx in each warp - - if (lane == 0) // record in-warp maxx by warp Idx - shared[wid] = val; - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; - val = warpReduceMax(val); - - return val; -} - -template -__inline__ __device__ T warpReduceSumV2(T* val) -{ -#pragma unroll - for (int i = 0; i < NUM; i++) - { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); - } - return (T) (0.0f); -} - -template -__inline__ __device__ T blockReduceSumV2(T* val) -{ - static __shared__ T shared[NUM][33]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduceSumV2(val); - - if (lane == 0) - { -#pragma unroll - for (int i = 0; i < NUM; i++) - { - shared[i][wid] = val[i]; - } - } - - __syncthreads(); - - bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) - { - val[i] = is_mask ? shared[i][lane] : (T) (0.0f); - } - warpReduceSumV2(val); - return (T) 0.0f; -} - -template -__inline__ __device__ T warpReduceMaxV2(T* val) -{ -#pragma unroll - for (int i = 0; i < NUM; i++) - { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); - } - return (T) (0.0f); -} - -template -__inline__ __device__ T blockReduceMaxV2(T* val) -{ - static __shared__ T shared[32][NUM]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - - warpReduceMaxV2(val); // get maxx in each warp - - if (lane == 0) // record in-warp maxx by warp Idx - { -#pragma unroll - for (int i = 0; i < NUM; i++) - { - shared[wid][i] = val[i]; - } - } - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) - { - val[i] = is_mask ? shared[lane][i] : (T) -1e20f; - } - warpReduceMaxV2(val); - - return (T) 0.0f; -} - -template -__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm) -{ - cg::thread_block cta = cg::this_thread_block(); - cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta); - - int const tid = cta.thread_rank(); - int const blockz = blockDim.x; - for (int i = 0; i < NUM; i++) - { -#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) - cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus()); -#else - // TODO Add implementation here - if (threadIdx.x == 0 && blockIdx.x == 0) - { - printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n"); - assert(false); - } -#endif - } - cg::sync(cta); - if (tid == 0) - { -#pragma unroll - for (int i = 0; i < NUM; i++) - { - float beta = 0.0f; - for (int j = 0; j < blockz; j += 32) - { - beta += cgBlockReduceSumElements_shm[i * blockz + j]; - } - element_list[i] = beta; - } - } -} - -template -struct TopK -{ - int p[MAX_K]; // index, being -1 at the tail if the array is not full - T u[MAX_K]; // value in descend order, being -MAX_T_VAL if the element is invalid - - __device__ __forceinline__ void insert(T const elem, int const elem_id) - { - if (elem_id < 0) - { - return; - } - // Condition of updating the array - // 1. array is not full - // 2. elem is greater than the smallest (last) element in the array - // 3. elem is equal to the smallest (last) element in the array but its elem_id is smaller - bool const need_update - = (p[MAX_K - 1] == -1 || elem > u[MAX_K - 1] || elem == u[MAX_K - 1] && elem_id < p[MAX_K - 1]); - if (!need_update) - { - return; - } - // Find suitable index for the new element - int i; - for (i = MAX_K - 2; i >= 0; --i) - { - bool const need_decrease = (p[i] == -1 || elem > u[i] || elem == u[i] && elem_id < p[i]); - if (!need_decrease) - break; - } - // Move elements to correct positions - for (int k = MAX_K - 2; k >= i; --k) - { - p[k + 1] = p[k]; - u[k + 1] = u[k]; - } - p[i] = elem_id; - u[i] = elem; - } - - __device__ __forceinline__ void init() - { - T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; - for (int i = 0; i < MAX_K; i++) - { - p[i] = -1; - u[i] = -MAX_T_VAL; - } - } -}; - -template -__device__ __forceinline__ TopK reduce_topk_op(TopK const& a, TopK const& b) -{ - TopK res = a; - for (int i = 0; i < MAX_K; ++i) - res.insert(b.u[i], b.p[i]); - return res; -} - -template -struct TopK_2 -{ - int p = -1; - T u = -((std::is_same::value) ? HALF_FLT_MAX : FLT_MAX); - - __device__ __forceinline__ void insert(T elem, int elem_id) - { - if (elem > u) - { - u = elem; - p = elem_id; - } - } - - __device__ __forceinline__ void init() - { - u = -((std::is_same::value) ? HALF_FLT_MAX : FLT_MAX); - p = -1; - } -}; - -template -__device__ __forceinline__ TopK_2 reduce_topk_op_2(TopK_2 const& a, TopK_2 const& b) -{ - return a.u > b.u ? a : b; -} - -template -__device__ __forceinline__ T clamp_inf_for_half(float const input) -{ - return input; -} - -template <> -__device__ __forceinline__ half clamp_inf_for_half(float const input) -{ - // clamp inf values to enable fp16 training - return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000); -} - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp deleted file mode 100644 index f1c6f88b43..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/stringUtils.h" -#include "tensorrt_llm/common/assert.h" - -#include -#include -#include -#include -#include - -namespace tensorrt_llm::common -{ - -namespace -{ -std::string vformat(char const* fmt, va_list args) -{ - va_list args0; - va_copy(args0, args); - auto const size = vsnprintf(nullptr, 0, fmt, args0); - if (size <= 0) - return ""; - - std::string stringBuf(size, char{}); - auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args); - - TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno))); - - return stringBuf; -} - -} // namespace - -std::string fmtstr(char const* format, ...) -{ - va_list args; - va_start(args, format); - std::string result = vformat(format, args); - va_end(args); - return result; -}; - -std::unordered_set str2set(std::string const& input, char delimiter) -{ - std::unordered_set values; - if (!input.empty()) - { - std::stringstream valStream(input); - std::string val; - while (std::getline(valStream, val, delimiter)) - { - if (!val.empty()) - { - values.insert(val); - } - } - } - return values; -}; - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h deleted file mode 100644 index 9c5ecde98c..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#if ENABLE_BF16 -#include -#endif // ENABLE_BF16 -#include - -#include // std::make_unique -#include // std::stringstream -#include -#include -#include - -namespace tensorrt_llm::common -{ -#if ENABLE_BF16 -static inline std::basic_ostream& operator<<(std::basic_ostream& stream, __nv_bfloat16 const& val) -{ - stream << __bfloat162float(val); - return stream; -} -#endif // ENABLE_BF16 - -static inline std::basic_ostream& operator<<(std::basic_ostream& stream, __half const& val) -{ - stream << __half2float(val); - return stream; -} - -inline std::string fmtstr(std::string const& s) -{ - return s; -} - -inline std::string fmtstr(std::string&& s) -{ - return s; -} - -#if defined(_MSC_VER) -std::string fmtstr(char const* format, ...); -#else -std::string fmtstr(char const* format, ...) __attribute__((format(printf, 1, 2))); -#endif - -// __PRETTY_FUNCTION__ is used for neat debugging printing but is not supported on Windows -// The alternative is __FUNCSIG__, which is similar but not identical -#if defined(_WIN32) -#define __PRETTY_FUNCTION__ __FUNCSIG__ -#endif - -auto constexpr kDefaultDelimiter = ", "; - -template -inline TStream& arr2outCasted(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter) -{ - out << "("; - if (size > 0) - { - for (size_t i = 0; i < size - 1; ++i) - { - out << static_cast(arr[i]) << delim; - } - out << static_cast(arr[size - 1]); - } - out << ")"; - return out; -} - -template -inline TStream& arr2out(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter) -{ - return arr2outCasted(out, arr, size, delim); -} - -template -inline std::string arr2str(T* arr, size_t size, char const* delim = kDefaultDelimiter) -{ - std::stringstream ss; - return arr2out(ss, arr, size, delim).str(); -} - -template -inline std::string vec2str(std::vector const& vec, char const* delim = kDefaultDelimiter) -{ - return arr2str(vec.data(), vec.size(), delim); -} - -inline bool strStartsWith(std::string const& str, std::string const& prefix) -{ - return str.rfind(prefix, 0) == 0; -} - -/// @brief Split a string into a set of strings using a delimiter -std::unordered_set str2set(std::string const& input, char delimiter); - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp deleted file mode 100644 index b410613d05..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/tllmException.h" -#include "tensorrt_llm/common/stringUtils.h" - -#include -#if !defined(_MSC_VER) -#include -#include -#include -#endif -#include - -namespace tensorrt_llm::common -{ - -namespace -{ -int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2; -} - -#if !defined(_MSC_VER) - -TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) - : std::runtime_error{""} -{ - mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES); - auto const trace = getTrace(); - std::runtime_error::operator=( - std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())}); -} -#else -TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) - : mNbFrames{} - , std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)} -{ -} -#endif - -TllmException::~TllmException() noexcept = default; - -std::string TllmException::getTrace() const -{ -#if defined(_MSC_VER) - return ""; -#else - auto const trace = backtrace_symbols(mCallstack.data(), mNbFrames); - std::ostringstream buf; - for (auto i = 1; i < mNbFrames; ++i) - { - Dl_info info; - if (dladdr(mCallstack[i], &info) && info.dli_sname) - { - auto const clearName = demangle(info.dli_sname); - buf << fmtstr("%-3d %*p %s + %zd", i, VOID_PTR_SZ, mCallstack[i], clearName.c_str(), - static_cast(mCallstack[i]) - static_cast(info.dli_saddr)); - } - else - { - buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace[i]); - } - if (i < mNbFrames - 1) - buf << std::endl; - } - - if (mNbFrames == MAX_FRAMES) - buf << std::endl << "[truncated]"; - - std::free(trace); - return buf.str(); -#endif -} - -std::string TllmException::demangle(char const* name) -{ -#if defined(_MSC_VER) - return name; -#else - std::string clearName{name}; - auto status = -1; - auto const demangled = abi::__cxa_demangle(name, nullptr, nullptr, &status); - if (status == 0) - { - clearName = demangled; - std::free(demangled); - } - return clearName; -#endif -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h deleted file mode 100644 index 47e0e63d3f..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -#define NEW_TLLM_EXCEPTION(...) \ - tensorrt_llm::common::TllmException(__FILE__, __LINE__, tensorrt_llm::common::fmtstr(__VA_ARGS__)) - -namespace tensorrt_llm::common -{ - -class TllmException : public std::runtime_error -{ -public: - static auto constexpr MAX_FRAMES = 128; - - explicit TllmException(char const* file, std::size_t line, std::string const& msg); - - ~TllmException() noexcept override; - - [[nodiscard]] std::string getTrace() const; - - static std::string demangle(char const* name); - -private: - std::array mCallstack{}; - int mNbFrames; -}; - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h b/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h deleted file mode 100644 index 1406e82133..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include - -namespace tensorrt_llm::common -{ - -std::uintptr_t constexpr kCudaMemAlign = 128; - -inline int8_t* alignPtr(int8_t* ptr, uintptr_t to) -{ - uintptr_t addr = (uintptr_t) ptr; - if (addr % to) - { - addr += to - addr % to; - } - return (int8_t*) addr; -} - -constexpr size_t alignSize(size_t size, size_t to) -{ - if ((size % to) != 0U) - { - size += to - size % to; - } - return size; -} - -inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment) -{ - uintptr_t addr = (uintptr_t) ptr; - addr += previousWorkspaceSize; - return alignPtr((int8_t*) addr, alignment); -} - -inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize) -{ - return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign); -} - -inline int8_t* nextWorkspacePtr( - int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign) -{ - uintptr_t curr_offset = offset; - uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment; - int8_t* newptr = size == 0 ? nullptr : base + curr_offset; - offset = next_offset; - return newptr; -} - -inline int8_t* nextWorkspacePtrWithAlignment( - int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign) -{ - return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment); -} - -inline size_t calculateTotalWorkspaceSize( - size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign) -{ - size_t total = 0; - for (int i = 0; i < count; i++) - { - total += workspaces[i]; - if (workspaces[i] % alignment) - { - total += alignment - (workspaces[i] % alignment); - } - } - return total; -} - -}; // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp deleted file mode 100644 index 61a41031bf..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp +++ /dev/null @@ -1,352 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include - -#include -#include -#include - -// Config - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10)) -#define CUTE_ARCH_RED_F16_SM70_ENABLED -#endif - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) -#define CUTE_ARCH_RED_VEC_SM90_ENABLED -#define CUTE_ARCH_RED_BF16_SM90_ENABLED -#endif - -namespace cute -{ - -////////////////////////////////// -// Wrapper around CUDA's atomicAdd -////////////////////////////////// - -template -struct TypedAtomicAdd -{ - using SRegisters = T[1]; - using DRegisters = T[1]; - - CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst) - { - atomicAdd(&dst, src); - } -}; - -template -struct Copy_Traits> -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout::value>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout::value>>>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// -// F16 ADD PTX -////////////////////////////////// - -struct SM70_RED_ADD_NOFTZ_F16 -{ - using SRegisters = uint16_t[1]; - using DRegisters = uint16_t[1]; - - CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) - asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -struct SM70_RED_ADD_NOFTZ_F16x2 -{ - using SRegisters = uint32_t[1]; - using DRegisters = uint32_t[1]; - - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) - asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -struct SM90_RED_ADD_NOFTZ_F16x2_V2 -{ - using SRegisters = uint32_t[2]; - using DRegisters = uint64_t[1]; - - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) - asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -struct SM90_RED_ADD_NOFTZ_F16x2_V4 -{ - using SRegisters = uint32_t[4]; - using DRegisters = uint128_t[1]; - - CUTE_HOST_DEVICE static void copy( - uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) - asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), - "r"(src2), "r"(src3)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// -// BF16 ADD PTX -////////////////////////////////// - -struct SM90_RED_ADD_NOFTZ_BF16 -{ - using SRegisters = uint16_t[1]; - using DRegisters = uint16_t[1]; - - CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// - -struct SM90_RED_ADD_NOFTZ_BF16x2 -{ - using SRegisters = uint32_t[1]; - using DRegisters = uint32_t[1]; - - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// - -struct SM90_RED_ADD_NOFTZ_BF16x2_V2 -{ - using SRegisters = uint32_t[2]; - using DRegisters = uint64_t[1]; - - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// - -struct SM90_RED_ADD_NOFTZ_BF16x2_V4 -{ - using SRegisters = uint32_t[4]; - using DRegisters = uint128_t[1]; - - CUTE_HOST_DEVICE static void copy( - uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), - "r"(src2), "r"(src3)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// - -} // end namespace cute diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h deleted file mode 100644 index 2362da4f7f..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h +++ /dev/null @@ -1,120 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing architecture support for multiply-add operations -*/ - -#pragma once -#include "cutlass_extensions/weight_only_quant_op.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace arch -{ - -// Tag which triggers MMA which will trigger -struct OpMultiplyAddDequantizeInterleavedBToA; - -/* - Below we have extra tags to signal what kind of dequantization we want to do - (per col, scale only fine grained, finegrained with zero). This still lets us - the existing template infrastructure (incl. that in CUTLASS). However, we - split out the template below into OpMultiplyAddDequantizeInterleavedBToA along - with the quantization op before instantiating the GEMM pieces. - - Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of - code we need to duplicate. - */ -struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; -struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; -struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; - -// The default just forwards the original operator -template -struct TagOperator -{ - using TaggedOperator = MmaOp; -}; - -// Specializations below attach more information to the operator -template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; -}; - -template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; -}; - -template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; -}; - -// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original -// operator + the extra information. If no extra info was tagged, the dequant op per column scaling -// as a default. -template -struct DetagOperator -{ - using Operator = TaggedMmaOp; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; -}; - -template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; -}; - -template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; -}; - -template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; -}; - -} // namespace arch -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h deleted file mode 100644 index c83a9a074d..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include "cutlass/device_kernel.h" -#include "tensorrt_llm/common/cudaUtils.h" - -namespace tensorrt_llm -{ -namespace cutlass_extensions -{ - -template -inline int compute_occupancy_for_kernel() -{ - - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size > (48 << 10)) - { - cudaFuncAttributes attr; - int device = 0; - int max_smem_per_block = 0; - tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); - tensorrt_llm::common::check_cuda_error( - cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - if constexpr (enable_cutlass_3x) - { - tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); - } - else - { - tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel)); - } - if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) - { - // This should mean that - // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) - // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this - // configuration. - return 0; - } - - if constexpr (enable_cutlass_3x) - { - tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute( - cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - else - { - tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute( - cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - } - - int max_active_blocks = -1; - if constexpr (enable_cutlass_3x) - { - tensorrt_llm::common::check_cuda_error( - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel, - 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); - } - else - { - tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); - } - - return max_active_blocks; -} - -} // namespace cutlass_extensions -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp deleted file mode 100644 index bba25ec23a..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp +++ /dev/null @@ -1,550 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/fast_math.h" - -#include "cute/numeric/numeric_types.hpp" -#include "cute/tensor.hpp" -#include "cutlass/trace.h" - -#include "cutlass_extensions/arch/copy_red_global.hpp" -#include "cutlass_extensions/util/gather_tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace epilogue -{ -namespace collective -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class EpilogueMoeFusedFinalize -{ -public: - using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized; - using DispatchPolicy = PtrArrayNoSmemWarpSpecialized; - - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementIntermediate = typename ThreadEpilogueOp::ElementD; - - using ElementC = typename ThreadEpilogueOp::ElementC; - using StrideC = StrideC_; - using InternalStrideC = cute::remove_pointer_t; - using ElementD = ElementD_; - using StrideD = StrideD_; - using InternalStrideD = cute::remove_pointer_t; - - static_assert(!is_same_v, "Stride C must be a pointer"); - static_assert(is_same_v, "Stride D must not be a pointer"); - - using CopyAtomR2S = Copy_Atom; - using CopyAtomS2R = Copy_Atom; - using CopyAtomR2G = Copy_Atom; - static constexpr int AlignmentD = CopyAtomR2G::NumValSrc; - - using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{})); - - constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); - - struct SharedStorage - { - alignas(SmemAlignmentD) cute::ArrayEngine> smem_D; - }; - - struct TensorMapStorage - { - }; - - struct Arguments - { - typename ThreadEpilogueOp::Params thread{}; - ElementC const** ptr_C{}; - StrideC dC{}; - ElementD* ptr_D{}; - StrideD dD{}; - ElementBias const* ptr_bias; - StrideBias dBias{}; - ElementScale const* ptr_scale; - StrideScale dScale{}; - int64_t const* group_offset{}; - int32_t const* scatter_index{}; - cutlass::FastDivmod num_rows_in_final_output; - }; - - using Params = Arguments; - - // - // Methods - // - - template - static constexpr Params to_underlying_arguments( - ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace) - { - return args; - } - - template - static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0) - { - return 0; - } - - template - static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, - void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) - { - return cutlass::Status::kSuccess; - } - - template - CUTLASS_HOST_DEVICE static bool can_implement( - [[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args) - { - bool implementable = true; - if (problem_shape.is_host_problem_shape_available()) - { - // Check alignment for all problem sizes - for (int i = 0; i < problem_shape.groups(); i++) - { - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); - auto [M, N, K, L] = problem_shape_MNKL; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(M, N, L), InternalStrideD{}); - } - } - - if (!implementable) - { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global " - "reduction instruction.\n"); - } - return implementable; - } - - CUTLASS_HOST_DEVICE - EpilogueMoeFusedFinalize(Params const& params_) - : params(params_) - { - } - - CUTLASS_DEVICE - bool is_source_needed() - { - // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. - return params.ptr_C != nullptr - && (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0); - } - - template - CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, cute::Tensor const& accumulators, TiledMma tiled_mma, - ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf) - { - using namespace cute; - using X = Underscore; - - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - - auto synchronize = [&]() - { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - auto mma_tile_m = tile_size<0>(tiled_mma); - auto mma_tile_n = tile_size<1>(tiled_mma); - auto epi_tile_m = size<0>(EpilogueTile{}); - auto epi_tile_n = size<1>(EpilogueTile{}); - - CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); - CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); - - // Batches are managed by using appropriate pointers to C and D matrices - int32_t const mock_L = 1; - int32_t const mock_l_coord = 0; - - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - - // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. - // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, - // we get the correct alpha/beta values for the current batch/group using group index. - ThreadEpilogueOp epilogue_op(params.thread, l_coord); - - SharedStorage& storage = *reinterpret_cast(smem_buf); - - Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{}); - Tensor sD = as_position_independent_swizzle_tensor(sD_); - - // Function to scatter output rows - auto& num_rows = params.num_rows_in_final_output; - auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord])); - auto get_scatter_idx = [&](auto i) - { - auto scatter = read_scatter_map(i); - int quot, rem; - num_rows(quot, rem, scatter); - return rem; - }; - - // Represent the full output tensor - ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr; - auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{}; - Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l) - Tensor mD_mnl = make_gather_tensor( - make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l) - - // Use fake shape for bias, it doesn't matter - bool const is_bias_needed = params.ptr_bias != nullptr; - Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias); - Tensor mScale_mnl = make_tensor( - make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale); - - Tensor gC_mnl - = local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl - = local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) - - Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - Tensor gBias_mnl - = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gScale_mnl - = local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N) - Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N) - - Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Get the smallest tiled copy we can use to retile the accumulators - TiledCopy tiled_copy_C_atom - = make_tiled_copy_C_atom(Copy_Atom{}, tiled_mma); - TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom); - - auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); - Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) - Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N) - Tensor tRS_rD = make_tensor(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N) - - // Make a tiled copy vectorized along major direction of D - auto tiled_s2r = [&]() - { - if constexpr (cutlass::gemm::detail::is_k_major()) - { - constexpr int NumThreadsMajor = epi_tile_n / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; - return make_tiled_copy(CopyAtomS2R{}, - Layout, Int>, Stride, _1>>{}, - Layout>>{}); - } - else if constexpr (cutlass::gemm::detail::is_mn_major()) - { - constexpr int NumThreadsMajor = epi_tile_m / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; - return make_tiled_copy(CopyAtomS2R{}, - Layout, Int>, Stride<_1, Int>>{}, - Layout, _1>>{}); - } - else - { - static_assert(cute::is_void_v, "Unsupported D gmem layout."); - } - }(); - - auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); - Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - - // Allocate intermediate registers for a single subtile - Tensor tSR_rD = make_tensor(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rD_final = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rC = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rBias = make_tensor(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rScale = make_tensor(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - - // Make an identity coordinate tensor for predicating our output MN tile - Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); - Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - - // epilogue subtile loop - CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) - { - CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) - { - int mma_m = (epi_m * epi_tile_m) / mma_tile_m; - int mma_n = (epi_n * epi_tile_n) / mma_tile_n; - Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n); - - int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); - int r2s_v = epi_n_in_mma * size(tRS_rD); - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v) - { - tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v); - } - - copy(tiled_r2s, tRS_rD, tRS_sD); - synchronize(); - - copy(tiled_s2r, tSR_sD, tSR_rD); - synchronize(); - - Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n); - Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n); - Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n); - Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n); - Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n); - - if (epilogue_op.is_source_needed()) - { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_rD); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_rD); ++n) - { - if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) - { - copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n)); - if (is_bias_needed) - { - copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); - } - copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rD); ++i) - { - auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n)); - if (is_bias_needed) - { - epi_value += static_cast(tSR_rBias(i, m, n)); - } - tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); - } - copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); - } - } - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_rD); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_rD); ++n) - { - if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) - { - if (is_bias_needed) - { - copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); - } - copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rD); ++i) - { - auto epi_value = epilogue_op(tSR_rD(i, m, n)); - if (is_bias_needed) - { - epi_value += static_cast(tSR_rBias(i, m, n)); - } - tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); - } - copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); - } - } - } - } - } - } - } - -private: - Params params; -}; - -namespace detail -{ - -template -constexpr auto get_vectorized_atomic_add_op() -{ - using namespace cute; - - auto constexpr MaxVecSize = size(MaxVec{}); - - if constexpr (is_same_v) - { - if constexpr (MaxVecSize >= 8) - { - return SM90_RED_ADD_NOFTZ_F16x2_V4{}; - } - else if constexpr (MaxVecSize >= 4) - { - return SM90_RED_ADD_NOFTZ_F16x2_V2{}; - } - else if constexpr (MaxVecSize >= 2) - { - return SM70_RED_ADD_NOFTZ_F16x2{}; - } - else - { - return SM70_RED_ADD_NOFTZ_F16{}; - } - } - else if constexpr (is_same_v) - { - if constexpr (MaxVecSize >= 8) - { - return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; - } - else if constexpr (MaxVecSize >= 4) - { - return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; - } - else if constexpr (MaxVecSize >= 2) - { - return SM90_RED_ADD_NOFTZ_BF16x2{}; - } - else - { - return SM90_RED_ADD_NOFTZ_BF16{}; - } - } - else - { - // non-vectorized atomic add for all other types until supported - return TypedAtomicAdd{}; - } -} - -} // namespace detail - -template -struct EpilogueMoeFusedFinalizeBuilder -{ - - // assuming cooperative kernel schedule - using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{})); - using EpilogueTile = Shape<_128, EpiTileN>; - - // Output of linear combination is ElementCompute instead of ElementD - // since we will be doing more computate on it, no need to cast yet. - using ThreadEpilogueOp - = cutlass::epilogue::thread::LinearCombination; - - using SmemLayoutAtomD - = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()); - using CopyAtomR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator()); - using CopyAtomS2R = DefaultCopy; - using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op()); - - template - struct Sm90TmaWarpSpecializedAdapterWithSmemStorage : detail::Sm90TmaWarpSpecializedAdapter - { - // We need to override this one using declaration because otherwise we double up on the smem - using TensorMapStorage = typename EpilogueOp::TensorMapStorage; - - using Base = detail::Sm90TmaWarpSpecializedAdapter; - - CUTLASS_HOST_DEVICE - Sm90TmaWarpSpecializedAdapterWithSmemStorage( - typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors) - : Base(params) - { - } - - // These functions depend on the type of TensorMapStorage - template - CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t next_batch) - { - } - - template - CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate) - { - } - }; - - using CollectiveOp = Sm90TmaWarpSpecializedAdapterWithSmemStorage< - EpilogueMoeFusedFinalize>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h deleted file mode 100644 index f3c622b88a..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h +++ /dev/null @@ -1,105 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing linear combination with a maximum operation used by epilogues. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/epilogue/thread/linear_combination_generic.h" -#include "cutlass/epilogue/thread/scale_type.h" -#include "cutlass/functional.h" -#include "cutlass/half.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace epilogue -{ -namespace thread -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -__forceinline__ __device__ float copysignf_pos(float a, float b) -{ - float r; - r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); - return r; -} - -__forceinline__ __device__ float tanh_opt(float x) -{ -#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) - float const exp_val = -1.f * fabs(2 * x); - return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); -#else - return fast_tanh(x); -#endif -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -template <> -struct GELU_taylor -{ - static bool const kIsHeavy = true; - - CUTLASS_DEVICE - float operator()(float const& z) const - { - - float k0 = float(0.7978845608028654); - float k1 = float(0.044715); - - return float(cutlass::constants::half() * z - * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } - - using Params = LinearCombinationGenericParams; - - CUTLASS_DEVICE - float operator()(float const& scalar, Params const& params_) const - { - return this->operator()(scalar); - } -}; - -} // namespace thread -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h deleted file mode 100644 index d3d4d0a45a..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h +++ /dev/null @@ -1,352 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. - - original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h - -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/arch/memory.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/numeric_conversion.h" -#include "tensorrt_llm/common/quantization.h" - -namespace tk = tensorrt_llm::common; - -namespace cutlass -{ -namespace epilogue -{ -namespace threadblock -{ - -template -class EpilogueVisitorPerRowPerCol -{ -public: - using ThreadblockShape = ThreadblockShape_; - static int const kThreadCount = ThreadCount; - - using ScaleTileIterator = ScaleTileIterator_; - using OutputTileIterator = OutputTileIterator_; - using ElementwiseFunctor = ElementwiseFunctor_; - - static int const kIterations = OutputTileIterator::kIterations; - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - - using ElementOutput = typename OutputTileIterator::Element; - using LayoutOutput = cutlass::layout::RowMajor; - using ElementAccumulator = ElementAccumulator_; - - using AlphaScaleElementType = typename ScaleTileIterator::Element; - - using ElementCompute = ElementCompute_; - using AccumulatorFragment = Array; - using ComputeFragment = Array; - using OutputVector = Array; - - static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; - static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); - - /// Argument structure - struct Arguments - { - - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; - - // - // Methods - // - Arguments() - : batch_stride_alpha(0) - , batch_stride_C(0) - , batch_stride_D(0) - { - } - - Arguments(typename ElementwiseFunctor::Params elementwise_) - : elementwise(elementwise_) - , batch_stride_alpha(0) - , batch_stride_C(0) - , batch_stride_D(0) - { - } - - Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, - int64_t batch_stride_C_, int64_t batch_stride_D_) - : elementwise(elementwise_) - , batch_stride_alpha(batch_stride_alpha_) - , batch_stride_C(batch_stride_C_) - , batch_stride_D(batch_stride_D_) - { - } - }; - - struct Params - { - - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Arguments const& args) - : elementwise(args.elementwise) - , batch_stride_alpha(args.batch_stride_alpha) - , batch_stride_C(args.batch_stride_C) - , batch_stride_D(args.batch_stride_D) - { - } - }; - - /// Shared storage - struct SharedStorage - { - }; - -private: - Params const& params_; - SharedStorage& shared_storage_; - MatrixCoord extent_; - MatrixCoord extent_real_; - ElementwiseFunctor elementwise_; - - bool const per_token_quant_; - bool const per_channel_quant_; - - AlphaScaleElementType* ptr_alpha_row_; - AlphaScaleElementType* ptr_alpha_col_; - ScaleTileIterator iterator_alpha_col_; - OutputTileIterator iterator_C_; - OutputTileIterator iterator_D_; - - AlphaScaleElementType element_alpha_row_ = 1.0f; - AlphaScaleElementType element_alpha_col_ = 1.0f; - typename ScaleTileIterator::Fragment fragment_alpha_col_; - typename OutputTileIterator::Fragment fragment_C_; - typename OutputTileIterator::Fragment fragment_D_; - - ElementAccumulator beta_; - - int column_offset_; - - MatrixCoord thread_offset_; - -public: - CUTLASS_DEVICE - EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, - cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, - typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C, - typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row, - AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, - typename OutputTileIterator::Element* ptr_D, - cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0, - cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) - : params_(params) - , shared_storage_(shared_storage) - , extent_(problem_size) - , elementwise_(params.elementwise) - , per_token_quant_(quant_option.hasPerTokenScaling()) - , per_channel_quant_(quant_option.hasPerChannelScaling()) - , ptr_alpha_row_(ptr_alpha_row) - , ptr_alpha_col_(ptr_alpha_col) - , iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset) - , iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset) - , iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset) - , extent_real_(problem_size_real) - { - beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); - - if (beta_ == ElementAccumulator()) - { - iterator_C_.clear_mask(); - } - - if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) - { - element_alpha_col_ = *ptr_alpha_col_; - } - - if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) - { - element_alpha_row_ = *ptr_alpha_row_; - } - } - - /// Helper to indicate split-K behavior - CUTLASS_DEVICE - void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme - int split_k_slices) - { ///< Total number of split-K slices - } - - /// Called to set the batch index - CUTLASS_DEVICE - void set_batch_index(int batch_idx) - { - iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); - iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); - iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); - } - - /// Called at the start of the epilogue just before iterating over accumulator slices - CUTLASS_DEVICE - void begin_epilogue() - { - if (per_channel_quant_) - { - iterator_alpha_col_.load(fragment_alpha_col_); - } - } - - /// Called at the start of one step before starting accumulator exchange - CUTLASS_DEVICE - void begin_step(int step_idx) - { - fragment_D_.clear(); - fragment_C_.clear(); - - if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) - { - iterator_C_.load(fragment_C_); - ++iterator_C_; - } - } - - /// Called at the start of a row - CUTLASS_DEVICE - void begin_row(int row_idx) - { - // load alpha_row in begin_step only when per token(row) scaling is used - if (per_token_quant_) - { - int thread_offset_row - = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); - - arch::global_load( - element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); - } - } - - /// Called after accumulators have been exchanged for each accumulator vector - CUTLASS_DEVICE - void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) - { - - NumericArrayConverter source_converter; - - ComputeFragment result = source_converter(accum); - if (per_channel_quant_) - { - ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; - result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); - } - else - { - result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); - } - - // Convert to the output - NumericArrayConverter output_converter; - OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; - output = output_converter(result); - } - - /// Called at the end of a row - CUTLASS_DEVICE - void end_row(int row_idx) {} - - /// Called after all accumulator elements have been visited - CUTLASS_DEVICE - void end_step(int step_idx) - { - - iterator_D_.store(fragment_D_); - ++iterator_D_; - } - - /// Called after all steps have been completed - CUTLASS_DEVICE - void end_epilogue() {} - -private: - CUTLASS_DEVICE - ComputeFragment per_token_channel_scale_accumulator_( - ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) - { - - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) - { - result[i] = accum[i] * (scale_col[i] * scale_row); - } - - return result; - } - - CUTLASS_DEVICE - ComputeFragment per_token_scale_accumulator_( - ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) - { - - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) - { - result[i] = accum[i] * (scale_col * scale_row); - } - - return result; - } -}; - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h deleted file mode 100644 index 6f26d79017..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +++ /dev/null @@ -1,282 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. - - original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h - -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/platform/platform.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/epilogue/thread/linear_combination_clamp.h" -#include "cutlass/epilogue/thread/linear_combination_gelu.h" -#include "cutlass/epilogue/thread/linear_combination_hardswish.h" -#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -#include "cutlass/epilogue/thread/linear_combination_relu.h" -#include "cutlass/epilogue/thread/linear_combination_relu0.h" -#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" - -#include "cutlass/epilogue/thread/conversion_op.h" -#include "cutlass/epilogue/thread/reduction_op.h" - -#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" - -#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" -#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" -#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" - -#include "cutlass/epilogue/threadblock/epilogue.h" -#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" - -#include "cutlass/layout/permute.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace epilogue -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -namespace detail -{ - -/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. -template -struct DefaultIteratorsTensorOp -{ - using WarpTileIterator - = cutlass::epilogue::warp::TileIteratorTensorOpMixed; - - using SharedLoadIterator - = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; - - static int const kFragmentsPerIteration = 2; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load output tile from shared memory in epilogue. -/// -/// Satisfies: ReadableTileIterator -/// -template -class SharedLoadIteratorMixed -{ -public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = int32_t; - - using Layout = layout::RowMajor; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - - static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; - - static int const kThreads = ThreadMap::kThreads; - - /// Fragment object - using Fragment = Array; - - /// Memory access size - using AccessType = AlignedArray; - - /// Vector type used for SMEM loads - using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), - const_min(16, kAlignment)>; - - static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; - -private: - // - // Data members - // - - /// Byte-level pointer - LoadType const* pointers_[kLoadsPerAccess]; - - /// Stride along adjacent rows in units of LoadType - int stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - SharedLoadIteratorMixed(TensorRef ref, int thread_idx) - : stride_((ref.stride(0) / LoadType::kElements)) - { - - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); - - // Initialize pointers - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] = reinterpret_cast(ref.data()); - - int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; - int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; - - col_idx += (bank_offset + i) % kLoadsPerAccess; - - pointers_[i] += thread_offset.row() * stride_ + col_idx; - } - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] += pointer_offset / LoadType::kElements; - } - } - - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const& offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] - += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const - { - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) - { - - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) - { - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) - { - - int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ - + group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_ - + pointer_offset / LoadType::kElements; - - int frag_row_idx - = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - LoadType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) - { - - int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kLoadsPerAccess; ++v) - { - - int vector_idx - = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); - - LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; - - frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; - } - } - } - } - } - } - - /// Loads a fragment - CUTLASS_DEVICE - void load(Fragment& frag) const - { - - load_with_pointer_offset(frag, 0); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h deleted file mode 100644 index 233d633a82..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h +++ /dev/null @@ -1,141 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * @file epilogue_helpers.h - * - * This file includes types for the epilogues. The empty structs exist so we can signal to template - * code the type of epilogue we want to run, and let the underlying code specify the details such as - * element types, accumulator type and elements per vector access. - * - */ - -#pragma once - -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/epilogue/thread/linear_combination_generic.h" -#include "cutlass/epilogue/thread/linear_combination_relu.h" -#include "cutlass/epilogue/thread/linear_combination_silu.h" -#include "cutlass_extensions/epilogue/thread/fused_activations.h" -#include - -namespace tensorrt_llm -{ -namespace cutlass_extensions -{ - -struct EpilogueOpBiasSilu -{ -}; - -struct EpilogueOpBiasReLU -{ -}; - -struct EpilogueOpBiasFtGelu -{ -}; - -struct EpilogueOpBias -{ -}; - -struct EpilogueOpDefaultSilu -{ -}; - -struct EpilogueOpDefaultReLU -{ -}; - -struct EpilogueOpDefaultFtGelu -{ -}; - -struct EpilogueOpDefault -{ -}; - -template -struct Epilogue -{ - static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); -}; - -constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationSilu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationRelu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationGeneric; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombination; -}; - -constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationSilu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationRelu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationGeneric; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombination; -}; - -} // namespace cutlass_extensions -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl deleted file mode 100644 index 593eca06e3..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl +++ /dev/null @@ -1,221 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/arch/mma.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/gemm/collective/builders/sm90_common.inl" - -// SM90 Collective Builders should be used only starting CUDA 12.0 -#if (__CUDACC_VER_MAJOR__ >= 12) -#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail -{ - -// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. -template -constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout stage_count) -{ - // 32 bytes to account for barriers etc. - constexpr int stage_barrier_bytes = 32; - constexpr int a_bits = static_cast(sizeof_bits::value); - constexpr int b_bits = static_cast(sizeof_bits::value); - constexpr int stage_bytes = [&]() -> int - { - if constexpr (SwapAB) - { - return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 - + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes; - } - else - { - return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 - + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes; - } - }(); - - return (CapacityBytes - carveout_bytes) / stage_bytes; -} - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA_TMA_WS_SS -template class Activation, bool SwapAB> -struct CollectiveBuilderGated - || cute::is_same_v - || cute::is_same_v - || cute::is_same_v) &¬ detail:: - is_use_rmem_A()>> -{ - static_assert(is_static::value); - static_assert(is_static::value); -#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); -#endif - static_assert(detail::is_aligned(), - "Should meet TMA alignment requirement\n"); - - static constexpr bool IsArrayOfPointersGemm - = (cute::is_same_v); - static constexpr bool IsFP8Input = detail::is_input_fp8(); - static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), - "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); - - // For fp32 types, map to tf32 MMA value type - using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; - using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; - - static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); - static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - - using AtomLayoutMNK = cute::conditional_t - || IsArrayOfPointersGemm, - Layout>, Layout>>; - - using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector(), - AtomLayoutMNK{})); - - using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); - using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); - - using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - - static constexpr int PipelineStages - = detail::compute_stage_count_or_override_gated(StageCountType{}); - using DispatchPolicy = cute::conditional_t, - /* For FP8 use a separate mainloop compared to other datatypes */ - cute::conditional_t, - MainloopSm90TmaGmmaWarpSpecialized>>; - - using SmemCopyAtomA = void; - using SmemCopyAtomB = void; - - using CollectiveOp = CollectiveMmaGated, - ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA_TMA_WS_FP8_FAST_ACCUM_SS -template class Activation, bool SwapAB> -struct CollectiveBuilderGated - || cute::is_same_v - || cute::is_same_v - || cute::is_same_v>> -{ - static_assert(is_static::value); - static_assert(is_static::value); - static_assert(detail::is_aligned(), - "Not meet TMA alignment requirement yet\n"); - static_assert( - detail::is_input_fp8(), "Only FP8 datatypes are compatible with these kernel schedules\n"); - // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder - static_assert(!detail::is_use_rmem_A(), - "Not supported for fp8 non-TN warp specialized kernels yet\n"); -#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); -#endif - - static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); - static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - - static constexpr bool IsArrayOfPointersGemm - = (cute::is_same_v); - using AtomLayoutMNK - = cute::conditional_t - || IsArrayOfPointersGemm, - Layout>, Layout>>; - - using TiledMma = decltype(cute::make_tiled_mma( - cute::GMMA::ss_op_selector(), - AtomLayoutMNK{})); - - using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); - using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); - - using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - - static constexpr int PipelineStages - = detail::compute_stage_count_or_override_gated(StageCountType{}); - using DispatchPolicy = cute::conditional_t, - MainloopSm90TmaGmmaWarpSpecialized>; - - using SmemCopyAtomA = void; - using SmemCopyAtomB = void; - - using CollectiveOp = CollectiveMmaGated, - ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp deleted file mode 100644 index 2f2422c991..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp +++ /dev/null @@ -1,58 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp" - -namespace cutlass::gemm::collective -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template class Activation, - bool SwapAB = false, class Enable = void> -struct CollectiveBuilderGated -{ - static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl" -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp deleted file mode 100644 index d850f36df5..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp +++ /dev/null @@ -1,59 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/detail/dependent_false.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template class Activation, bool SwapAB = false> -struct CollectiveMmaGated -{ - static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp" -#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp" -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp deleted file mode 100644 index dcba6ee637..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp +++ /dev/null @@ -1,642 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/arch/cluster_sm90.hpp" -#include "cute/arch/copy_sm90.hpp" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" - -#include "cute/algorithm/functional.hpp" -#include "cute/algorithm/gemm.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/numeric/arithmetic_tuple.hpp" -#include "cute/tensor_predicate.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/trace.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective -{ -using namespace cute; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// WarpSpecialized Mainloop -template class Activation_, bool SwapAB_> -struct CollectiveMmaGated, TileShape_, - ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, - GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> -{ - static constexpr bool isGated = true; - static constexpr bool SwapAB = SwapAB_; - - // - // Type Aliases - // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized; - using TileShape = TileShape_; - using ElementA = ElementA_; - using StrideA = StrideA_; - using ElementB = ElementB_; - using StrideB = StrideB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - using Activation = Activation_; - - using ElementAux = cute::conditional_t; - using ValTypeAux = cute::conditional_t; - - using MainloopPipeline = cutlass::PipelineTmaAsync; - using PipelineState = cutlass::PipelineState; - - using PipelineParams = typename MainloopPipeline::Params; - - static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert( - (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert( - (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert( - (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert( - (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - using SmemLayoutAux = cute::conditional_t; - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); - static_assert(cute::is_base_of::value - && cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - static_assert( - cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert( - cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - - // TMA converts f32 input to tf32 when copying from GMEM to SMEM - // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. - static constexpr bool ConvertF32toTF32A = cute::is_same_v; - static constexpr bool ConvertF32toTF32B = cute::is_same_v; - using InternalElementA = cute::conditional_t>>; - using InternalElementB = cute::conditional_t>>; - using InternalElementAux = cute::conditional_t; - - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128> - { - cute::array_aligned> smem_A; - cute::array_aligned> smem_B; - cute::array_aligned> smem_Aux; - } tensors; - - using PipelineStorage = typename MainloopPipeline::SharedStorage; - PipelineStorage pipeline; - }; - - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - // Host side kernel arguments - struct Arguments - { - ElementA const* ptr_A; - StrideA dA; - ElementB const* ptr_B; - StrideB dB; - float scale_d0 = 1.0f; - float scale_d1 = 1.0f; - uint32_t mma_promotion_interval = 4; - }; - - // Device side kernel params - struct Params - { - // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, - make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), - SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any - // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, - make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), - SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any - using TMA_Aux = cute::conditional_t; - TMA_A tma_load_a; - TMA_B tma_load_b; - TMA_Aux tma_load_aux; - float scale_d0 = 1.0f; - float scale_d1 = 1.0f; - }; - - // - // Methods - // - - template - static constexpr Params to_underlying_arguments( - ProblemShape const& problem_shape, Arguments const& args, void* workspace) - { - (void) workspace; - - // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - auto ptr_A = reinterpret_cast(args.ptr_A); - auto ptr_B = reinterpret_cast(args.ptr_B); - - Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); - Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); - typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, - SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any - typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, - SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - - if constexpr (SwapAB) - { - auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); - Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); - typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, - SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; - } - else - { - auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); - Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); - typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, - SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; - } - } - - template - static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) - { - constexpr int tma_alignment_bits = 128; - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - bool implementable = true; - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); - - if (!implementable) - { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - } - return implementable; - } - - static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr int K_PIPE_MMAS = 1; - static constexpr uint32_t TmaTransactionBytes - = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 - + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 - + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) - / 8; - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); - } - - /// Set up the data needed by this collective for load and mma. - /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// Returned tuple must contain at least two elements, with the first two elements being: - /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) - /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) - /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) - /// The rest of the tensors can be specified as needed by this collective. - template - CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const - { - using X = Underscore; - // Separate out problem shape for convenience - auto [M, N, K, L] = problem_shape_MNKL; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) - Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) - - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) - - if constexpr (SwapAB) - { - Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) - Tensor gAux_xkl - = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); - } - else - { - Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) - Tensor gAux_xkl - = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) - return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Producer Perspective - template - CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - cute::tuple const& load_inputs, BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) - { - int lane_predicate = cute::elect_one_sync(); - - if (lane_predicate) - { - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); - - // - // Prepare the TMA loads for A and B - // - - constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); - uint2 cluster_local_block_id - = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - Tensor gAux_xkl = get<2>(load_inputs); - - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); - auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) - : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); - // Partition the inputs based on the current block coordinates. - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) - Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); - - // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) - - Tensor tAuxgAux = block_tma_aux.partition_S(gAux); - Tensor tAuxsAux = block_tma_aux.partition_D(sAux); - - uint16_t mcast_mask_a = 0; - uint16_t mcast_mask_b = 0; - uint16_t mcast_mask_aux = 0; - - // Issue TmaLoads - // Maps the tile -> block, value - if constexpr (cute::is_same_v) - { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) - { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); - } - } - - if constexpr (cute::is_same_v) - { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) - { - mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); - } - } - - if constexpr (SwapAB) - { - mcast_mask_aux = mcast_mask_a; - } - else - { - mcast_mask_aux = mcast_mask_b; - } - - // Mainloop - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - // LOCK smem_pipe_write for _writing_ - pipeline.producer_acquire(smem_pipe_write); - - // - // Copy gmem to smem for *k_tile_iter - // - - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); - - int write_stage = smem_pipe_write.index(); - copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), - tAsA(_, _, _, write_stage)); - copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), - tBsB(_, _, _, write_stage)); - copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), - tAuxsAux(_, _, _, write_stage)); - ++k_tile_iter; - - // Advance smem_pipe_write - ++smem_pipe_write; - } - } - } - - /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) - { - int lane_predicate = cute::elect_one_sync(); - - // Issue the epilogue waits - if (lane_predicate) - { - /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all - * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was - * still inverted from make_producer_start_state - */ - pipeline.producer_tail(smem_pipe_write); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Consumer Perspective - template - CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, - FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, - Params const& mainloop_params) - { - static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::rank(SmemLayoutAux{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); - - // - // Define C accumulators and A/B partitioning - // - - TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - - // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) - - auto tCsAux = [&]() -> auto - { - if constexpr (SwapAB) - { - return thread_mma.partition_A(sAux); - } - else - { - return thread_mma.partition_B(sAux); - } - }(); - auto tCrAux = [&]() -> auto - { - if constexpr (SwapAB) - { - return thread_mma.make_fragment_A(tCsAux); - } - else - { - return thread_mma.make_fragment_B(tCsAux); - } - }(); - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - if constexpr (SwapAB) - { - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE - } - else - { - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE - } - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE - - // - // PIPELINED MAIN LOOP - // - static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); - - // We release buffers to producer warps(dma load) with some mmas in flight - PipelineState smem_pipe_release = smem_pipe_read; - - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - - warpgroup_fence_operand(accum0); - warpgroup_fence_operand(accum1); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - int read_stage = smem_pipe_read.index(); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) - { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); - if constexpr (SwapAB) - { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); - } - else - { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); - } - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - - warpgroup_commit_batch(); - - ++smem_pipe_read; - } - - warpgroup_fence_operand(accum0); - warpgroup_fence_operand(accum1); - // Mainloop GMMAs - k_tile_count -= prologue_mma_count; - - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - // - // Compute on k_tile - // - - int read_stage = smem_pipe_read.index(); - warpgroup_fence_operand(accum0); - warpgroup_fence_operand(accum1); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) - { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); - if constexpr (SwapAB) - { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); - } - else - { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); - } - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_wait(); - warpgroup_fence_operand(accum0); - warpgroup_fence_operand(accum1); - - // UNLOCK smem_pipe_release, done _computing_ on it - pipeline.consumer_release(smem_pipe_release); - - // Advance smem_pipe_read and smem_pipe_release - ++smem_pipe_read; - ++smem_pipe_release; - } - - warpgroup_fence_operand(accum0); - warpgroup_fence_operand(accum1); - } - - /// Perform a Consumer Epilogue to release all buffers - CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) - { - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - k_tile_count -= prologue_mma_count; - - smem_pipe_release.advance(k_tile_count); - - // Wait on all GMMAs to complete - warpgroup_wait<0>(); - - for (int count = 0; count < prologue_mma_count; ++count) - { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - ++smem_pipe_release; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp deleted file mode 100644 index 72c1adf293..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp +++ /dev/null @@ -1,665 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cute/arch/cluster_sm90.hpp" -#include "cute/arch/copy_sm90.hpp" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" - -#include "cute/algorithm/functional.hpp" -#include "cute/algorithm/gemm.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/numeric/arithmetic_tuple.hpp" -#include "cute/tensor_predicate.hpp" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/gemm/collective/fp8_accumulation.hpp" -#include "cutlass/trace.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective -{ -using namespace cute; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// WarpSpecialized Mainloop -template class Activation_, bool SwapAB_> -struct CollectiveMmaGated, TileShape_, - ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, - GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> -{ - static constexpr bool isGated = true; - static constexpr bool SwapAB = SwapAB_; - - // - // Type Aliases - // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8; - using TileShape = TileShape_; - using ElementA = ElementA_; - using StrideA = StrideA_; - using ElementB = ElementB_; - using StrideB = StrideB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - using Activation = Activation_; - - using ElementAux = cute::conditional_t; - using ValTypeAux = cute::conditional_t; - - using MainloopPipeline = cutlass::PipelineTmaAsync; - using PipelineState = cutlass::PipelineState; - - using PipelineParams = typename MainloopPipeline::Params; - - static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert( - (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert( - (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert( - (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert( - (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - using SmemLayoutAux = cute::conditional_t; - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(cute::is_base_of::value - && cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - static_assert( - cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert( - cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128> - { - cute::array_aligned> smem_A; - cute::array_aligned> smem_B; - cute::array_aligned> smem_Aux; - } tensors; - - using PipelineStorage = typename MainloopPipeline::SharedStorage; - PipelineStorage pipeline; - }; - - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - // Host side kernel arguments - struct Arguments - { - ElementA const* ptr_A; - StrideA dA; - ElementB const* ptr_B; - StrideB dB; - float scale_d0 = 1.0f; - float scale_d1 = 1.0f; - uint32_t mma_promotion_interval = 4; - }; - - // Device side kernel params - struct Params - { - // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, - make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), - SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any - // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, - make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), - SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any - using TMA_Aux = cute::conditional_t; - TMA_A tma_load_a; - TMA_B tma_load_b; - TMA_Aux tma_load_aux; - float scale_d0 = 1.0f; - float scale_d1 = 1.0f; - uint32_t mma_promotion_interval = 4; - }; - - // - // Methods - // - - template - static constexpr Params to_underlying_arguments( - ProblemShape const& problem_shape, Arguments const& args, void* workspace) - { - (void) workspace; - - // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - auto ptr_A = reinterpret_cast(args.ptr_A); - auto ptr_B = reinterpret_cast(args.ptr_B); - - Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); - Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); - typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, - SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any - typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, - SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - if constexpr (SwapAB) - { - auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); - Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); - typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, - SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; - } - else - { - auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); - Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); - typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, - SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; - } - } - - template - static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) - { - constexpr int tma_alignment_bits = 128; - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - bool implementable = true; - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); - /* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA - * instructions. */ - implementable = implementable && (args.mma_promotion_interval % 4 == 0); - - if (!implementable) - { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - } - return implementable; - } - - static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr int K_PIPE_MMAS = 1; - static constexpr uint32_t TmaTransactionBytes - = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 - + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 - + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) - / 8; - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); - } - - /// Set up the data needed by this collective for load and mma. - /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// Returned tuple must contain at least two elements, with the first two elements being: - /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) - /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) - /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) - template - CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const - { - using X = Underscore; - // Separate out problem shape for convenience - auto [M, N, K, L] = problem_shape_MNKL; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) - Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) - - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) - - if constexpr (SwapAB) - { - Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) - Tensor gAux_xkl - = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); - } - else - { - Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) - Tensor gAux_xkl - = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) - return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Producer Perspective - template - CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - cute::tuple const& load_inputs, BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) - { - int lane_predicate = cute::elect_one_sync(); - - if (lane_predicate) - { - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); - - // - // Prepare the TMA loads for A and B - // - - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id - = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - Tensor gAux_xkl = get<2>(load_inputs); - - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); - auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) - : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); - - // Partition the inputs based on the current block coordinates. - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) - Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); - - // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) - - Tensor tAuxgAux = block_tma_aux.partition_S(gAux); - Tensor tAuxsAux = block_tma_aux.partition_D(sAux); - - uint16_t mcast_mask_a = 0; - uint16_t mcast_mask_b = 0; - uint16_t mcast_mask_aux = 0; - - // Issue TmaLoads - // Maps the tile -> block, value - if constexpr (cute::is_same_v) - { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) - { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); - } - } - - if constexpr (cute::is_same_v) - { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) - { - mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); - } - } - - if constexpr (SwapAB) - { - mcast_mask_aux = mcast_mask_a; - } - else - { - mcast_mask_aux = mcast_mask_b; - } - - // Mainloop - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - // LOCK smem_pipe_write for _writing_ - pipeline.producer_acquire(smem_pipe_write); - - // - // Copy gmem to smem for *k_tile_iter - // - - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); - - int write_stage = smem_pipe_write.index(); - copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), - tAsA(_, _, _, write_stage)); - copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), - tBsB(_, _, _, write_stage)); - copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), - tAuxsAux(_, _, _, write_stage)); - ++k_tile_iter; - - // Advance smem_pipe_write - ++smem_pipe_write; - } - } - } - - /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) - { - int lane_predicate = cute::elect_one_sync(); - - // Issue the epilogue waits - if (lane_predicate) - { - /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all - * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was - * still inverted from make_producer_start_state - */ - pipeline.producer_tail(smem_pipe_write); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Consumer Perspective - template - CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, - FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, - Params const& mainloop_params) - { - - static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); - - // - // Define C accumulators and A/B partitioning - // - - TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - - // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) - - auto tCsAux = [&]() -> auto - { - if constexpr (SwapAB) - { - return thread_mma.partition_A(sAux); - } - else - { - return thread_mma.partition_B(sAux); - } - }(); - auto tCrAux = [&]() -> auto - { - if constexpr (SwapAB) - { - return thread_mma.make_fragment_A(tCsAux); - } - else - { - return thread_mma.make_fragment_B(tCsAux); - } - }(); - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - if constexpr (SwapAB) - { - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE - } - else - { - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE - } - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE - - // - // PIPELINED MAIN LOOP - // - static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); - - // We release buffers to producer warps(dma load) with some mmas in flight - PipelineState smem_pipe_release = smem_pipe_read; - - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - - GmmaFP8Accumulation accumulation0(accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA)); - GmmaFP8Accumulation accumulation1(accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA)); - warpgroup_fence_operand(accumulation0()); - warpgroup_fence_operand(accumulation1()); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - if (accumulation0.prepare_if_needed()) - { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - - int read_stage = smem_pipe_read.index(); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) - { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm( - tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); - if constexpr (SwapAB) - { - cute::gemm( - tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); - } - else - { - cute::gemm( - tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); - } - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - accumulation0.promote_if_needed(); - accumulation1.promote_if_needed(); - - ++smem_pipe_read; - } - - warpgroup_fence_operand(accumulation0()); - warpgroup_fence_operand(accumulation1()); - // Mainloop GMMAs - k_tile_count -= prologue_mma_count; - - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - // - // Compute on k_tile - // - - int read_stage = smem_pipe_read.index(); - - if (accumulation0.prepare_if_needed()) - { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - - warpgroup_fence_operand(accumulation0()); - warpgroup_fence_operand(accumulation1()); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) - { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm( - tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); - if constexpr (SwapAB) - { - cute::gemm( - tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); - } - else - { - cute::gemm( - tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); - } - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_wait(); - warpgroup_fence_operand(accumulation0()); - warpgroup_fence_operand(accumulation1()); - - accumulation0.promote_if_needed(); - accumulation1.promote_if_needed(); - - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - - // Advance smem_pipe_read and smem_pipe_release - ++smem_pipe_read; - ++smem_pipe_release; - } - - accumulation0.promote_residue_if_needed(); - accumulation1.promote_residue_if_needed(); - - warpgroup_fence_operand(accumulation0()); - warpgroup_fence_operand(accumulation1()); - } - - /// Perform a Consumer Epilogue to release all buffers - CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) - { - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - k_tile_count -= prologue_mma_count; - - smem_pipe_release.advance(k_tile_count); - - // Wait on all GMMAs to complete - warpgroup_wait<0>(); - - for (int count = 0; count < prologue_mma_count; ++count) - { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - ++smem_pipe_release; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h deleted file mode 100644 index 2edd5a228b..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h +++ /dev/null @@ -1,438 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and - batched array variants. -*/ - -#pragma once - -// #include - -#include "cutlass/arch/arch.h" -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/gemm_universal.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" - -#include "cutlass/gemm/device/default_gemm_configuration.h" -#include "cutlass/gemm/kernel/default_gemm_universal.h" - -#include "cutlass/trace.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace device -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/* - This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) - It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs - and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. - - Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support - that feature at the moment. - */ - -template -class GemmUniversalBaseCompat -{ -public: - using GemmKernel = GemmKernel_; - using ThreadblockShape = typename GemmKernel::Mma::Shape; - - using ElementA = typename GemmKernel::ElementA; - using LayoutA = typename GemmKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = GemmKernel::kTransformA; - - using ElementB = typename GemmKernel::ElementB; - using LayoutB = typename GemmKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = GemmKernel::kTransformB; - - using ElementC = typename GemmKernel::ElementC; - using LayoutC = typename GemmKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - - using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; - - using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; - using Operator = typename GemmKernel::Operator; - - /// Argument structure - using Arguments = typename GemmKernel::Arguments; - -protected: - /// Kernel parameters object - typename GemmKernel::Params params_; - -protected: - /// Private helper to obtain the grid dimensions with fix-up for split-K - static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) - { - - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); - - gemm_k_size = args.problem_size.k(); - - if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) - { - - int const kAlignK - = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); - - gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); - - if (gemm_k_size) - { - grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); - } - } - } - -public: - /// Constructs the GEMM. - GemmUniversalBaseCompat() {} - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) - { - - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - - ThreadblockSwizzle threadblock_swizzle; - dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - - uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); - - if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) - { - - return Status::kErrorInvalidProblem; - } - - return GemmKernel::can_implement(args); - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); - - size_t workspace_bytes = 0; - - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - - if (args.mode == GemmUniversalMode::kGemmSplitKParallel) - { - - // Split-K parallel always requires a temporary workspace - workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); - } - else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) - { - - // Serial split-K only requires a temporary workspace if the number of partitions along the - // GEMM K dimension is greater than one. - workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); - } - - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); - - return workspace_bytes; - } - - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); - - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - - CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" - << " result = {" << result << "}"); - - return result; - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); - - int max_active_blocks = -1; - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - if (smem_size <= (48 << 10)) - { - - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); - - if (result == cudaSuccess) - { - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - } - else - { - - // Query assuming zero shared memory then compute occupancy limit based on SMEM - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); - - if (result != cudaSuccess) - { - - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); - - return -1; - } - - if (smem_capacity < 0) - { - int device_idx = 0; - result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) - { - return -1; - } - - cudaDeviceProp properties; - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) - { - return -1; - } - - smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); - } - - int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); - - CUTLASS_TRACE_HOST(" occupancy: " << occupancy); - - return occupancy; - } - - CUTLASS_TRACE_HOST(" returning internal error"); - - return -1; - } - - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); - - size_t workspace_bytes = get_workspace_size(args); - - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - if (workspace_bytes) - { - - if (!workspace) - { - CUTLASS_TRACE_HOST(" error: device workspace must not be null"); - - return Status::kErrorWorkspaceNull; - } - - if (args.mode == GemmUniversalMode::kGemm) - { - CUTLASS_TRACE_HOST(" clearing device workspace"); - cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); - - return Status::kErrorInternal; - } - } - } - - // Get CUDA grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - - // Initialize the Params structure - params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); - - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) - { - cudaError_t result - = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Lightweight update given a subset of arguments - Status update(Arguments const& args, void* workspace = nullptr) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); - - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - params_.update(args, workspace); - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) - { - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); - - // - // Configure grid and block dimensions - // - - ThreadblockSwizzle threadblock_swizzle; - - dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); - dim3 block(GemmKernel::kThreadCount, 1, 1); - - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - // - // Launch kernel - // - - CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); - - // Launch - cutlass::Kernel<<>>(params_); - - // - // Query for errors - // - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) - { - return run(stream); - } - - /// Runs the kernel using initialized state. - Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { - - Status status = initialize(args, workspace, stream); - - if (status == Status::kSuccess) - { - status = run(stream); - } - - return status; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h deleted file mode 100644 index bfd3666b9c..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h +++ /dev/null @@ -1,542 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h -*/ - -#pragma once - -#include -#include -#include - -#include "cutlass/arch/arch.h" -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/gemm_universal.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" - -#include "cutlass/gemm/device/default_gemm_configuration.h" -#include "cutlass/gemm/kernel/default_gemm_universal.h" - -#include "cutlass/trace.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace device -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk, - int64_t* splitk_buffer_offsets) -{ - // in_tensor: [problem_idx, k_partition, hidden_size] - // Note that different requests of in_tensor might have different hidden_size (=m*n) - // so, we need to use splitk_buffer_offsets. - // out_tensor: problem_idx * [hidden_size] - - int const problem_idx = blockIdx.y; - GemmCoord problem = problem_sizes[problem_idx]; - int const hidden_size = problem.m() * problem.n(); - const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; - T_OUT* out_tensor_ = out_tensor[problem_idx]; - - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) - { - float sum = 0.0f; - for (int k_idx = 0; k_idx < splitk; k_idx++) - { - sum += (float) in_tensor_[k_idx * hidden_size + i]; - } - out_tensor_[i] = (T_OUT) (sum); - } -} - -/// GEMM Grouped -template -class BaseSplitkGrouped -{ -public: - using BaseKernel = BaseKernel_; - - using ElementA = typename BaseKernel::ElementA; - using LayoutA = typename BaseKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = BaseKernel::kTransformA; - static int const kAlignmentA = BaseKernel::kAlignmentA; - - using ElementB = typename BaseKernel::ElementB; - using LayoutB = typename BaseKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = BaseKernel::kTransformB; - static int const kAlignmentB = BaseKernel::kAlignmentB; - - using ElementC = typename BaseKernel::ElementC; - using LayoutC = typename BaseKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - static int const kAlignmentC = BaseKernel::kAlignmentC; - - using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; - - using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; - - using Operator = typename BaseKernel::Operator; - using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; - - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename WarpMmaOperator::MathOperator; - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - using ThreadblockShape = typename BaseKernel::Mma::Shape; - using WarpShape = typename BaseKernel::WarpShape; - using InstructionShape = typename BaseKernel::InstructionShape; - static int const kStages = BaseKernel::Mma::kStages; - - /// Argument structure - using Arguments = typename BaseKernel::Arguments; - - using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; - -protected: - /// Kernel parameters object - typename BaseKernel::Params gemm_params_; - -private: - /// Get the number of tiles across all problems in a group - static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) - { - int32_t tiles = 0; - for (int32_t i = 0; i < problem_count; ++i) - { - cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; - BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); - tiles += problem_tile_count(problem); - } - return tiles; - } - - /// Copy from `data` to `workspace` - Status copy_to_workspace(void* workspace, void* data, size_t bytes) - { - cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); - if (cuda_error != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - cuda_error = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error)); - return Status::kErrorInternal; - } - - return Status::kSuccess; - } - - /// Precomputes scheduling information for the grouped GEMM - Status precompute(Arguments const& args, int32_t tile_count, void* workspace) - { - size_t workspace_bytes = get_workspace_size(args); - std::vector host_workspace(workspace_bytes); - BaseKernel::ProblemVisitor::host_precompute( - args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data()); - return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); - } - - /// Reorder `data` according to `indices` - template - static void reorder_array(T* data, std::vector const& indices) - { - // For now, simply create a copy of the data and then copy over to the original. - std::vector copy(indices.size()); - for (size_t i = 0; i < indices.size(); ++i) - { - copy.at(i) = data[indices[i]]; - } - - memcpy(data, copy.data(), indices.size() * sizeof(T)); - } - -public: - /// Constructs the GEMM. - BaseSplitkGrouped() {} - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) - { - - return BaseKernel::can_implement(args); - } - - /// Get the number of tiles in a problem - static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) - { - auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); - return BaseKernel::ProblemVisitor::tile_count(grid); - } - - /// Get the number of tiles across all problems in a group - static int32_t group_tile_count(Arguments const& args) - { - if (args.host_problem_sizes == nullptr) - { - CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); - return -1; - } - - return group_tile_count(args.host_problem_sizes, args.problem_count); - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) - { - size_t total_mn = 0; - for (int i = 0; i < args.problem_count; i++) - { - total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); - } - size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices; - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size( - args.host_problem_sizes, args.problem_count, args.threadblock_count); - } - return workSpaceSize; - } - - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) - { - - return dim3(args.threadblock_count, 1, 1); - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) - { - - CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); - - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - cudaError_t result; - if (smem_size > (48 << 10)) - { - result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); - return -1; - } - } - - int max_active_blocks = -1; - result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, BaseKernel::kThreadCount, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); - return -1; - } - - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - - /// Sorts each pointer passed in according to the indices that sort - /// `problem_sizes_ptr` in descending order of problem-K dimension. - static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr, - int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr, - int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) - { - std::vector indices(problem_count); - std::iota(indices.begin(), indices.end(), 0); - std::stable_sort(indices.begin(), indices.end(), - [&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); }); - - reorder_array(problem_sizes_ptr, indices); - reorder_array(lda_host_ptr, indices); - reorder_array(ldb_host_ptr, indices); - reorder_array(ldc_host_ptr, indices); - reorder_array(ldd_host_ptr, indices); - reorder_array(offset_A_ptr, indices); - reorder_array(offset_B_ptr, indices); - reorder_array(offset_C_ptr, indices); - reorder_array(offset_D_ptr, indices); - } - - /// Computes the number of threadblocks to launch for the grouped kernel - static int sufficient( - cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) - { - // Determine the number of blocks that would be launched to fill up a single - // wave on the GPU with each SM having maximum occupancy. - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result)); - return 0; - } - - int multiprocessor_count; - result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result)); - return 0; - } - - bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); - if (override_sm_count) - { - available_sm_count = multiprocessor_count; - } - - int max_active_blocks = maximum_active_blocks(); - if (max_active_blocks <= 0) - { - return 0; - } - - int occupancy_based_block_count = available_sm_count * max_active_blocks; - - if (problem_sizes_ptr == nullptr || problem_count == 0) - { - return occupancy_based_block_count; - } - - int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); - - // If the group contains a single problem, launching the exact number of - // threadblocks needed to cover the problem minimizes the work performed - // per threadblock in finding the next tile to compute. We return total_tiles - // unless the user has provided the SM count. - if (problem_count == 1 && override_sm_count) - { - return total_tiles; - } - - // Choose between the full wave of threadblocks and the tile count. If there - // are fewer tiles in the group than threadblocks in the full wave, only - // some threadblocks will be assigned tiles. Those threadblocks - // which are not assigned tiles still need to perform the work of iterating through - // problem sizes to determine that they have no work to do. This competes for cycles - // with those threadblocks that are assigned tiles to compute. - return std::min(total_tiles, occupancy_based_block_count); - } - - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { - - CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); - - // Workspace - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - int32_t tile_count = group_tile_count(args); - Status status = precompute(args, tile_count, workspace); - if (status != Status::kSuccess) - { - return status; - } - - gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); - } - else - { - gemm_params_ = typename BaseKernel::Params(args, workspace); - } - - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) - { - cudaError_t result - = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Lightweight update given a subset of arguments - Status update(Arguments const& args, void* workspace = nullptr) - { - - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - int32_t tile_count = group_tile_count(args); - Status status = precompute(args, tile_count, workspace); - if (status != Status::kSuccess) - { - return status; - } - - gemm_params_.update(args, workspace, tile_count); - } - else - { - gemm_params_.update(args, workspace); - } - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) - { - if (!gemm_params_.problem_visitor.problem_count) - { - return Status::kSuccess; - } - - // - // Launch kernel - // - - // Launch splitk grouped gemm - { - dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); - dim3 block(BaseKernel::kThreadCount, 1, 1); - - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - cutlass::Kernel<<>>(gemm_params_); - - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - - // Launch splitkReduction - { - dim3 grid(32, gemm_params_.problem_visitor.problem_count); - dim3 block(256); - splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split, - gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices, - gemm_params_.splitk_buffer_offsets); - - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) - { - return run(stream); - } - - /// Initializes and runs the kernel. - Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) - { - - Status status = initialize(args, workspace, stream); - - if (status == Status::kSuccess) - { - status = run(stream); - } - - return status; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GEMM Grouped -template -class SplitkGemmGrouped : public BaseSplitkGrouped -{ -public: - using GemmKernel = GemmKernel_; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h deleted file mode 100644 index 100a1161a8..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ /dev/null @@ -1,162 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/bfloat16.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/half.h" -#include "cutlass/layout/matrix.h" - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -template -struct MixedGemmArchTraits -{ - static_assert(dependent_false, "Unrecognised parameterization"); -}; - -template -struct MixedGemmArchTraits -{ - static constexpr int Stages = 2; - using OperatorClass = cutlass::arch::OpClassSimt; - using AccType = float; - using LayoutB = cutlass::layout::ColumnMajor; - - static constexpr int ElementsPerAccessA = 1; - static constexpr int ElementsPerAccessB = 1; - static constexpr int ElementsPerAccessC = 1; - static constexpr int ThreadblockK = 8; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -// ======================= Turing Traits ============================== -// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 -// and compute will happen in fp16 then will be converted for bf16 output. -template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; - - using Operator = typename LayoutDetails::Operator; -}; - -// ======================= Ampere Traits ============================== -template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; - - using Operator = typename LayoutDetails::Operator; -}; - -// ======================= Ada Traits ============================== -template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; - - using Operator = typename LayoutDetails::Operator; -}; - -// FP8 A/B = fp8, C/D = fp32 -template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - // be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t - using TypeC = __nv_bfloat16; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; - - using Operator = typename LayoutDetails::Operator; -}; - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h deleted file mode 100644 index 3fd722994e..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/layout/matrix.h" - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -template -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassSimt; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -}; - -// ======================= Turing Traits ============================== -template <> -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassTensorOp; - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -}; - -// ======================= Ampere Traits ============================== -template <> -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassTensorOp; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -}; - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h deleted file mode 100644 index 1dbd0b1765..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h +++ /dev/null @@ -1,207 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with - the appropriate threadblock-scoped epilogue. - - Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are - accommodated by exchanging A and B operands and assuming transposed layouts. Partial - specializations here choose 'device::GemmTransposed' to implement this functionality. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/complex.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/default_gemm_configuration.h" -#include "cutlass/gemm/kernel/default_gemm.h" -#include "cutlass/gemm/kernel/default_gemm_complex.h" -#include "cutlass/gemm/kernel/gemm_transpose_operands.h" - -#include "cutlass/layout/permute.h" - -#include "splitk_gemm_grouped.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Complex elementwise transformation on A operand - ComplexTransform TransformA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Complex elementwise transformation on B operand - ComplexTransform TransformB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Warp-level tile size (concept: GemmShape) - typename InstructionShape, - /// Epilogue output operator - typename EpilogueOutputOp, - /// Threadblock-level swizzling operator - typename ThreadblockSwizzle, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Whether the schedule of problems to visit has been precomputed - GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, - /// Operation performed by GEMM - typename Operator = typename device::DefaultGemmConfiguration::Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Permute result D - typename PermuteDLayout = layout::NoPermute, - /// - typename Enable = void> -struct DefaultSplitkGemmGrouped; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Real-valued GEMM kernels -// - -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC, - /// Layout type for C and D matrix operands - typename LayoutC, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Warp-level tile size (concept: GemmShape) - typename InstructionShape, - /// Epilogue output operator - typename EpilogueOutputOp, - /// Threadblock-level swizzling operator - typename ThreadblockSwizzle, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Whether the schedule of problems to visit has been precomputed - GroupScheduleMode GroupScheduleMode_, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Permute result D - typename PermuteDLayout> -struct DefaultSplitkGemmGrouped::value>::type> -{ - - // If true, we must construct a 'transposed-and-exchanged' Mma operator. - static bool const kInternalTranspose = platform::is_same::value; - - using MapArguments = kernel::detail::MapArguments; - - // Define the default GEMM kernel - using DefaultGemmKernel = typename kernel::DefaultGemm::GemmKernel; - - /// Define the kernel in terms of the default kernel - using GemmKernel = kernel::SplitkGemmGrouped; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h deleted file mode 100644 index 0baec58ea9..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ /dev/null @@ -1,566 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail -{ -template -inline constexpr bool dependent_false_v = false; -} - -template -struct GemmFpAIntB -{ - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static bool const kSplitKSerial = SplitKSerial; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Element; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Mma::LayoutC; - using ElementScale = ElementC; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformA; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - - /// Parameters structure - struct Arguments - { - GemmUniversalMode mode = GemmUniversalMode::kGemm; - - cutlass::gemm::GemmCoord problem_size; - int group_size; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Mma::IteratorScale::TensorRef ref_zero; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - - // Control serial split-k - int batch_count; - - typename EpilogueOutputOp::Params output_op; - - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // Included so we can use Gemm Universal - int batch_stride_D = 0; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Arguments() {} - - CUTLASS_HOST_DEVICE - Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, - typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, - typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, - typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), - int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, - int const* scatter_D_indices = nullptr) - : problem_size(problem_size) - , group_size(group_size) - , ref_A(ref_A) - , ref_B(ref_B) - , ref_scale(ref_scale) - , ref_zero(ref_zero) - , ref_C(ref_C) - , ref_D(ref_D) - , batch_count(serial_split_k_factor) - , output_op(output_op) - , gather_A_indices(gather_A_indices) - , gather_B_indices(gather_B_indices) - , scatter_D_indices(scatter_D_indices) - { - } - }; - - /// Parameters structure - struct Params - { - cutlass::gemm::GemmCoord problem_size; - int group_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::Params params_B; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::Params params_scale; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Mma::IteratorScale::TensorRef ref_zero; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::Params params_D; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - typename EpilogueOutputOp::Params output_op; - int* semaphore; - int gemm_k_size; - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : swizzle_log_tile(0) - , semaphore(0) - , gemm_k_size(0) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, - void* workspace = nullptr) - : problem_size(args.problem_size) - , group_size(args.group_size) - , grid_tiled_shape(grid_tiled_shape) - , swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)) - , params_A(args.ref_A.layout()) - , ref_A(args.ref_A) - , params_B(args.ref_B.layout()) - , ref_B(args.ref_B) - , params_scale(args.ref_scale.layout()) - , ref_scale(args.ref_scale) - , ref_zero(args.ref_zero) - , params_C(args.ref_C.layout()) - , ref_C(args.ref_C) - , params_D(args.ref_D.layout()) - , ref_D(args.ref_D) - , output_op(args.output_op) - , semaphore(static_cast(workspace)) - , gemm_k_size(gemm_k_size) - , gather_A_indices(args.gather_A_indices) - , gather_B_indices(args.gather_B_indices) - , scatter_D_indices(args.scatter_D_indices) - { - } - }; - - /// Shared memory storage structure - union SharedStorage - { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - GemmFpAIntB() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(Arguments const& args) - { - static int const kAlignmentA - = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) - ? 64 - : Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB - = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) - ? 64 - : Mma::IteratorB::AccessType::kElements; - - static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; - - static int const kAlignmentC = (platform::is_same>::value) - ? 32 - : (platform::is_same>::value) - ? 64 - : Epilogue::OutputTileIterator::kElementsPerAccess; - - if (!TensorRef_aligned(args.ref_A, kAlignmentA)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_B, kAlignmentB)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_C, kAlignmentC)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_D, kAlignmentC)) - { - return Status::kErrorMisalignedOperand; - } - - if (!args.ref_scale.good()) - { - return Status::kErrorNotSupported; - } - - if constexpr (hasZero(Mma::QuantOp)) - { - if (!args.ref_zero.good()) - { - return Status::kErrorNotSupported; - } - } - else - { - if (args.ref_zero.good()) - { - return Status::kErrorNotSupported; - } - } - - if constexpr (isFinegrained(Mma::QuantOp)) - { - if (args.group_size != 64 && args.group_size != 128) - { - return Status::kErrorNotSupported; - } - } - - return Status::kSuccess; - } - - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; - } - - // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator - // has a different constructor signature than a regular cutlass iterator - template = true> - CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, - typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, - typename IteratorScale::TensorCoord extent, int thread_id, - typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) - { - - return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); - } - - template = true> - CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, - typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, - typename IteratorScale::TensorCoord extent, int thread_id, - typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) - { - - return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); - } - - CUTLASS_DEVICE - void run_kernel_(Params const& params, SharedStorage& shared_storage) - { - using LayoutB = typename Mma::IteratorB::Layout; - static_assert(platform::is_same::value && kInterleave == 1 - || platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() - || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) - { - - return; - } - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size, - }; - - cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, - threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; - - typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; - typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; - cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), - {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); - - typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), - {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, - params.gather_B_indices); - - typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; - typename Mma::IteratorScale iterator_scale = initialize_scale( - params.params_scale, params.ref_scale.data(), params.ref_zero.data(), - {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - if (!kSplitKSerial || gemm_k_iterations > 0) - { - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); - } - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) - { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_offset.k()); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // - // Release the semaphore - // - - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) - { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else - { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - semaphore.release(lock); - } - } - - template - CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) - { - if constexpr (platform::is_same::value) - { - run_kernel_(params, shared_storage); - } - else - { - CUTLASS_NOT_IMPLEMENTED(); - } - } - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { -#if defined(__CUDA_ARCH__) -#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ == 890) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 900) - CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. -#else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); -#endif -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh deleted file mode 100644 index 1bd0a3f11a..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh +++ /dev/null @@ -1,218 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include -#include - -namespace fused_moe -{ -template -struct Fused_Moe_Kernel_sm80 -{ - static constexpr int kMaxTileM = MaxTileM_; - static constexpr int kTileN = isGateActivation(activation_type_) ? TileN_ / 2 : TileN_; - static constexpr int kTileK = TileK_; - static constexpr int kStages = Stages_; - static constexpr Activation_Type activation_type = activation_type_; - - using ElementInput = ElementInput_; - using ElementWeight = ElementWeight_; - using ElementOutput = ElementOutput_; - using BaseKernelTraits = Fused_Moe_Kernel_traits_sm80; - using Routine_Arguments = Routine_Arguments; - using Routine_Params = Routine_Params; - using ProblemVisitor - = cutlass::gemm::kernel::MoeProblemVisitor, false>, - cutlass::gemm::GemmShape, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, - BaseKernelTraits::kThreadCount, BaseKernelTraits::kThreadCount>; - - struct Arguments - { - Routine_Arguments routine_args; - int problem_count{}; - int threadblock_count{}; - }; - - struct Params - { - Routine_Params routine_params; - int threadblock_count{}; - typename ProblemVisitor::Params problem_visitor_param; - }; - - using BaseKernelTraits_m16 = Fused_Moe_Kernel_traits_sm80; - static constexpr bool use_m16 = TileK_ >= 64; // use tileshape m = 16 when original tileshape k >= 64 - - static constexpr int kSmemSize = use_m16 - ? (BaseKernelTraits::kSmemSize > BaseKernelTraits_m16::kSmemSize ? BaseKernelTraits::kSmemSize - : BaseKernelTraits_m16::kSmemSize) - : BaseKernelTraits::kSmemSize; - static constexpr int kThreadCount = BaseKernelTraits::kThreadCount; - - static constexpr bool can_implement(int const avaliable_smem_size) - { - return BaseKernelTraits::can_implement(avaliable_smem_size); - } - - static Params to_underlying_arguments(Arguments const& args) - { - return { - {args.routine_args.ptr_input, args.routine_args.ptr_fc1, args.routine_args.ptr_bias, - args.routine_args.ptr_output, args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, - args.routine_args.gemm_k, args.routine_args.num_expert, args.routine_args.bias_is_broadcast}, - args.threadblock_count, - {args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, args.routine_args.gemm_k, - args.problem_count, nullptr, 0}}; - } - - CUTE_DEVICE - void run_device(Params const& params) - { -#define ROUTINE_PATH(kTileM_size) \ - { \ - constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \ - using RoutineTraits = Fused_Moe_Kernel_routine_sm80; \ - RoutineTraits routine{}; \ - int const block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \ - routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \ - } - typename ProblemVisitor::SharedStorage dummy_storage{}; - ProblemVisitor problem_visitor(params.problem_visitor_param, dummy_storage, blockIdx.x); - while (problem_visitor.next_tile()) - { - auto problem_size = problem_visitor.problem_size(); - auto grid_size = problem_visitor.grid_shape(problem_size); - auto problem_index = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - int const gemm_m = problem_size.m(); - const int32_t block_m_idx_temp = cta_idx / grid_size.n(); - const int32_t block_n_idx = cta_idx % grid_size.n(); - - int const residue_m = gemm_m - kMaxTileM * block_m_idx_temp; - if (residue_m > kMaxTileM / 2) - { - using RoutineTraits = Fused_Moe_Kernel_routine_sm80; - RoutineTraits routine{}; - routine.run_routine(params.routine_params, problem_index, block_m_idx_temp, block_n_idx, gemm_m); - } - else - { - - if constexpr (kMaxTileM >= 128) - { - if (residue_m > 32) - { - ROUTINE_PATH(64); - } - else if (residue_m > 16) - { - ROUTINE_PATH(32); - } - else - { - // TODO: use cuda core gemm here - ROUTINE_PATH(16); - } - } - else if (kMaxTileM == 64) - { - if (residue_m > 16) - { - ROUTINE_PATH(32); - } - else - { - // TODO: use cuda core gemm here - ROUTINE_PATH(16); - } - } - else if (kMaxTileM == 32) - { - // TODO: use cuda core gemm here - ROUTINE_PATH(16); - } - else - { - // TODO: use cuda core gemm here - ROUTINE_PATH(16); - } - } - problem_visitor.advance(gridDim.x); - } -#undef ROUTINE_PATH - } -}; - -template -__global__ void run_global(__grid_constant__ typename GemmType::Params const params) -{ - GemmType gemm; - gemm.run_device(params); -} - -/// Computes the maximum number of active blocks per multiprocessor -template -static int fused_gemm_maximum_active_blocks(int smem_capacity = -1) -{ - - CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); - - constexpr int smem_size = GemmType::kSmemSize; - - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - cudaError_t result; - if (smem_size > (48 << 10)) - { - result = cudaFuncSetAttribute(run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); - return -1; - } - } - - int max_active_blocks = -1; - result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, run_global, GemmType::kThreadCount, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); - return -1; - } - - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; -} -} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh deleted file mode 100644 index 4c46a541ef..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh +++ /dev/null @@ -1,799 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include - -namespace fused_moe -{ - -template -struct Fused_Moe_Kernel_routine_sm80; - -template -struct Fused_Moe_Kernel_routine_sm80> -{ - using KT = Fused_Moe_Kernel_traits_sm80; - using Params = Routine_Params; - - CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) - { - using X = cute::Underscore; - - int const M = gemm_m; - int const N1 = params.gemm_n; - int const K1 = params.gemm_k; - bool const bias_is_broadcast = params.bias_is_broadcast; - - int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]); - typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; - typename KT::ElementWeight const* ptr_fc1_gate_ - = params.ptr_fc1 + (2 * problem_index + 1) * N1 * K1; // TODO: we only focus on gated activation.. - typename KT::ElementWeight const* ptr_fc1_ - = params.ptr_fc1 + 2 * problem_index * N1 * K1; // TODO: we only focus on gated activation.. - typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr) - ? nullptr - : (bias_is_broadcast ? params.ptr_bias + 2 * problem_index * N1 : params.ptr_bias + 2 * row_jump * N1); - typename KT::ElementInput const* ptr_bias_gate_ = (params.ptr_bias == nullptr) - ? nullptr - : (bias_is_broadcast ? params.ptr_bias + (2 * problem_index + 1) * N1 - : params.ptr_bias + (2 * row_jump + 1) * N1); - typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; - - cute::Tensor mInput_mk - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), - cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); - - cute::Tensor mfc1_gate_nk - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_gate_)), - cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); - - cute::Tensor mfc1_nk - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), - cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); - - cute::Tensor mBias_mn = cute::make_tensor( - cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), - cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2, - cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. - - cute::Tensor mBias_gate_mn = cute::make_tensor( - cute::make_gmem_ptr(static_cast(ptr_bias_gate_)), cute::make_shape(M, N1), - cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2, - cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. - - cute::Tensor mOutput_mn - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), - cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); - - cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) - cute::Tensor gfc1_gate_nk = cute::local_tile(mfc1_gate_nk, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) - cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) - - cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) - - cute::Tensor gBias_gate_mn = cute::local_tile(mBias_gate_mn, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) - - cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) - - return cute::make_tuple(gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn); - } - - // be careful, m_idx will change when use another tile shape.. - CUTE_DEVICE void run_routine( - Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) - { - extern __shared__ char smem_[]; - typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); - int const thread_idx = threadIdx.x; - bool const bias_is_broadcast = params.bias_is_broadcast; - // gmem tensor partition .. - auto [gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn] - = gmem_tensor_init(problem_index, gemm_m, params); - int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); - auto const n_tile_count = cute::size<2>(gfc1_gate_nk); - - // smem tensor .. - cute::Tensor sInput = cute::make_tensor( - cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) - cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), - typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) - cute::Tensor sfc1_gate_weight - = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_gate_weight.data()), - typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) - cute::Tensor sO = cute::make_tensor( - cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) - - // (1) first step, get the fc1_res and fc1_gate - - // (1.1) get partition for gmem -> smem - cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) - cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) - cute::Tensor gfc1g = gfc1_gate_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) - - typename KT::GmemTiledCopyA gmem_tiled_copy_A; - typename KT::GmemTiledCopyB gmem_tiled_copy_B; - auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); - auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); - - cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) - cute::Tensor tInputsInput = gmem_thr_copy_A.partition_D(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) - cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) - cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) - cute::Tensor tfc1ggfc1g = gmem_thr_copy_B.partition_S(gfc1g); // (BCPY,BCPY_N,BCPY_K,k) - cute::Tensor tfc1gsfc1g = gmem_thr_copy_B.partition_D(sfc1_gate_weight); // (BCPY,BCPY_N,BCPY_K,Stage) - - // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) - cute::Tensor tInputpInput - = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), - cute::Stride{}); - // Construct identity layout for sInput - cute::Tensor cInput = make_identity_tensor( - make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - - // Repeat the partitioning with identity layouts - cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - - // Set predicates for m bounds - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < cute::size<0>(tInputpInput); ++m) - { - tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m - } - - // (1.2) prefetch gmem -> smem - cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. - auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 - int k_tile_count = cute::size<2>(gInput); - CUTLASS_PRAGMA_UNROLL - for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) - { - if (k_tile_count <= 0) - { - cute::clear(tInputpInput); - } - // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); - // use copy_if - cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - tInputsInput(cute::_, cute::_, cute::_, k_pipe)); - cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); - cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1gsfc1g(cute::_, cute::_, cute::_, k_pipe)); - cute::cp_async_fence(); - k_tile_count--; - if (k_tile_count > 0) - { - ++k_tile_iter; - } - } - - // (1.3) get partition for rf - typename KT::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) - cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) - cute::Tensor tOrfc1g = thr_mma.partition_fragment_B(sfc1_gate_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) - - cute::Tensor accum - = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) - cute::Tensor accum_gate - = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) - cute::clear(accum); - cute::clear(accum_gate); - // checkout the shape - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum_gate)); // MMA_M - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum_gate)); // MMA_N - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum)); // MMA_N - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum_gate)); // MMA_N - CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K - CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1g)); // MMA_K - CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); - CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); - - // (1.4)retiling the smem and rf for copy.. - auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); - auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); - cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) - cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) - CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M - CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K - - auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); - auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); - cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) - cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) - cute::Tensor tOsfc1g = smem_thr_copy_B.partition_S(sfc1_gate_weight); // (CPY,CPY_N,CPY_K,Stage) - cute::Tensor tOrfc1g_copy_view = smem_thr_copy_B.retile_D(tOrfc1g); // (CPY,CPY_N,CPY_K) - CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N - CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K - CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1g) == cute::size<1>(tOrfc1g_copy_view)); // CPY_N - CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1g) == cute::size<2>(tOrfc1g_copy_view)); // CPY_K - - // (1.5) mainloop - // Current pipe index in smem to read from - int smem_pipe_read = 0; - // Current pipe index in smem to write to - int smem_pipe_write = KT::Stages - 1; - - cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - cute::Tensor tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); - - constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); - // prefetch register pipeline - if constexpr (K_BLOCK_MAX > 1) - { - cute::cp_async_wait(); - __syncthreads(); - - // Prefetch the first rmem from the first k-tile - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), - tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), - tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); - cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, cute::Int<0>{}), - tOrfc1g_copy_view(cute::_, cute::_, cute::Int<0>{})); - } - // k loop for mainloop - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - if (k_block == K_BLOCK_MAX - 1) - { - tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); - cute::cp_async_wait(); - __syncthreads(); - } - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), - tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); - // Copy gmem to smem before computing gemm on each k-pipe - if (k_block == 0) - { - // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::copy_if(gmem_tiled_copy_A, tInputpInput, - tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1gsfc1g(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::cp_async_fence(); - if (k_tile_count - 1 > 0) - { - ++k_tile_iter; - } - - // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) - smem_pipe_write = smem_pipe_read; - ++smem_pipe_read; - smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; - } - // Thread-level register gemm for k_block - cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), - accum); - cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), - tOrfc1g(cute::_, cute::_, k_block), accum_gate); - }); - } - - // load tail - cute::for_each(cute::make_int_sequence{}, - [&](auto WaitIndex) - { - k_tile_count--; - using WaitIndex_t = decltype(WaitIndex); - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - if (k_block == K_BLOCK_MAX - 1) - { - tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); - cute::cp_async_wait(); - __syncthreads(); - } - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), - tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); - if (k_block == 0) - { - // only update smem_pipe_read - ++smem_pipe_read; - smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; - } - // Thread-level register gemm for k_block - cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), - tOrfc1(cute::_, cute::_, k_block), accum); - cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), - tOrfc1g(cute::_, cute::_, k_block), accum_gate); - }); - }); - // mma tail - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), - tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); - // Thread-level register gemm for k_block - cute::gemm( - tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum); - cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), - tOrfc1g(cute::_, cute::_, k_block), accum_gate); - }); - // if (cute::thread0()) { - // cute::print(accum_gate(0, 0, 0)); - // printf("\n"); - // } - // (2) add bias if it has.. - if (params.ptr_bias != nullptr) - { - cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); - cute::Tensor gBias_gate = gBias_gate_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); - cute::Tensor tOgBias = thr_mma.partition_C(gBias); - cute::Tensor tOgBiasg = thr_mma.partition_C(gBias_gate); - for (int i = 0; i < cute::size(accum); i++) - { - accum(i) += tOgBias(i); - accum_gate(i) += tOgBiasg(i); - } - } - - // (3) calculate swiglu - using ActivationFn = typename KT::ActivationFn; - ActivationFn fn{}; - CUTLASS_PRAGMA_UNROLL - for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) - { - accum(temp_iter) = fn(accum_gate(temp_iter)) * accum(temp_iter); - } - - // (4) push all the result to smem - // (4.1) convert result from ElementAccum to ElementInput - cute::Tensor temp_accum = util_convert_type(accum); - // if (cute::thread0()) { - // cute::print(temp_accum(0, 0, 0)); - // printf("\n"); - // } - // (4.2) retile rf and smem for copy back.. - auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); - // cute::clear(sO); - cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); - cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); - - // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) - cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); - __syncthreads(); - - // (4.4) sO -> rO -> gO - - typename KT::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // - // remember, for all the threads in the same col, they have the same idx for bias.. - cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); - // cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row.. - auto tOsO = gmem_thr_copy_O.partition_S(sO); - auto tOgO = gmem_thr_copy_O.partition_D(gO); - // auto tOgBias = gmem_thr_copy_O.partition_D(gBias); - cute::Tensor cOutput = cute::make_identity_tensor( - cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); - cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); - cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); - cute::copy(gmem_tiled_copy_O, tOsO, tOrO); - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < cute::size<1>(tOgO); ++m) - { - if (cute::get<0>(tOcO(0, m, 0)) < residue_m) - { - cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); - } - } - } -}; - -template -struct Fused_Moe_Kernel_routine_sm80> -{ - - using KT = Fused_Moe_Kernel_traits_sm80; - using Params = Routine_Params; - - CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) - { - using X = cute::Underscore; - - int const M = gemm_m; - int const N1 = params.gemm_n; - int const K1 = params.gemm_k; - bool const bias_is_broadcast = params.bias_is_broadcast; - - int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]); - typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; - typename KT::ElementWeight const* ptr_fc1_ = params.ptr_fc1 + problem_index * N1 * K1; - typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr) - ? nullptr - : (bias_is_broadcast ? params.ptr_bias + problem_index * N1 : params.ptr_bias + row_jump * N1); - typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; - - cute::Tensor mInput_mk - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), - cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); - - cute::Tensor mfc1_nk - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), - cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); - - cute::Tensor mBias_mn = cute::make_tensor( - cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), - cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1, - cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. - - cute::Tensor mOutput_mn - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), - cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); - - cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) - cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) - - cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) - - cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) - - return cute::make_tuple(gInput_mk, gfc1_nk, gBias_mn, gOutput_mn); - } - - // be careful, m_idx will change when use another tile shape.. - CUTE_DEVICE void run_routine( - Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) - { - extern __shared__ char smem_[]; - typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); - int const thread_idx = threadIdx.x; - bool const bias_is_broadcast = params.bias_is_broadcast; - // gmem tensor partition .. - auto [gInput_mk, gfc1_nk, gBias_mn, gOutput_mn] = gmem_tensor_init(problem_index, gemm_m, params); - int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); - auto const n_tile_count = cute::size<2>(gfc1_nk); - - // smem tensor .. - cute::Tensor sInput = cute::make_tensor( - cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) - cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), - typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) - cute::Tensor sO = cute::make_tensor( - cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) - - // (1) first step, get the fc1_res and fc1_gate - - // (1.1) get partition for gmem -> smem - cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) - cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) - - typename KT::GmemTiledCopyA gmem_tiled_copy_A; - typename KT::GmemTiledCopyB gmem_tiled_copy_B; - auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); - auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); - - cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) - cute::Tensor tInputsInput = gmem_thr_copy_A.partition_S(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) - cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) - cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) - - // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) - cute::Tensor tInputpInput - = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), - cute::Stride{}); - // Construct identity layout for sInput - cute::Tensor cInput = make_identity_tensor( - make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - - // Repeat the partitioning with identity layouts - cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - - // Set predicates for m bounds - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < cute::size<0>(tInputpInput); ++m) - { - tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m - } - - // (1.2) prefetch gmem -> smem - cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. - auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 - int k_tile_count = cute::size<2>(gInput); - CUTLASS_PRAGMA_UNROLL - for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) - { - if (k_tile_count <= 0) - { - cute::clear(tInputpInput); - } - // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); - // use copy_if - cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - tInputsInput(cute::_, cute::_, cute::_, k_pipe)); - cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); - cute::cp_async_fence(); - k_tile_count--; - if (k_tile_count > 0) - { - ++k_tile_iter; - } - } - - // (1.3) get partition for rf - typename KT::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) - cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) - - cute::Tensor accum - = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) - cute::clear(accum); - // checkout the shape - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N - CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K - CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); - CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); - - // (1.4)retiling the smem and rf for copy.. - auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); - auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); - cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) - cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) - CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M - CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K - - auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); - auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); - cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) - cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) - CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N - CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K - - // (1.5) mainloop - // Current pipe index in smem to read from - int smem_pipe_read = 0; - // Current pipe index in smem to write to - int smem_pipe_write = KT::Stages - 1; - - cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - - constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); - // prefetch register pipeline - if constexpr (K_BLOCK_MAX > 1) - { - cute::cp_async_wait(); - __syncthreads(); - - // Prefetch the first rmem from the first k-tile - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), - tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), - tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); - } - // k loop for mainloop - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - if (k_block == K_BLOCK_MAX - 1) - { - tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - cute::cp_async_wait(); - __syncthreads(); - } - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - // Copy gmem to smem before computing gemm on each k-pipe - if (k_block == 0) - { - // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::copy_if(gmem_tiled_copy_A, tInputpInput, - tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::cp_async_fence(); - if (k_tile_count - 1 > 0) - { - ++k_tile_iter; - } - - // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) - smem_pipe_write = smem_pipe_read; - ++smem_pipe_read; - smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; - } - // Thread-level register gemm for k_block - cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), - accum); - }); - } - // load tail - cute::for_each(cute::make_int_sequence{}, - [&](auto WaitIndex) - { - k_tile_count--; - using WaitIndex_t = decltype(WaitIndex); - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - if (k_block == K_BLOCK_MAX - 1) - { - tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - cute::cp_async_wait(); - __syncthreads(); - } - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - if (k_block == 0) - { - // only update smem_pipe_read - ++smem_pipe_read; - smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; - } - // Thread-level register gemm for k_block - cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), - tOrfc1(cute::_, cute::_, k_block), accum); - }); - }); - // mma tail - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - // Thread-level register gemm for k_block - cute::gemm( - tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum); - }); - // if (cute::thread0()) { - // cute::print(accum_gate(0, 0, 0)); - // printf("\n"); - // } - // (2) add bias if it has.. - if (params.ptr_bias != nullptr) - { - cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); - cute::Tensor tOgBias = thr_mma.partition_C(gBias); - for (int i = 0; i < cute::size(accum); i++) - { - accum(i) += tOgBias(i); - } - } - // (3) calculate swiglu - using ActivationFn = typename KT::ActivationFn; - ActivationFn fn{}; - CUTLASS_PRAGMA_UNROLL - for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) - { - accum(temp_iter) = fn(accum(temp_iter)); - } - - // (4) push all the result to smem - // (4.1) convert result from ElementAccum to ElementInput - cute::Tensor temp_accum = util_convert_type(accum); - // if (cute::thread0()) { - // cute::print(temp_accum(0, 0, 0)); - // printf("\n"); - // } - // (4.2) retile rf and smem for copy back.. - auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); - // cute::clear(sO); - cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); - cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); - - // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) - cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); - __syncthreads(); - - // (4.4) sO -> rO -> gO - - typename KT::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // - cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); - auto tOsO = gmem_thr_copy_O.partition_S(sO); - auto tOgO = gmem_thr_copy_O.partition_D(gO); - cute::Tensor cOutput = cute::make_identity_tensor( - cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); - cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); - cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); - cute::copy(gmem_tiled_copy_O, tOsO, tOrO); - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < cute::size<1>(tOgO); ++m) - { - if (cute::get<0>(tOcO(0, m, 0)) < residue_m) - { - cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); - } - } - } -}; - -} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh deleted file mode 100644 index b4c90085db..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh +++ /dev/null @@ -1,215 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace fused_moe -{ -template -struct Routine_Arguments -{ - ElementInput* ptr_input{}; - ElementWeight* ptr_fc1{}; - ElementInput* ptr_bias{}; - ElementOutput* ptr_output{}; - int64_t const* total_tokens_including_expert{}; - int gemm_n{}; - int gemm_k{}; - int num_expert{}; - bool bias_is_broadcast{}; -}; - -template -struct Routine_Params -{ - ElementInput* ptr_input{}; - ElementWeight* ptr_fc1{}; - ElementInput* ptr_bias{}; - ElementOutput* ptr_output{}; - int64_t const* total_tokens_including_expert{}; - int gemm_n{}; - int gemm_k{}; - int num_expert{}; - bool bias_is_broadcast{}; -}; - -enum class Activation_Type -{ - Gelu = 0, - Relu, - Silu, - Swiglu, - Geglu, - Identity, - InvalidType -}; - -constexpr bool isGateActivation(Activation_Type const& activation_type) -{ - return activation_type == Activation_Type::Swiglu || activation_type == Activation_Type::Geglu; -} - -template -constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) -{ - return Activation_Type::InvalidType; -} - -template <> -constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) -{ - return Activation_Type::Identity; -} - -template <> -constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) -{ - return Activation_Type::Relu; -} - -template <> -constexpr Activation_Type EpilogueRouting(bool is_gate) -{ - return is_gate ? Activation_Type::Swiglu : Activation_Type::Silu; -} - -template <> -constexpr Activation_Type EpilogueRouting(bool is_gate) -{ - return is_gate ? Activation_Type::Geglu : Activation_Type::Gelu; -} - -/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/ -template -struct Fused_Moe_Kernel_traits_sm80 -{ - using ElementInput = ElementInput_; - using ElementWeight = ElementWeight_; - using ElementAccum = float; - using ElementOutput = ElementOutput_; - - using index_t = uint32_t; - static_assert(TileM_ % 16 == 0); - static_assert(TileN_ % 32 == 0); - static_assert(TileK_ % 32 == 0); - static constexpr int Stages = Stages_; - static constexpr int kTileM = TileM_; - static constexpr int kTileN = TileN_; - static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64); - - // tile shape - using TileShape = cute::Shape, cute::Int, cute::Int>; - static constexpr int kWarpsCount = 4; - static constexpr int kThreadCount = kWarpsCount * 32; - - // MMA atom arch and layout - using MMA_Atom_Arch = std::conditional_t, - cute::MMA_Atom, cute::MMA_Atom>; - // using ValLayoutMNK = cute::Layout>; - using ThreadLayoutMNK - = std::conditional_t, cute::_1>>, - cute::Layout, cute::_1>>>; - using ValLayoutMNK = std::conditional_t, - cute::Tile>; - using TiledMma = cute::TiledMMA; // 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4 - static constexpr int kAlignment = 8; - static constexpr int kBlcokKSmem = (kTileM == 16) ? 64 : 32; - // A memory copy operand - using DefaultOperandA - = DefaultGemm_TensorOpSm80_OperandA; - using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; - using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; - using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; - - // B memory copy operand - using DefaultOperandB - = DefaultGemm_TensorOpSm80_OperandB; - using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; - using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; - using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; - - // Output memory copy operand - using SmemLayoutAtomO = SmemLayoutAtomA; - using SmemCopyAtomO = cute::Copy_Atom; - static constexpr int kGmemElementPerLoad = sizeof(cute::uint128_t) / sizeof(ElementOutput); - static constexpr int kGmemTrheadsPerRow = kBlcokKSmem / kGmemElementPerLoad; - using GmemLayoutAtomO - = cute::Layout, cute::Int>, - cute::Stride, cute::_1>>; - using GmemTiledCopyO = decltype(cute::make_tiled_copy(cute::Copy_Atom{}, - GmemLayoutAtomO{}, cute::Layout>{})); - - static_assert(cute::rank(SmemLayoutAtomA{}) == 2); - static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M - static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K - static_assert(cute::rank(SmemLayoutAtomB{}) == 2); - static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N - static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K - - using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{}, - cute::make_shape( - cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_M, BLK_K, Stages - using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{}, - cute::make_shape( - cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_N, BLK_K, Stages - using SmemLayoutO = decltype(cute::tile_to_shape( - SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N - - // we need at least 2 stages.. - static_assert(Stages >= 2); - - struct SharedStorageNormal : cute::aligned_struct<128> - { - cute::array_aligned> smem_input; - cute::array_aligned> smem_fc1_weight; - cute::array_aligned> smem_o; - }; - - struct SharedStorageGate : cute::aligned_struct<128> - { - cute::array_aligned> smem_input; - cute::array_aligned> smem_fc1_gate_weight; - cute::array_aligned> smem_fc1_weight; - cute::array_aligned> smem_o; - }; - - using SharedStorage = std::conditional_t; - - using ActivationFn = std::conditional_t, - std::conditional_t, - std::conditional_t, cutlass::epilogue::thread::Identity>>>; - - static constexpr int kSmemSize = static_cast(sizeof(SharedStorage)); - - static constexpr bool can_implement(int const avaliable_smem_size) - { - return avaliable_smem_size > kSmemSize; - } - - // #endif -}; -} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h deleted file mode 100644 index 80a4d85608..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Scheduler for grouped GEMM -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -#include "cutlass/matrix_coord.h" - -#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" -#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -/// Visitor class to abstract away the algorithm for iterating over tiles -template -struct GemmMoeProblemVisitor - : public MoeProblemVisitor, ThreadblockShape, - GroupScheduleMode_, PrefetchTileCount, ThreadCount> -{ - - static bool const kTransposed = Transposed; - - using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; - using Base - = MoeProblemVisitor; - using Params = typename Base::Params; - using SharedStorage = typename Base::SharedStorage; - - // - // Methods - // - CUTLASS_DEVICE - GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) - : Base(params_, shared_storage_, block_idx) - { - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp deleted file mode 100644 index 3a084ee04f..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp +++ /dev/null @@ -1,70 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel -{ - -//////////////////////////////////////////////////////////////////////////////// - -/* - * Stateless universal device GEMM kernel type that treats GEMM as - * a composition of a collective mainloop and a collective epilogue. - * - * Supports both the 2.x and 3.x APIs based on whether the first type is - * a cute::tuple<> or not. - * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h - * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp - * - * In the following declaration, the name preceding the 'Or' refers to - * 3.x API type argument order, and the name succeeding the 'Or' refers to - * 2.x API type argument order. Template arguments without two names - * belong to the 3.x API only. - **/ -template -class GemmUniversalGated; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel - -//////////////////////////////////////////////////////////////////////////////// - -#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp" -#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp" -//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h deleted file mode 100644 index 0650ca8ded..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h +++ /dev/null @@ -1,585 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief GEMM kernel to support the epilogue visitor model - for customized softmax partial reduction epilogue fusion. - - This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once - its usage has been stabilized. For now, it is included in this example to demonstrate - some basic output fusion options. - - original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h -*/ - -#pragma once - -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" -#include "cutlass/trace.h" - -#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" - -namespace tk = tensorrt_llm::common; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct GemmWithEpilogueVisitor -{ -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueVisitor = typename Epilogue::Visitor; - using ThreadblockSwizzle = ThreadblockSwizzle_; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using TensorRefA = TensorRef; - - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using TensorRefB = TensorRef; - - using ElementCompute = typename EpilogueVisitor::ElementCompute; - using LayoutAlphaCol = cutlass::layout::RowMajor; - using LayoutAlphaRow = cutlass::layout::ColumnMajor; - using TensorRefAlphaCol = TensorRef; - using TensorRefAlphaRow = TensorRef; - - using ElementC = typename EpilogueVisitor::ElementOutput; - using LayoutC = typename Epilogue::Layout; - using TensorRefC = TensorRef; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformB; - using Operator = typename Mma::Operator; - - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - using EpilogueOutputOp = - typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - /// Split-K preserves splits that are 128b aligned - static int const kSplitKAlignment - = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); - - // - // Structures - // - - /// Argument structure - struct Arguments - { - - // - // Data members - // - - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; - - TensorRefA ref_A; - TensorRefB ref_B; - tk::QuantMode quant_option; - TensorRefAlphaCol ref_alpha_col; - TensorRefAlphaRow ref_alpha_row; - TensorRefC ref_C; - TensorRefC ref_D; - - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_D; - - typename EpilogueVisitor::Arguments epilogue_visitor; - - // - // Methods - // - - Arguments() - : mode(GemmUniversalMode::kGemm) - , batch_count(1) - { - } - - /// constructs an arguments structure - Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, - TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, - TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, - int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_) - : mode(mode_) - , problem_size(problem_size_) - , batch_count(batch_count_) - , ref_A(ref_A_) - , ref_B(ref_B_) - , quant_option(quant_option_) - , ref_alpha_col(ref_alpha_col_) - , ref_alpha_row(ref_alpha_row_) - , ref_C(ref_C_) - , ref_D(ref_D_) - , batch_stride_A(batch_stride_A_) - , batch_stride_B(batch_stride_B_) - , batch_stride_D(0) - , epilogue_visitor(epilogue_visitor_) - { - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params - { - - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorB::Params params_B; - typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; - typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; - typename EpilogueVisitor::OutputTileIterator::Params params_C; - typename EpilogueVisitor::OutputTileIterator::Params params_D; - - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; - - void* ptr_A; - void* ptr_B; - tk::QuantMode quant_option; - typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; - typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; - ElementC* ptr_C; - ElementC* ptr_D; - - int64_t batch_stride_A; - int64_t batch_stride_B; - - typename EpilogueVisitor::Params epilogue_visitor; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : swizzle_log_tile(0) - , params_A(0) - , params_B(0) - , params_alpha_col(0) - , params_C(0) - , params_D(0) - , batch_count(0) - , gemm_k_size(0) - , mode(cutlass::gemm::GemmUniversalMode::kGemm) - , ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_alpha_col(nullptr) - , ptr_alpha_row(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , batch_stride_A(0) - , batch_stride_B(0) - { - } - - Params( - Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) - : problem_size(args.problem_size) - , swizzle_log_tile(0) - , params_A(args.ref_A.layout()) - , params_B(args.ref_B.layout()) - , params_alpha_col(args.ref_alpha_col.layout()) - , params_alpha_row(args.ref_alpha_col.layout()) - , params_C(args.ref_C.layout()) - , params_D(args.ref_D.layout()) - , mode(args.mode) - , batch_count(args.batch_count) - , gemm_k_size(args.problem_size.k()) - , ptr_A(args.ref_A.data()) - , ptr_B(args.ref_B.data()) - , quant_option(args.quant_option) - , ptr_alpha_col(args.ref_alpha_col.data()) - , ptr_alpha_row(args.ref_alpha_row.data()) - , ptr_C(args.ref_C.data()) - , ptr_D(args.ref_D.data()) - , batch_stride_A(args.batch_stride_A) - , batch_stride_B(args.batch_stride_B) - , epilogue_visitor(args.epilogue_visitor) - { - - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); - - if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) - { - - int const kAlignK - = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); - - gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); - - if (gemm_k_size) - { - grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); - } - } - - swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); - } - }; - - /// Shared memory storage structure - union SharedStorage - { - - typename Mma::SharedStorage main_loop; - - struct - { - typename Epilogue::SharedStorage epilogue; - typename EpilogueVisitor::SharedStorage visitor; - } epilogue; - }; - -public: - // - // Methods - // - - CUTLASS_DEVICE - GemmWithEpilogueVisitor() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - - CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); - - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; - - bool isAMisaligned = false; - bool isBMisaligned = false; - bool isCMisaligned = false; - - if (platform::is_same::value) - { - isAMisaligned = problem_size.k() % kAlignmentA; - } - else if (platform::is_same::value) - { - isAMisaligned = problem_size.m() % kAlignmentA; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isAMisaligned = problem_size.k() % kAlignmentA; - } - - if (platform::is_same::value) - { - isBMisaligned = problem_size.n() % kAlignmentB; - } - else if (platform::is_same::value) - { - isBMisaligned = problem_size.k() % kAlignmentB; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isBMisaligned = problem_size.k() % kAlignmentB; - } - - if (platform::is_same::value) - { - isCMisaligned = problem_size.n() % kAlignmentC; - } - else if (platform::is_same::value) - { - isCMisaligned = problem_size.m() % kAlignmentC; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isCMisaligned = problem_size.n() % kAlignmentC; - } - - if (isAMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); - return Status::kErrorMisalignedOperand; - } - - if (isBMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); - return Status::kErrorMisalignedOperand; - } - - if (isCMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); - return Status::kErrorMisalignedOperand; - } - - CUTLASS_TRACE_HOST(" returning kSuccess"); - - return Status::kSuccess; - } - - static Status can_implement(Arguments const& args) - { - return can_implement(args.problem_size); - } - - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; - } - -#define SPLIT_K_ENABLED 1 - - /// Executes one GEMM - CUTLASS_DEVICE - void run_kernel_(Params const& params, SharedStorage& shared_storage) - { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() - || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) - { - - return; - } - - int offset_k = 0; - int problem_size_k = params.problem_size.k(); - - ElementA* ptr_A = static_cast(params.ptr_A); - ElementB* ptr_B = static_cast(params.ptr_B); - -#if SPLIT_K_ENABLED - // - // Fetch pointers based on mode. - // - if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) - { - - if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) - { - - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } - - offset_k = threadblock_tile_offset.k() * params.gemm_k_size; - } - else if (params.mode == GemmUniversalMode::kBatched) - { - ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; - ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; - } - else if (params.mode == GemmUniversalMode::kArray) - { - ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; - ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; - } -#endif - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - offset_k, - }; - - cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); - - typename Mma::IteratorB iterator_B( - params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - // - // Construct the epilogue visitor - // - - EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, - params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, - params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C, - params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); - - if (params.mode == GemmUniversalMode::kGemm) - { - // Indicate which position in a serial reduction the output operator is currently updating - epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) - { - epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); - } - - // Construct the epilogue - Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - epilogue(epilogue_visitor, accumulators); - } - - template - CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) - { - if constexpr (platform::is_same::value) - { - run_kernel_(params, shared_storage); - } - else - { - CUTLASS_NOT_IMPLEMENTED(); - } - } - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { -#if defined(__CUDA_ARCH__) -#if (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 900) - // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. - run_kernel(params, shared_storage); -#else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); -#endif -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h deleted file mode 100644 index 6dc6ffc1a9..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ /dev/null @@ -1,143 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/* - This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is - quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices - to be consumed by CUTLASS. - - Note that for int4, ThreadBlockK MUST be 64. - - */ - -#pragma once - -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/platform/platform.h" - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/tile_interleaved_layout.h" - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -template -struct LayoutDetailsB -{ -}; - -// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. -// TODO - Switch this to column major for weights since gemms should be more performant. -template -struct LayoutDetailsB= 75>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -template -struct LayoutDetailsB= 75>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -template -struct LayoutDetailsB -{ - static constexpr int ThreadblockK = 64; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; - // for fast accumulation - // using Operator = cutlass::arch::OpMultiplyAddFastAccum; -}; - -// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, -// which signals that we want to dequantize after loading from smem. -template - struct LayoutDetailsB < TypeA, - uint8_t, Arch, - typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; -}; - -template - struct LayoutDetailsB < TypeA, - uint4b_t, Arch, - typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; -}; - -template -struct LayoutDetailsB= 90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -template -struct LayoutDetailsB= 90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh deleted file mode 100644 index aac2cb3579..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh +++ /dev/null @@ -1,185 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include -#include - -template -struct DefaultGemm_TensorOpSm80_OperandA; - -template -struct DefaultGemm_TensorOpSm80_OperandB; - -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::half_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::bfloat16_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -/// Operand A - Column-major (M-major) -template -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::half_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -template -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::bfloat16_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands - -// Operand B - Column-Major (K-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{ -}; - -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{ -}; - -// Operand B - Row-Major (N-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{ -}; - -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{ -}; - -// -// F16: 128-by-128-by-32 (small k-block) -// - -/// Operand A - Row-major (K-Major) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::half_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::bfloat16_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -template -CUTE_DEVICE auto util_convert_type(cute::Tensor const& tensor) -{ - using From_type = typename Engine::value_type; - constexpr int numel = decltype(cute::size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - // HACK: this requires tensor to be "contiguous" - auto frag = convert_op(*reinterpret_cast const*>(tensor.data())); - return cute::make_tensor(cute::make_rmem_ptr(&frag), tensor.layout()); -} - -template -CUTE_DEVICE void util_copy( - TiledCopy const& tiled_copy, cute::Tensor const& S, cute::Tensor& D) -{ - CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{}); - CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{}); - CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D)); - CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D)); - CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D)); - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < cute::size<1>(S); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < cute::size<2>(S); ++k) - { - cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k)); - } - } -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h deleted file mode 100644 index b708f7c28b..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h +++ /dev/null @@ -1,553 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief -*/ - -#pragma once - -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/trace.h" - -#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" -#include "cutlass_extensions/tile_interleaved_layout.h" - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// -// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. -// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. -template -using void_t = void; - -template -struct use_dq_gemm : platform::false_type -{ -}; - -template -struct use_dq_gemm> : platform::true_type -{ -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MoeFCGemm -{ -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = false; - - // Optional transpose - using MapArguments = kernel::detail::MapArguments; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - static_assert(!kTransposed, "Transpose problem not supported"); - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; - using ElementScale = ElementC; - - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ProblemVisitor - = GemmMoeProblemVisitor; - - // - // Structures - // - - /// Argument structure - struct Arguments - { - - // - // Data members - // - - int problem_count; - int threadblock_count; - int group_size; - - typename EpilogueOutputOp::Params output_op; - - ElementA* ptr_A; - ElementB* ptr_B; - ElementScale* weight_scales; - ElementC* ptr_C; - ElementC* ptr_D; - bool C_is_broadcast; - - int64_t const* total_tokens_including_expert; - int64_t gemm_n; - int64_t gemm_k; - - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : problem_count(0) - , threadblock_count(0) - , ptr_A(nullptr) - , ptr_B(nullptr) - , weight_scales(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , total_tokens_including_expert(nullptr) - , gemm_n(0) - , gemm_k(0) - , host_problem_sizes(nullptr) - , C_is_broadcast{true} - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, - ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C, - bool C_is_broadcast, ElementC* ptr_D, int64_t const* total_tokens_including_expert, int64_t gemm_n, - int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr) - : problem_count(problem_count) - , threadblock_count(threadblock_count) - , group_size(group_size) - , output_op(output_op) - , ptr_A(const_cast(ptr_A)) - , ptr_B(const_cast(ptr_B)) - , weight_scales(const_cast(weight_scales)) - , ptr_C(const_cast(ptr_C)) - , C_is_broadcast{C_is_broadcast} - , ptr_D(ptr_D) - , total_tokens_including_expert(total_tokens_including_expert) - , gemm_n(gemm_n) - , gemm_k(gemm_k) - , host_problem_sizes(nullptr) - { - if (platform::is_same::value || platform::is_same::value) - { - assert(weight_scales); - } - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params - { - - typename ProblemVisitor::Params problem_visitor; - int threadblock_count; - int group_size; - bool C_is_broadcast; - - typename EpilogueOutputOp::Params output_op; - - ElementA* ptr_A; - ElementB* ptr_B; - ElementScale* weight_scales; - ElementC* ptr_C; - ElementC* ptr_D; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : ptr_A(nullptr) - , ptr_B(nullptr) - , weight_scales(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , C_is_broadcast(true) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : problem_visitor( - args.total_tokens_including_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count) - , threadblock_count(args.threadblock_count) - , group_size(args.group_size) - , output_op(args.output_op) - , ptr_A(args.ptr_A) - , ptr_B(args.ptr_B) - , weight_scales(args.weight_scales) - , ptr_C(args.ptr_C) - , ptr_D(args.ptr_D) - , C_is_broadcast(args.C_is_broadcast) - { - } - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - { - - problem_visitor = typename ProblemVisitor::Params(args.total_tokens_including_expert, args.gemm_n, - args.gemm_k, args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - weight_scales = args.weight_scales; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - C_is_broadcast = args.C_is_broadcast; - } - }; - - /// Shared memory storage structure - union SharedStorage - { - typename ProblemVisitor::SharedStorage problem_visitor; - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - -public: - // - // Methods - // - - CUTLASS_DEVICE - MoeFCGemm() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - return Status::kSuccess; - } - - static Status can_implement(Arguments const& args) - { - if (platform::is_same::value || platform::is_same::value) - { - if (args.weight_scales == nullptr) - { - CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); - return Status::kInvalid; - } - } - else if (args.weight_scales != nullptr) - { - CUTLASS_TRACE_HOST( - "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); - return Status::kInvalid; - } - else if (args.group_size != args.gemm_k) - { - CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)"); - return Status::kInvalid; - } - // Handle the case the input is too short - else if (args.gemm_n < Mma::IteratorB::AccessType::kElements) - { - CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment"); - return Status::kInvalid; - } - return Status::kSuccess; - } - - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; - } - - CUTLASS_DEVICE - void run_kernel_(Params const& params, SharedStorage& shared_storage) - { - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - static_assert(platform::is_same::value && kInterleave == 1 - || platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // - // Problem visitor. - // - ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - - const int64_t gemm_k = params.problem_visitor.gemm_k; - const int64_t gemm_n = params.problem_visitor.gemm_n; - int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; - - // Outer 'persistent' loop to iterate over tiles - int loop = 0; - while (problem_visitor.next_tile()) - { - loop++; - - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - - cutlass::gemm::GemmCoord threadblock_offset( - int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); - - // Load element pointers. Exchange pointers and strides if working on the transpose - const int64_t rows_to_jump - = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; - ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; - typename LayoutA::LongIndex ldm_A = gemm_k; - - char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix; - ElementB* ptr_B = reinterpret_cast(byte_ptr_B); - typename LayoutB::LongIndex ldm_B - = platform::is_same::value ? gemm_n : gemm_k * kInterleave; - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - 0, - }; - - cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; - - cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); - - typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, - {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - auto CreateMMA = [&]() - { - if constexpr (use_dq_gemm::value) - return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); - else - return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - }; - Mma mma = CreateMMA(); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); - - if constexpr (use_dq_gemm::value) - { - const MatrixCoord scale_extent = {1, problem_size.n()}; - typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), - weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale); - - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); - } - else - { - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - } - - // - // Epilogue - // - - ElementC* ptr_C = reinterpret_cast(params.ptr_C) - + (params.C_is_broadcast ? problem_idx : rows_to_jump) * gemm_n; - ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - - // lora need to set as layout_C(gemm_n) - LayoutC layout_C = params.C_is_broadcast ? LayoutC(0) : LayoutC(gemm_n); - LayoutC layout_D(gemm_n); - - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn()); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn()); - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - if constexpr (platform::is_same>::value) - { - EpilogueOutputOp output_op(params.output_op, problem_idx); - epilogue(output_op, iterator_D, accumulators, iterator_C); - } - else - { - EpilogueOutputOp output_op(params.output_op); - epilogue(output_op, iterator_D, accumulators, iterator_C); - } - - // Next tile - problem_visitor.advance(gridDim.x); - } - } - - template - CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) - { - if constexpr (platform::is_same::value) - { - run_kernel_(params, shared_storage); - } - else - { - CUTLASS_NOT_IMPLEMENTED(); - } - } - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { -#if defined(__CUDA_ARCH__) -#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900) - constexpr bool isFp8 = platform::is_same::value - || platform::is_same::value; - if constexpr (isFp8) - { - run_kernel(params, shared_storage); - } - else - { // reuse sm80 kernel for other types, align with dispatchToArch - run_kernel(params, shared_storage); - } -#elif (__CUDA_ARCH__ >= 900) - run_kernel(params, shared_storage); -#else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); -#endif -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h deleted file mode 100644 index 796dc2fe78..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h +++ /dev/null @@ -1,344 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Base scheduler for grouped problems, using MoE -*/ - -#pragma once - -#include "cutlass/gemm/kernel/grouped_problem_visitor.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Visitor class to abstract away the algorithm for iterating over tiles -template -struct BaseMoeProblemVisitor -{ - using ThreadblockShape = ThreadblockShape_; - - struct ProblemInfo - { - static int32_t const kNoPrefetchEntry = -1; - int32_t problem_idx; - int32_t problem_start; - - CUTLASS_DEVICE - ProblemInfo() - : problem_idx(kNoPrefetchEntry) - , problem_start(kNoPrefetchEntry) - { - } - - CUTLASS_DEVICE - ProblemInfo(int32_t problem_idx_, int32_t problem_start_) - : problem_idx(problem_idx_) - , problem_start(problem_start_) - { - } - }; - - struct Params - { - int64_t const* last_row_for_problem; - int64_t gemm_n; - int64_t gemm_k; - int32_t problem_count; - void const* workspace; - int32_t tile_count; - - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - Params() - : last_row_for_problem(nullptr) - , gemm_n(0) - , gemm_k(0) - , problem_count(0) - , workspace(nullptr) - , tile_count(0) - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count, - void const* workspace = nullptr, int32_t tile_count = 0) - : last_row_for_problem(last_row_for_problem) - , gemm_n(gemm_n) - , gemm_k(gemm_k) - , problem_count(problem_count) - , workspace(workspace) - , tile_count(tile_count) - { - } - }; - - Params const& params; - int32_t tile_idx; - int32_t problem_tile_start; - int32_t problem_idx; - - // - // Methods - // - CUTLASS_DEVICE - BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) - : params(params_) - , tile_idx(block_idx) - , problem_tile_start(0) - , problem_idx(0) - { - } - - /// Get the grid shape - CUTLASS_HOST_DEVICE - static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem) - { - - return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), - ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); - } - - /// Gets the global tile index - CUTLASS_HOST_DEVICE - int32_t tile_index() const - { - return tile_idx; - } - - /// Gets the index of the problem - CUTLASS_HOST_DEVICE - int32_t problem_index() const - { - return problem_idx; - } - - CUTLASS_HOST_DEVICE - int32_t threadblock_idx() const - { - return tile_idx - problem_tile_start; - } - - CUTLASS_DEVICE - void advance(int32_t grid_size) - { - tile_idx += grid_size; - } - - CUTLASS_HOST_DEVICE - static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) - { - ProblemSizeHelper::possibly_transpose_problem(problem); - } - - /// Returns the problem size for the current problem - CUTLASS_HOST_DEVICE - cutlass::gemm::GemmCoord problem_size() const - { - return problem_size(problem_idx); - } - - CUTLASS_HOST_DEVICE - cutlass::gemm::GemmCoord problem_size(int idx) const - { - const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; - const int64_t current_problem_row = params.last_row_for_problem[idx]; - const int64_t gemm_m = current_problem_row - prev_problem_row; - GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); - ProblemSizeHelper::possibly_transpose_problem(problem); - return problem; - } - - CUTLASS_HOST_DEVICE - static int32_t tile_count(cutlass::gemm::GemmCoord const& grid) - { - return ProblemSizeHelper::tile_count(grid); - } - - static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count) - { - int32_t total_tiles = 0; - for (int32_t i = 0; i < problem_count; ++i) - { - auto problem = host_problem_sizes_ptr[i]; - possibly_transpose_problem(problem); - auto grid = grid_shape(problem); - total_tiles += tile_count(grid); - } - - return total_tiles; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MoeProblemVisitor; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// ProblemVisitor that performs all scheduling on device -// -template -struct MoeProblemVisitor : public BaseMoeProblemVisitor -{ - using Base = BaseMoeProblemVisitor; - using Params = typename Base::Params; - static int const kThreadCount = ThreadCount; - static bool const kRequiresPrecomputation = false; - static int const kThreadsPerWarp = 32; - - struct SharedStorage - { - }; - - // Final tile of the problem loaded by this thread. Each thread will hold - // a separate value. - int32_t problem_ending_tile; - - SharedStorage& shared_storage; - - // - // Methods - // - CUTLASS_DEVICE - MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) - : Base(params_, block_idx) - , problem_ending_tile(0) - , shared_storage(shared_storage_) - { - this->problem_idx = -1 * kThreadsPerWarp; - this->problem_tile_start = 0; - } - - CUTLASS_DEVICE - bool next_tile() - { - // Check whether the tile to compute is within the range of the current problem. - int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); - if (this->tile_idx < problem_tile_end) - { - return true; - } - - // Check whether the tile to compute is within the current group of problems fetched by the warp. - // The last tile for this group is the final tile of the problem held by the final thread in the warp. - int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); - - // Keep the starting problem for this group in `problem_idx`. This is done to reduce - // register pressure. The starting problem for this group is simply the first problem - // in the group most recently fetched by the warp. - int32_t& group_problem_start = this->problem_idx; - group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; - - // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce - // register pressure. - int32_t& group_tile_start = this->problem_tile_start; - - // Each thread in the warp processes a separate problem to advance until - // reaching a problem whose starting tile is less less than tile_idx. - while (group_tile_end <= this->tile_idx) - { - group_problem_start += kThreadsPerWarp; - if (group_problem_start > this->params.problem_count) - { - return false; - } - - // Since `group_tile_start` is a reference to `this->problem_tile_start`, this - // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` - // is also set here is used later in `next_tile`. - group_tile_start = group_tile_end; - - int lane_idx = threadIdx.x % kThreadsPerWarp; - int32_t lane_problem = group_problem_start + lane_idx; - - // Compute the number of tiles in the problem assigned to each thread. - problem_ending_tile = 0; - if (lane_problem < this->params.problem_count) - { - cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); - cutlass::gemm::GemmCoord grid = this->grid_shape(problem); - problem_ending_tile = this->tile_count(grid); - } - - // Compute a warp-wide inclusive prefix sum to compute the ending tile index of - // each thread's problem. - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < kThreadsPerWarp; i <<= 1) - { - int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); - if (lane_idx >= i) - { - problem_ending_tile += val; - } - } - - // The total tile count for this group is now in the final position of the prefix sum - int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); - - problem_ending_tile += group_tile_start; - group_tile_end += tiles_in_group; - } - - // The next problem to process is the first one that does not have ending tile position - // that is greater than or equal to tile index. - int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); - - this->problem_idx = group_problem_start + problem_idx_in_group; - - // The starting tile for this problem is the ending tile of the previous problem. In cases - // where `problem_idx_in_group` is the first problem in the group, we do not need to reset - // `problem_tile_start`, because it is set to the previous group's ending tile in the while - // loop above. - if (problem_idx_in_group > 0) - { - this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); - } - - return true; - } - - static size_t get_workspace_size( - cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count) - { - return 0; - } - - static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, - int32_t block_count, void* host_workspace_ptr) - { - } -}; - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp deleted file mode 100644 index e3d31a2c5b..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp +++ /dev/null @@ -1,646 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/arch/cluster_sm90.hpp" -#include "cute/tensor.hpp" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/trace.h" -#include "cutlass/workspace.h" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel -{ - -/////////////////////////////////////////////////////////////////////////////// - -template -class GemmUniversalGated - && CollectiveMainloop_::isGated>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - using Activation = typename CollectiveMainloop::Activation; - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(ArchTag::kMinComputeCapability >= 90); - - using TileSchedulerTag = TileScheduler_; - using TileScheduler = - typename detail::TileSchedulerSelector::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock - = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - /// Register requirement for Load and Math WGs - static constexpr uint32_t LoadRegisterRequirement = 40; - static constexpr uint32_t MmaRegisterRequirement = 232; - - // 1 stage ordered sequence between mainloop and epilogue producer load threads - using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; - - // Kernel level shared memory storage - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128> - { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16> - { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; - } pipelines; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments - { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params - { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - void* workspace{nullptr}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static Params to_underlying_arguments(Arguments const& args, void* workspace) - { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - auto problem_shape = args.problem_shape; - // if constexpr (detail::IF_SWAP_AB::value) { - // // swap M/N - // get<0>(problem_shape) = get<1>(args.problem_shape); - // get<1>(problem_shape) = get<0>(args.problem_shape); - // } - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) - { - CUTLASS_TRACE_HOST( - " WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; - - // Calculate workspace pointers - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* mainloop_workspace = nullptr; - // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used - // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means - // subtile will not be used, therefore separate reduction will not be enabled. - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, - ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); - - return {args.mode, problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, - scheduler, workspace}; - } - - static bool can_implement(Arguments const& args) - { - bool implementable = (args.mode == GemmUniversalMode::kGemm) - or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); - if (!implementable) - { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - return implementable; - } - - static size_t get_workspace_size(Arguments const& args) - { - size_t workspace_size = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, - cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) - { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - - status = TileScheduler::template initialize_workspace(args.scheduler, - workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, - NumEpilogueSubTiles); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) - { - return status; - } - - status = CollectiveEpilogue::initialize_workspace( - args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) - { - return status; - } - - return status; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 get_grid_shape(Params const& params) - { - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) - { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } - args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN - ? TileScheduler::RasterOrderOptions::AlongN - : TileScheduler::RasterOrderOptions::AlongM; - return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); - } - - static dim3 get_block_shape() - { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void operator()(Params const& params, char* smem_buf) - { - using namespace cute; - using X = Underscore; - -// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); -#else - - // Preconditions - static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); - static_assert(size<0>(TileShape{}) >= 128, - "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); - - static_assert(cute::rank(StrideA{}) == 3, - "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, - "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, - "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, - "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ - enum class WarpGroupRole - { - Producer = 0, - Consumer0 = 1, - Consumer1 = 2 - }; - enum class ProducerWarpRole - { - Mainloop = 0, - Warp1 = 1, - Epilogue = 2, - Warp3 = 3 - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - int mma_thread_idx = thread_idx % size(TiledMma{}); - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_idx == 0) && lane_predicate) - { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) - { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = size(TiledMma{}); - mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) - { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; - epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); - epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - typename LoadWarpOrderBarrier::Params params_load_order_barrier; - params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; - params_load_order_barrier.group_size = NumThreadsPerWarp; - LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - auto cluster_wait_fn = []() - { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) - { - cute::cluster_arrive_relaxed(); - return []() { cute::cluster_wait(); }; - } - else - { - __syncthreads(); - return []() {}; // do nothing - } - }(); - - // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - - TileScheduler scheduler{params.scheduler}; - auto work_tile_info = scheduler.get_current_work(); - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); - static_assert(cute::tuple_size_v >= 3, - "Output of load_init must have at least three elements (A, B, Aux)"); - - // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - Tensor gAux_xkl = get<2>(load_inputs); - - // Get pipeline stage increments from tensor shapes - auto k_tile_count = size<3>(gA_mkl); - - // Wait for all thread blocks in the Cluster - cluster_wait_fn(); - - if (warp_group_role == WarpGroupRole::Producer) - { - cutlass::arch::warpgroup_reg_dealloc(); - - // Mainloop Producer Warp - if (producer_warp_role == ProducerWarpRole::Mainloop) - { - bool do_load_order_arrive = true; - while (work_tile_info.is_valid()) - { - if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) - { - work_tile_info = fetch_next_work(work_tile_info, scheduler); - continue; - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the - // work. - auto work_k_tile_count - = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter - = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - - collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, block_rank_in_cluster, - shared_storage.tensors.mainloop); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(work_k_tile_count); - - // Signal for the epilogue load warp to begin - if (do_load_order_arrive) - { - load_order_barrier.arrive(); - do_load_order_arrive = false; - } - - // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - } // Mainloop Producer Warp End - - // Epilogue Producer Warp - else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) - { - while (work_tile_info.is_valid()) - { - if (!TileScheduler::requires_separate_reduction(params.scheduler)) - { - load_order_barrier.wait(); - } - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) - { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - epi_load_pipe_producer_state = collective_epilogue.load(epi_load_pipeline, - epi_load_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, - shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx()); - } - - // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Epilogue Producer Warp End - } // Producer Warp Group End - - else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - cutlass::arch::warpgroup_reg_alloc(); - - // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it - bool do_store_tail = false; - float scale_d0 = params.mainloop.scale_d0; - float scale_d1 = params.mainloop.scale_d1; - while (work_tile_info.is_valid()) - { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - auto work_k_tile_count - = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - - // Allocate the accumulators for the (M,N) blk_shape - // - // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. - auto accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) - auto accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) - if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) - { - collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, - accumulators1, work_k_tile_count, mma_thread_idx, shared_storage.tensors.mainloop, - params.mainloop); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, work_k_tile_count); - - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(work_k_tile_count); - } - // Index of warp group within consumer warp groups - int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; - - // Perform reduction across splits, if needed - TileScheduler::fixup( - params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx); - TileScheduler::fixup( - params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx); - - Activation elt_op; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accumulators0); i++) - { - accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); - } - - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) - { - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] - = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, - epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, - tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue, - work_tile_info.reduction_subtile_idx()); - epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; - epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; - do_store_tail = true; - } - - // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop - - if (do_store_tail) - { - collective_epilogue.store_tail( - epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, epi_store_pipe_producer_state); - } - } // Consumer Warp Groups End -#endif - } - -private: - // Kernel helper function to get next work unit - CUTLASS_DEVICE - typename TileScheduler::WorkTileInfo fetch_next_work( - typename TileScheduler::WorkTileInfo& work_tile_info, TileScheduler& scheduler) const - { - // Check whether we should continue on with the current work unit. If this is the case, - // the work unit will have been updated in continue_current_work to reflect the new - // tile to be computed. - if (scheduler.continue_current_work(work_tile_info)) - { - return work_tile_info; - } - - // Get next work tile - scheduler.advance_to_next_work(); - return scheduler.get_current_work(); - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp deleted file mode 100644 index 39886f2431..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp +++ /dev/null @@ -1,621 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/trace.h" -#include "cutlass/workspace.h" - -#include "cute/tensor.hpp" - -#include "cute/util/debug.hpp" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel -{ - -/////////////////////////////////////////////////////////////////////////////// - -template -class GemmUniversalGated - && CollectiveMainloop_::isGated>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - using Activation = typename CollectiveMainloop::Activation; - static_assert(ArchTag::kMinComputeCapability >= 90); - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(!cute::is_same_v, - "Ping-pong kernel does not currently support stream-K scheduler."); - using TileSchedulerTag = TileScheduler_; - using TileScheduler = - typename detail::TileSchedulerSelector::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = 2; - static constexpr uint32_t MaxThreadsPerBlock - = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - /// Register requirement for Load and Math WGs - static constexpr uint32_t LoadRegisterRequirement = 40; - static constexpr uint32_t MmaRegisterRequirement = 232; - - // 1 stage ordered sequence between mainloop and epilogue producer load threads - using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; - - // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue - static constexpr uint32_t StagesPerMathWarpGroup = 2; - using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier; - - // Kernel level shared memory storage - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128> - { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16> - { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; - alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; - } pipelines; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments - { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params - { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static Params to_underlying_arguments(Arguments const& args, void* workspace) - { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - (void) workspace; - auto problem_shape = args.problem_shape; - // if constexpr (detail::IF_SWAP_AB::value) { - // // swap M/N - // get<0>(problem_shape) = get<1>(args.problem_shape); - // get<1>(problem_shape) = get<0>(args.problem_shape); - // } - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) - { - CUTLASS_TRACE_HOST( - " WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; - - // Calculate workspace pointers - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* mainloop_workspace = nullptr; - - return {args.mode, problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, - TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)}; - } - - static bool can_implement(Arguments const& args) - { - bool implementable = (args.mode == GemmUniversalMode::kGemm) - or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); - if (!implementable) - { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - return implementable; - } - - static size_t get_workspace_size(Arguments const& args) - { - size_t workspace_size = 0; - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, - cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) - { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - status = TileScheduler::template initialize_workspace(args.scheduler, - workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) - { - return status; - } - - status = CollectiveEpilogue::initialize_workspace( - args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) - { - return status; - } - - return status; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 get_grid_shape(Params const& params) - { - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) - { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } - args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN - ? TileScheduler::RasterOrderOptions::AlongN - : TileScheduler::RasterOrderOptions::AlongM; - return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); - } - - static dim3 get_block_shape() - { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void operator()(Params const& params, char* smem_buf) - { - using namespace cute; - using X = Underscore; - -// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); -#else - - // Preconditions - static_assert(cute::rank(StrideA{}) == 3, - "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, - "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, - "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, - "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - enum class WarpGroupRole - { - Producer = 0, - Consumer0 = 1, - Consumer1 = 2 - }; - enum class ProducerWarpRole - { - Mainloop = 0, - Warp1 = 1, - Epilogue = 2, - Warp3 = 3 - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_idx == 0) && lane_predicate) - { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) - { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; - mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) - { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; - epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; - epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - typename LoadWarpOrderBarrier::Params params_load_order_barrier; - params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; - params_load_order_barrier.group_size = NumThreadsPerWarp; - LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); - - typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; - // DMA Load WG will not participate in these Ordered Barrier syncs - params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); - params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group - MathWarpGroupOrderBarrier math_wg_order_barrier( - shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - auto cluster_wait_fn = [&]() - { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) - { - cute::cluster_arrive_relaxed(); - return []() { cute::cluster_wait(); }; - } - else - { - __syncthreads(); - return []() {}; // do nothing - } - }(); - - // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); - static_assert(cute::tuple_size_v >= 3, - "Output of load_init must have at least three elements (A, B, Aux)"); - - // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - Tensor gAux_xkl = get<2>(load_inputs); - - // Get pipeline stage increments from tensor shapes - auto k_tile_count = size<3>(gA_mkl); - auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); - auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); - - TileScheduler scheduler{params.scheduler}; - - if (warp_group_role == WarpGroupRole::Consumer1) - { - // Advance 2nd Math WG to the next work tile for the startup - scheduler.advance_to_next_work(); - // Advance 2nd Math WG pipeline states to the end of 1st Math WG - mainloop_pipe_consumer_state.advance(k_tile_count); - epi_load_pipe_consumer_state.advance(c_tile_count); - epi_store_pipe_producer_state.advance(d_tile_count); - } - auto work_tile_info = scheduler.get_current_work(); - - // Wait for all thread blocks in the Cluster - cluster_wait_fn(); - - if (warp_group_role == WarpGroupRole::Producer) - { - cutlass::arch::warpgroup_reg_dealloc(); - - // Mainloop Producer Warp - if (producer_warp_role == ProducerWarpRole::Mainloop) - { - bool do_load_order_arrive = true; - while (work_tile_info.is_valid()) - { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); - - collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster, - shared_storage.tensors.mainloop); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(k_tile_count); - - // Signal for the epilogue load warp to begin - if (do_load_order_arrive) - { - load_order_barrier.arrive(); - do_load_order_arrive = false; - } - - // Get next work tile - scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - } // Mainloop Producer Warp End - - // Epilogue Producer Warp - else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) - { - load_order_barrier.wait(); - while (work_tile_info.is_valid()) - { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - epi_load_pipe_producer_state - = collective_epilogue.load(epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL, - blk_shape, blk_coord, tiled_mma, lane_idx, shared_storage.tensors.epilogue); - - // Get next work tile - scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Epilogue Producer Warp End - } // Producer Warp Group End - - else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - cutlass::arch::warpgroup_reg_alloc(); - - float scale_d0 = params.mainloop.scale_d0; - float scale_d1 = params.mainloop.scale_d1; - while (work_tile_info.is_valid()) - { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Allocate the accumulators for the (M,N) blk_shape - Tensor accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) - Tensor accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) - - // Order two Math WG's MMA one after the other, helps hide Epilogue - math_wg_order_barrier.wait(); - - collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, accumulators1, - k_tile_count, warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop); - - // Cue for next Math WG's MMA to start - math_wg_order_barrier.arrive(); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count); - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); - - Activation elt_op; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accumulators0); i++) - { - accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); - } - - // Order two Math WG's Epilogue one after the other - math_wg_order_barrier.wait(); - - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] - = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, - epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, - tiled_mma, warp_group_thread_idx, shared_storage.tensors.epilogue); - - // TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels - // we need to wait for all TMA stores to complete before issuing consumer order barrier arrives - // to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer. - auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] - = collective_epilogue.store_tail(epi_load_pipeline, epi_load_pipe_consumer_state_next, - epi_store_pipeline, epi_store_pipe_producer_state_next); - - // Update starting load/store pipeline states for the next tile - // state has already been incremented by 1 tile in collective calls, advance once again for ping pong - epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; - epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; - epi_load_pipe_consumer_state.advance(c_tile_count); - epi_store_pipe_producer_state.advance(d_tile_count); - - // Cue for next Math WG's Epilogue to start - math_wg_order_barrier.arrive(); - - // Get next work tile - scheduler.advance_to_next_work(NumMmaWarpGroups); - work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop - } // Consumer Warp Groups End -#endif - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h deleted file mode 100644 index 5e3531f093..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h +++ /dev/null @@ -1,494 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h -*/ - -#pragma once - -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/trace.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SplitkGemmGrouped -{ -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = Transposed; - - // Optional transpose - using MapArguments = kernel::detail::MapArguments; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; - - using ElementFinalOutput = typename MapArguments::ElementA; - - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ProblemVisitor - = GemmGroupedProblemVisitor; - - // - // Structures - // - - /// Argument structure - struct Arguments - { - - // - // Data members - // - - GemmCoord* problem_sizes; - int problem_count; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA** ptr_A; - ElementB** ptr_B; - ElementFinalOutput** ptr_C; - ElementFinalOutput** ptr_D; - - typename LayoutA::Stride::LongIndex* lda; - typename LayoutB::Stride::LongIndex* ldb; - typename LayoutC::Stride::LongIndex* ldc; - typename LayoutC::Stride::LongIndex* ldd; - - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // splitK - int split_k_slices; - int64_t* splitk_buffer_offsets; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : problem_count(0) - , threadblock_count(0) - , ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , lda(nullptr) - , ldb(nullptr) - , ldc(nullptr) - , ldd(nullptr) - , host_problem_sizes(nullptr) - , split_k_slices(1) - , splitk_buffer_offsets(nullptr) - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count, - typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C, - ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda, - typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc, - typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices, - int64_t* splitk_buffer_offsets) - : problem_sizes(problem_sizes) - , problem_count(problem_count) - , threadblock_count(threadblock_count) - , output_op(output_op) - , ptr_A(ptr_A) - , ptr_B(ptr_B) - , ptr_C(ptr_C) - , ptr_D(ptr_D) - , lda(lda) - , ldb(ldb) - , ldc(ldc) - , ldd(ldd) - , host_problem_sizes(host_problem_sizes) - , split_k_slices(split_k_slices) - , splitk_buffer_offsets(splitk_buffer_offsets) - { - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params - { - - typename ProblemVisitor::Params problem_visitor; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA** ptr_A; - ElementB** ptr_B; - ElementFinalOutput** ptr_C; - ElementFinalOutput** ptr_D; - ElementC* ptr_C_split; - ElementC* ptr_D_split; - - typename LayoutA::Stride::LongIndex* lda; - typename LayoutB::Stride::LongIndex* ldb; - typename LayoutC::Stride::LongIndex* ldc; - typename LayoutC::Stride::LongIndex* ldd; - - // - // Methods - // - - // splitk - GemmCoord grid_tiled_shape; - int swizzle_log_tile; - int gemm_k_size; - GemmCoord* host_problem_sizes; - int split_k_slices; - int64_t* splitk_buffer_offsets; - - CUTLASS_HOST_DEVICE - Params() - : ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , ptr_C_split(nullptr) - , ptr_D_split(nullptr) - , lda(nullptr) - , ldb(nullptr) - , ldc(nullptr) - , ldd(nullptr) - , swizzle_log_tile(0) - , gemm_k_size(0) - , host_problem_sizes(nullptr) - , split_k_slices(1) - , splitk_buffer_offsets(nullptr) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count) - , host_problem_sizes(args.host_problem_sizes) - , threadblock_count(args.threadblock_count) - , output_op(args.output_op) - , ptr_A(args.ptr_A) - , ptr_B(args.ptr_B) - , ptr_C(args.ptr_C) - , ptr_D(args.ptr_D) - , ptr_C_split((ElementC*) workspace) - , ptr_D_split((ElementC*) workspace) - , lda(args.lda) - , ldb(args.ldb) - , ldc(args.ldc) - , ldd(args.ldd) - , split_k_slices(args.split_k_slices) - , splitk_buffer_offsets(args.splitk_buffer_offsets) - { - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0], - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); - swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); - - // only support same k - int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK; - int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); - - gemm_k_size = gemm_k_iterations * Mma::Shape::kK; - } - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - { - - problem_visitor = - typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - ptr_C_split = workspace; - ptr_D_split = workspace; - - lda = args.lda; - ldb = args.ldb; - ldc = args.ldc; - ldd = args.ldd; - } - }; - - /// Shared memory storage structure - struct SharedStorage - { - union - { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - } kernel; - - // ProblemVisitor shared storage can't be overlapped with others - typename ProblemVisitor::SharedStorage problem_visitor; - }; - -public: - // - // Methods - // - - CUTLASS_DEVICE - SplitkGemmGrouped() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - return Status::kSuccess; - } - - static Status can_implement(Arguments const& args) - { - return Status::kSuccess; - } - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { - - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - - // - // Problem visitor. - // - ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) - { - - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); - - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - - // Load element pointers. Exchange pointers and strides if working on the transpose - ElementA* ptr_A - = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); - typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); - - ElementB* ptr_B - = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); - typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, - int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0); - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - threadblock_tile_offset.k() * params.gemm_k_size, - }; - - cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()}; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k; - if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) - { - problem_size_k = problem_size.k(); - } - else - { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); - - typename Mma::IteratorB iterator_B( - LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - - int lane_idx = threadIdx.x % 32; - - // - // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); - - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - ElementC* ptr_C = params.ptr_C_split; - ElementC* ptr_D = params.ptr_D_split; - - LayoutC layout_C(params.ldc[problem_idx]); - LayoutC layout_D(params.ldd[problem_idx]); - - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); - - // assume identity swizzle - MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n()); - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C); - - iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() - + gridDim.z * params.splitk_buffer_offsets[problem_idx]); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C); - iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() - + gridDim.z * params.splitk_buffer_offsets[problem_idx]); - - Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // Next tile - problem_visitor.advance(gridDim.x); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h deleted file mode 100644 index ed5e3e4daf..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h +++ /dev/null @@ -1,125 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ -//////////////////////////////////////////////////////////////////////////////// - -// We need to distinguish here, since we want volta support. It is too much effort -// to write shared memory iterators that are probably needed for volta to function -// properly. As a result, we allow converters both after the LDG (for volta) and after -// the LDS for Turing+. -template < - /// Iterator for B matrix in global memory - typename IteratorB, - /// Warp level Mma - typename MmaOperator, - /// Math operation perform by warp level operator - typename MathOperator> -struct SetConverters -{ -}; - -// Dequantize after LDG, so set transforms accordingly -template < - /// Iterator for B matrix in global memory - typename IteratorB, - /// Mma Policy - typename MmaOperator> -struct SetConverters -{ - using TransformAfterLDG - = FastInterleavedAndBiasedNumericArrayConverter; - - using TransformAfterLDS = NumericArrayConverter; -}; - -// Dequantize after LDS, so set transforms accordingly - -template < - /// Iterator for B matrix in global memory - typename IteratorB, - /// Mma Policy - typename MmaOperator> -struct SetConverters -{ - using TransformAfterLDG = NumericArrayConverter; - - using TransformAfterLDS - = FastInterleavedAndBiasedNumericArrayConverter; -}; - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale_, - /// Layout for the scale operand - typename LayoutScale_, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// - typename Enable = void> -struct DqMma; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h deleted file mode 100644 index 17c6346553..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +++ /dev/null @@ -1,302 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass_extensions/arch/mma.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" -#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" -#include "cutlass_extensions/tile_interleaved_layout.h" - -#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" -#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -template -struct DefaultScaleIteratorsMultistage; - -// Fine grained iterators -template -struct DefaultScaleIteratorsMultistage> -{ - using IteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, - Layout, 0, Alignment>; - - using SmemIteratorScale = IteratorScale; -}; - -// Per column iterators -template -struct DefaultScaleIteratorsMultistage> -{ - // ThreadMap for scale iterator - static_assert((MmaShape::kN % Alignment) == 0, ""); - -private: - using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, - MmaShape::kN / Alignment, Alignment>; - -public: - // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, - Element, Layout, 0, IteratorScaleThreadMap, Alignment>; - - using SmemIteratorScale = IteratorScale; -}; - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Type for element A - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Stages in GEMM - int kStages, - /// Operator performed by GEMM - typename Operator_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear> -struct DqMma= 80 && !layout::IsColumnMajorTileInterleave::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value - || platform::is_same::value, - "Element A must be fp16, fp8 or bf16"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, - AccessTypeB>; - - using ScaleIterators = DefaultScaleIteratorsMultistage; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converter = FastInterleavedAndBiasedNumericArrayConverter; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; -}; - -// Specialization to handle column major interleave B -template < - /// Type for element A - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Stages in GEMM - int kStages, - /// Operator performed by GEMM - typename Operator_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear> -struct DqMma= 80 && layout::IsColumnMajorTileInterleave::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value - || platform::is_same::value, - "Element A must be fp16, fp8 or bf16"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, - AccessTypeA>; - -private: - static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; - static constexpr int RowsPerTile = LayoutB::kRowsPerTile; - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); - - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); - - using GmemIteratorShape - = MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; - -public: - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator; - - using ScaleIterators = DefaultScaleIteratorsMultistage; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converter = FastInterleavedAndBiasedNumericArrayConverter; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h deleted file mode 100644 index 345cd2eec9..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +++ /dev/null @@ -1,284 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass_extensions/arch/mma.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" -#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" -#include "cutlass_extensions/tile_interleaved_layout.h" - -#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" -#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -template -struct DefaultScaleIteratorsPipelined; - -// Fine grained iterators -template -struct DefaultScaleIteratorsPipelined> -{ -private: - using SmemScaleType = half_t; - -public: - using IteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, - Layout, 0, Alignment>; - - using SmemIteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, - SmemScaleType, Layout, 0, Alignment>; -}; - -// Per column iterators -template -struct DefaultScaleIteratorsPipelined> -{ - static_assert((MmaShape::kN % Alignment) == 0, ""); - -private: - // ThreadMap for scale iterator - using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, - MmaShape::kN / Alignment, Alignment>; - using SmemScaleType = half_t; - -public: - // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, - Element, Layout, 0, IteratorScaleThreadMap, Alignment>; - - using SmemIteratorScale - = cutlass::transform::threadblock::PredicatedTileIterator, SmemScaleType, - Layout, 0, IteratorScaleThreadMap, Alignment>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Type for element A - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator_> -struct DqMma::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); - - static constexpr bool DqAfterLDG = platform::is_same::value; - using MmaCoreElementA = half_t; - using MmaCoreElementB = typename platform::conditional::type; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA>; - - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementB, LayoutB, 0, - typename MmaCore::IteratorThreadMapB, kAlignmentB>; - - using ScaleIterators = DefaultScaleIteratorsPipelined; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converters = SetConverters; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; -}; - -// Specialization to handle column major interleave B -template < - /// Type for element A - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator_> -struct DqMma::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - - static constexpr bool DqAfterLDG = platform::is_same::value; - using MmaCoreElementA = half_t; - using MmaCoreElementB = typename platform::conditional::type; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA>; - -private: - static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; - static constexpr int RowsPerTile = LayoutB::kRowsPerTile; - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); - - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); - - using GmemIteratorShape - = MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; - -public: - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator; - - // ThreadMap for scale iterator - static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); - using IteratorScaleThreadMap - = transform::PitchLinearStripminedThreadMap, - MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; - - using ScaleIterators = DefaultScaleIteratorsPipelined; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converters = SetConverters; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h deleted file mode 100644 index ad6c7496e1..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h +++ /dev/null @@ -1,351 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" -#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage -/// (stage>=3) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage -/// (stage>=3) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -#ifdef ENABLE_FP8 -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage -/// (stage>=3) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -#endif - -// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on -// large tile when not enough shared mem is present to do 3+ stage -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB> -struct DefaultMma -{ - - // Define the MmaCore components - // 3 is used on purpose here to trigger components for mma multistage - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, - GatherA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, - GatherB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h deleted file mode 100644 index 77af81005a..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ /dev/null @@ -1,353 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB> -struct DefaultMma -{ - -private: - // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. - static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; - using MmaElementA = typename platform::conditional::type; - using MmaElementB = typename platform::conditional::type; - -public: - // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; - - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; - - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, - typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; -}; - -// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on -// large tile when not enough shared mem is present to do 3+ stage -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB> -struct DefaultMma -{ - - // Define the MmaCore components - // 3 is used on purpose here to trigger components for mma multistage - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA, GatherA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB, GatherB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h deleted file mode 100644 index 1fb7f7eb28..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h +++ /dev/null @@ -1,257 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/mma_base.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass_extensions/weight_only_quant_op.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// -// SFINAE trick so I can keep the same loop code for Volta and dispatch to the -// correct warp level mma. On volta, all data is stored to shared memory as FP16. -template -CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, - typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C, - int const warp_tileB_k_offset) -{ - warp_mma(D, A, B, C); -} - -template -CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, - typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B, - typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset) -{ - warp_mma(D, A, B, C, warp_tileB_k_offset); -} - -//////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// The type of the scales - typename ElementScale_, - /// Number of stages, - int Stages, - /// The dequantizing op to be performed. - WeightOnlyQuantOp DequantOp, - /// Used for partial specialization, - typename Enable = bool> -class DqMmaBase -{ -public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Policy describing tuning details - using Policy = Policy_; - - ///< Type of the scale to be loaded - using ElementScale = ElementScale_; - - static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); - - // Finegrained scales get streamed in via cp.async - static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; - // We always have scales. - static constexpr int ScaleElementsPerStage = Shape::kN; - // We sometimes have a bias - static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; - - // - // Dependent types - // - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; - - /// Number of warp-level GEMM operations - static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - static constexpr int kNumKIterationsPerWarpBLoad - = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; - - static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); - static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage - { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA - = MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB - = MatrixShape; - - /// Shape of the shared memory buffer for the scales for the B matrix. - using ShapeScale = MatrixShape; - /// Shape of the shared memory buffer for the biases of the B matrix. - using ShapeZero = MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - /// Buffer to hold scales for threadblock - AlignedBuffer operand_scale; - - /// Buffer to hold scales for threadblock - AlignedBuffer operand_zero; - - public: - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() - { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() - { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() - { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() - { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - -protected: - // - // Data members - // - - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; - - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx) - , warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) - { - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h deleted file mode 100644 index 3c4036dd8c..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h +++ /dev/null @@ -1,110 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Data type for the scales - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = void> -class DqMmaMultistage; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" -#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h" diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h deleted file mode 100644 index f81961dee3..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h +++ /dev/null @@ -1,708 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Iterators over scales in global memory - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Layout of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Converter for B matrix applied immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear> -class DqMmaMultistage> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); - static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); - - /// Internal structure exposed for introspection. - struct Detail - { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA - = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB - = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -private: - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - /// The group size for quantization - int const group_size, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), - shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) - { - static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); - - typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); - typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); - - typename IteratorScale::AccessType* smem_scale_ptr - = reinterpret_cast(this->smem_iterator_scale_.get_scale()); - typename IteratorScale::AccessType* smem_zero_ptr - = reinterpret_cast(this->smem_iterator_scale_.get_zero()); - - int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; - - cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); - - if (gmem_zero_ptr != nullptr) - { - cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); - } - - if (iterator_scale.group_size_ == 64) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if (iterator_scale.group_size_ == 128) - { - if constexpr (Shape::kK == 128) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if constexpr (Shape::kK == 64) - { - if (iterator_scale.row_groupsize64_ & 0x1) - { - iterator_scale.add_tile_offset({1, 0}); - } - } - else - { - static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); - } - } - - iterator_scale.row_groupsize64_++; - - this->smem_iterator_scale_.add_tile_offset({1, 0}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance( - IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) - { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) - { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale operand in global memory - IteratorScale iterator_scale, - ///< initial value of accumulator - FragmentC const& src_accum) - { - - // - // Prologue - // - - TransformBAfterLDS lds_converter; - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) - { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // - - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) - { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - - typename IteratorA::AccessType zero_A; - zero_A.clear(); - - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename Dequantizer::FragmentScale warp_frag_scales; - typename Dequantizer::FragmentZero warp_frag_zeros; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - - warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - warp_dequantizer_.add_pointer_offset(Shape::kN); - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) - { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); - - using FragmentOperandB = cutlass::Array; - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; - static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - - using Converter - = cutlass::NumericArrayConverter; - - FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); - run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, - warp_tileB_k_compute_offset); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) - { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // This is the first group of a given stage, so we issue the loads for the B scales immediately. - if (group_start_iteration_B == 0) - { - copy_scales_and_advance(iterator_scale); - } - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) - { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - - // #committed) - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else - { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); - smem_read_stage_idx = 0; - } - else - { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); - } - } - - // Load the scale needed for the next tile iteration. - warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); - // Update internal pointer to set of scales in shared memory. - warp_dequantizer_.add_pointer_offset(Shape::kN); - } - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h deleted file mode 100644 index 83efdc5cb0..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h +++ /dev/null @@ -1,647 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Iterators over scales in global memory - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Layout of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear> -class DqMmaMultistage> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - // - // Dependent types - // - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail - { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA - = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB - = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -private: - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - ///< Group size for quantization. Not used by this main loop since it assumes per-column - int const group_size, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance( - IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) - { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) - { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale operand in global memory - IteratorScale iterator_scale, - ///< initial value of accumulator - FragmentC const& src_accum) - { - - // - // Prologue - // - - TransformBAfterLDS lds_converter; - - // NOTE - switch to ldg.sts - // Issue this first, so cp.async.commit_group will commit this load as well. - // Note: we do not commit here and this load will commit in the same group as - // the first load of A. - FragmentScale tb_frag_scales; - tb_frag_scales.clear(); - iterator_scale.load(tb_frag_scales); - this->smem_iterator_scale_.store(tb_frag_scales); - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) - { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // - - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) - { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - - typename IteratorA::AccessType zero_A; - zero_A.clear(); - - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename Dequantizer::FragmentScale warp_frag_scales; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - warp_dequantizer_.load(warp_frag_scales); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) - { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - - using FragmentOperandB = cutlass::Array; - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; - static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - - using Converter - = cutlass::NumericArrayConverter; - - FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); - run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, - warp_tileB_k_compute_offset); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) - { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) - { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - - // #committed) - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else - { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - smem_read_stage_idx = 0; - } - else - { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - } - } - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h deleted file mode 100644 index bd3e38971b..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +++ /dev/null @@ -1,106 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -#include "cutlass_extensions/gemm_configs.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Data type for the scales - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Converter for B matrix applied immediately after the LDG (before STS) - typename TransformBAfterLDG_, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_, - /// Used for partial specialization - typename Enable = void> -class DqMmaPipelined; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h" -#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h" diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h deleted file mode 100644 index 50bdd0d85b..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h +++ /dev/null @@ -1,486 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -#include "cutlass_extensions/gemm_configs.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Iterators over scales in global memory - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Layout of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Converter for B matrix applied immediately after the LDG (before STS) - typename TransformBAfterLDG_, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_> -class DqMmaPipelined> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - using TransformBAfterLDG = TransformBAfterLDG_; - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - // - // Dependent types - // - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); - - static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); - static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using WarpFragmentScale = typename Dequantizer::FragmentScale; - using WarpFragmentZero = typename Dequantizer::FragmentZero; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -protected: - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaPipelined(typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int const group_size, ///< The group size for quantization - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), - shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) - { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_scales_and_advance(IteratorScale& iterator_scale) - { - using TransformScale = NumericArrayConverter; - - FragmentScale tb_frag_scales; - FragmentScale tb_frag_zeros; - tb_frag_scales.clear(); - tb_frag_zeros.clear(); - - TransformScale transformScale; - - using FragmentElement = typename FragmentScale::Element; - - auto gmem_scale_ptr = iterator_scale.get_scale(); - auto gmem_zero_ptr = iterator_scale.get_zero(); - - arch::global_load(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); - - if (gmem_zero_ptr != nullptr) - { - arch::global_load( - tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); - } - - typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales); - typename TransformScale::result_type tb_frag_zeros_fp16; - if (gmem_zero_ptr != nullptr) - tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); - - auto frag_scale_ptr_fp16 = reinterpret_cast(&tb_frag_scales_fp16); - auto frag_zero_ptr_fp16 = reinterpret_cast(&tb_frag_zeros_fp16); - auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); - auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); - - if (iterator_scale.valid()) - { - auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); - arch::shared_store(smem_offset, frag_scale_ptr_fp16); - - if (gmem_zero_ptr != nullptr) - { - smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); - arch::shared_store(smem_offset, frag_zero_ptr_fp16); - } - } - - if (iterator_scale.group_size_ == 64) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if (iterator_scale.group_size_ == 128) - { - if constexpr (Shape::kK == 128) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if constexpr (Shape::kK == 64) - { - if (iterator_scale.row_groupsize64_ & 0x1) - { - iterator_scale.add_tile_offset({1, 0}); - } - } - else - { - static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); - } - } - - iterator_scale.row_groupsize64_++; - - this->smem_iterator_scale_.add_tile_offset({1, 0}); - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC& accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - IteratorScale iterator_scale, ///< iterator over scale operand in global memory - FragmentC const& src_accum) - { ///< source accumulator tile - - // - // Prologue - // - TransformBAfterLDG ldg_converter; - TransformBAfterLDS lds_converter; - - using TransformA - = NumericArrayConverter; - - // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want - // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. - TransformA transformA; - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A; - FragmentB tb_frag_B; - - tb_frag_A.clear(); - tb_frag_B.clear(); - - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - this->smem_iterator_A_.store(transformA(tb_frag_A)); - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - copy_scales_and_advance(iterator_scale); - - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - WarpFragmentScale warp_frag_scales; - WarpFragmentZero warp_frag_zero; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - - warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - warp_dequantizer_.add_pointer_offset(Shape::kN); - - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - iterator_scale.clear_mask(gemm_k_iterations <= 1); - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) - { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) - { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transformA(tb_frag_A)); - - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); - } - else - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k == 0) - { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - copy_scales_and_advance(iterator_scale); - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - iterator_scale.clear_mask(gemm_k_iterations <= 2); - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero); - run_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - } - - // Load the scales needed for the next tile iteration - warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); - // Update internal pointer to the set of scales in shared memory - warp_dequantizer_.add_pointer_offset(Shape::kN); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h deleted file mode 100644 index 316ea9f80a..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h +++ /dev/null @@ -1,399 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -#include "cutlass_extensions/gemm_configs.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Iterators over scales in global memory - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Layout of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Converter for B matrix applied immediately after the LDG (before STS) - typename TransformBAfterLDG_, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_> -class DqMmaPipelined> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - using TransformBAfterLDG = TransformBAfterLDG_; - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - // - // Dependent types - // - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -protected: - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaPipelined(typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation - ///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this - ///< argument is not added, it does not affect compilation for sm>=80. - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) - { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC& accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - IteratorScale iterator_scale, ///< iterator over scale operand in global memory - FragmentC const& src_accum) - { ///< source accumulator tile - - // - // Prologue - // - TransformBAfterLDG ldg_converter; - TransformBAfterLDS lds_converter; - - using TransformA - = NumericArrayConverter; - - using TransformScale = NumericArrayConverter; - - // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want - // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. - TransformA transformA; - TransformScale transformScale; - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A; - FragmentB tb_frag_B; - FragmentScale tb_frag_scales; - - using WarpFragmentScale = typename Dequantizer::FragmentScale; - WarpFragmentScale warp_frag_scales; - - tb_frag_A.clear(); - tb_frag_B.clear(); - tb_frag_scales.clear(); - - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - iterator_scale.load(tb_frag_scales); - - ++iterator_A; - ++iterator_B; - - this->smem_iterator_A_.store(transformA(tb_frag_A)); - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - __syncthreads(); - - warp_dequantizer_.load(warp_frag_scales); - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) - { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) - { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transformA(tb_frag_A)); - - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k == 0) - { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - run_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - } - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h deleted file mode 100644 index 350b247de2..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +++ /dev/null @@ -1,107 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/warp/default_mma_tensor_op.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" - -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for m-by-n-by-kgroup -template < - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A elements, - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp -{ - -private: - // Shape for computing the FP16s - using ComputeInstructionShape = InstructionShape_; - - // Chosen so we get K=16 for int8 and K=32 for int4. - static constexpr int LoadInstructionK = 128 / sizeof_bits::value; - - // Shape for loading the narrow data type from shared memory - using LoadInstructionShape = GemmShape; - -public: - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma, - cutlass::MatrixShape<1, 1>>; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h deleted file mode 100644 index 7c5088894b..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ /dev/null @@ -1,306 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" -#include "cutlass/arch/mma_sm89.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Instruction shape to override shared memory iterators with - typename SharedMemoryInstructionShape_, - /// Number of partitions along K dimension - int PartitionsK_ = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Used for partial specialization - typename Enable = bool> -class MmaTensorOpComputeBWithF16 -{ -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 80) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 89), - "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); - - static_assert(platform::is_same::value - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80) - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 89), - "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada"); - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; - - static_assert( - SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); - static_assert( - SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); - - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; - - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - -public: - /// Iterates over the A operand in memory - using IteratorA - = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, - MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = Array; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, - LayoutB, MatrixShape, Policy::OpDelta::kRow, - kThreadCount, kPartitionsK>; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = Array; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; - -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, - int const warp_tileB_k_offset) const - { - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " - "B"); - - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } - } -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } - } -#else - assert(0); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h deleted file mode 100644 index 1d5cd5d898..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +++ /dev/null @@ -1,463 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" - -#include "cutlass/functional.h" -#include "cutlass/platform/platform.h" - -#include "cutlass_extensions/weight_only_quant_op.h" -#include "tensorrt_llm/common/cudaBf16Wrapper.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Matrix multiply operator - typename MmaOperator_, - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity - Operand Operand, - /// Data type of Scale elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Number of threads participating in one matrix operation - int Threads, - /// - WeightOnlyQuantOp QuantOp_, - /// - typename Enable = void> -class MmaTensorOpDequantizer; - -//////////////////////////////////////////////////////////////////////////////// -// Bfloat specialization for Ampere -template < - /// Underlying matrix multiply operator (concept: MmaTensorOp) - typename MmaOperator_, - /// Shape of the warp level matrix multiply (concept: GemmShape) - typename Shape_, - /// - WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer= 80 - && platform::is_same::value>::type> -{ - -public: - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - - /// Type of the scales - using ElementScale = bfloat16_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentScale = Array; - using FragmentZero = Array; - - /// Warp mma shape - using Shape = Shape_; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_scale_ = smem_scales.data() + thread_offset; - if constexpr (hasZero(QuantOp)) - { - pointer_zero_ = smem_zeros.data() + thread_offset; - } - } - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) - { - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); - __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); - } - } -#else - // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should - // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid - // numerous conversion instructions in GEMM main loop. - arch::device_breakpoint(); -#endif - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag, FragmentScale& zero_frag) - { - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } - } - - CUTLASS_DEVICE - void dequantize( - FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) - { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); - __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); - - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); - __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); - __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); - } - } - } -#else - // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should - // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid - // numerous conversion instructions in GEMM main loop. - arch::device_breakpoint(); -#endif - } - - // Adds a pointer offset in units of elements. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_scale_ += offset; - pointer_zero_ += offset; - } - -private: - ElementScale const* pointer_scale_; - ElementScale const* pointer_zero_; -}; - -//////////////////////////////////////////////////////////////////////////////// - -// Specialization for Turing & Ampere -template < - /// Underlying matrix multiply operator (concept: MmaTensorOp) - typename MmaOperator_, - /// Shape of the warp level matrix multiply (concept: GemmShape) - typename Shape_, - /// - WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer= 75 - && platform::is_same::value>::type> -{ - -public: - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - - /// Type of the scales - using ElementScale = half_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentScale = Array; - using FragmentZero = Array; - - /// Warp mma shape - using Shape = Shape_; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_scale_ = smem_scales.data() + thread_offset; - if constexpr (hasZero(QuantOp)) - { - pointer_zero_ = smem_zeros.data() + thread_offset; - } - } - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) - { - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB - = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - multiplies mul_op; - - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag, FragmentScale& zero_frag) - { - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } - } - - CUTLASS_DEVICE - void dequantize( - FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) - { - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB - = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - multiplies mul_op; - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - - if constexpr (hasZero(QuantOp)) - { - plus plus_op; - - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] - = plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } - } - } - - // Adds a pointer offset in units of elements. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_scale_ += offset; - pointer_zero_ += offset; - } - -private: - ElementScale const* pointer_scale_; - ElementScale const* pointer_zero_; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h deleted file mode 100644 index 4acef2d180..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace tensorrt_llm -{ -namespace cutlass_extensions -{ -// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape -// in the kernel layout details when doing weight only quantization. -enum class CutlassTileConfig -{ - // Signals that we should run heuristics do choose a config - Undefined, - - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - // SiMT config - CtaShape128x128x8_WarpShape64x64x8, - - // TensorCore configs CTA_N = 128, CTA_K = 64 - // Warp configs for M=16 - CtaShape16x128x64_WarpShape16x32x64, - // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, - - // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, - CtaShape64x64x128_WarpShape32x64x64, - CtaShape64x128x64_WarpShape64x32x64, - - // Warp configs for M=128 - CtaShape128x64x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape64x64x64, - CtaShape128x128x64_WarpShape128x32x64, - CtaShape128x256x64_WarpShape64x64x64, - - // Warp configs for M=256 - CtaShape256x128x64_WarpShape64x64x64, - - // TensorCore config CTA_N = 64, CTA_K = 128 - CtaShape128x64x128_WarpShape64x32x128, - - // TensorCore config CTA_N = 256, CTA_K = 64 - CtaShape16x256x64_WarpShape16x64x64, - - // TensorCore config CTA_N = 256, CTA_K = 128 - CtaShape16x256x128_WarpShape16x64x128 - -}; - -enum class SplitKStyle -{ - NO_SPLIT_K, - SPLIT_K_SERIAL, - STREAM_K, // Sm80+ - // SPLIT_K_PARALLEL // Not supported yet -}; - -enum class CutlassTileConfigSM90 -{ - // Signals that we should run heuristics do choose a config - Undefined, - - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - // CTA configs for M=64 - CtaShape64x16x128B, - CtaShape64x32x128B, - CtaShape64x64x128B, - CtaShape64x128x128B, - CtaShape64x256x128B, - - // CTA configs for M=128 - CtaShape128x16x128B, - CtaShape128x32x128B, - CtaShape128x64x128B, - CtaShape128x128x128B, - CtaShape128x256x128B, - - // CTA configs for M=128 - CtaShape256x128x128B, -}; - -enum class MainloopScheduleType -{ - AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this - // defaults to the "legacy" main loop schedule. -}; - -enum class EpilogueScheduleType -{ - AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For - // architectures older than hopper, the epilogue is always performed by the same thread block as the main loop. -}; - -enum class ClusterShape -{ - ClusterShape_1x1x1, - ClusterShape_2x1x1, - ClusterShape_1x2x1, - ClusterShape_2x2x1, - ClusterShape_1x8x1, - ClusterShape_8x1x1 -}; - -struct CutlassGemmConfig -{ - enum CandidateConfigTypeParam : int - { - NONE = 0, - WEIGHT_ONLY = 1u << 0, - SIMT_ONLY = 1u << 1, - INT8_ONLY = 1u << 2, - HOPPER = 1u << 3, - GROUPED_GEMM = 1u << 4, - FP8_ONLY = 1u << 5, - }; - - CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; - SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; - int split_k_factor = -1; - int stages = -1; - - // config options for sm90 - CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; - MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; - EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; - ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; - bool is_sm90 = false; - - CutlassGemmConfig() {} - - CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) - : tile_config(tile_config) - , split_k_style(split_k_style) - , split_k_factor(split_k_factor) - , stages(stages) - , is_sm90(false) - { - } - - CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, - EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) - : tile_config_sm90(tile_config_sm90) - , mainloop_schedule(mainloop_schedule) - , epilogue_schedule(epilogue_schedule) - , cluster_shape(cluster_shape) - , is_sm90(true) - { - } - - std::string toString() const - { - std::stringstream tactic; - tactic << "Cutlass GEMM Tactic"; - if (tile_config_sm90 != tensorrt_llm::cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic) - { - assert(is_sm90 && "Invalid cutlass GEMM config"); - tactic << "\n\tstyle=TMA" - << "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\tcluster shape ID: " << (int) cluster_shape - << "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule; - } - else if (tile_config != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) - { - assert(!is_sm90 && "Invalid cutlass GEMM config"); - tactic << "\n\tstyle=compatible" - << "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages - << "\n\tsplit k: " << (int) split_k_factor; - } - else - { - tactic << "\n\tundefined"; - } - tactic << "\n"; - return tactic.str(); - } -}; - -inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) -{ - // clang-format off - if (config.is_sm90) - { - out << "tile_config_sm90_enum: " << int(config.tile_config_sm90) - << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) - << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) - << ", cluster_shape_enum: " << int(config.cluster_shape); - } - else - { - out << "tile_config_enum: " << int(config.tile_config) - << ", split_k_style_enum: " << int(config.split_k_style) - << ", split_k_factor: " << config.split_k_factor - << ", stages: " << config.stages; - } - // clang-format on - return out; -} - -} // namespace cutlass_extensions -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h deleted file mode 100644 index 44ba79680e..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +++ /dev/null @@ -1,447 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register -*/ - -#pragma once - -#include "cutlass/arch/arch.h" -#include "cutlass/array.h" -#include "cutlass/half.h" -#include "cutlass/numeric_types.h" - -namespace cutlass -{ - -// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low -// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally -// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. -// This converter will uninterleave the data and subtract the bias while converting to the result type. -template -struct FastInterleavedAndBiasedNumericArrayConverter -{ -}; - -template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); - - // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - - uint32_t* bf16_result_ptr = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - static constexpr uint32_t fp32_base = 0x4B000000; - float fp32_intermediates[4]; - - // Construct FP32s, bfloat does not have enough mantissa for IADD trick - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); - - // Subtract out fp32_base + 128 to make the unsigned integer signed. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 4; ++ii) - { - fp32_intermediates[ii] -= 8388736.f; - } - - // Truncate the fp32 representation and pack up as bfloat16s. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 2; ++ii) - { - bf16_result_ptr[ii] - = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); - } -#else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - result.clear(); // Suppress compiler warning - arch::device_breakpoint(); -#endif - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. - - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. - - // This is the half2 {1032, 1032} represented as an integer. - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - static constexpr uint32_t NEG_72 = 0xd480d480; - - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - - uint32_t* h = reinterpret_cast(&result); - uint32_t const source_i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; - - // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. - // No shift needed for first item. - uint32_t i4s = source_i4s; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - CUTLASS_PRAGMA_UNROLL - for (int ii = 1; ii < result_type::kElements / 2; ++ii) - { - i4s >>= sizeof_bits::value; - // (i4s & 0x000f000f) | 0x43004300 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[ii]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - } - - // This is the BF16 {-136, -136} represented as an integer. - static constexpr uint32_t BF16_BIAS = 0xC308C308; - static constexpr uint32_t BF16_ONE = 0x3F803F80; - - // Finally, we construct the output numbers. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < result_type::kElements / 2; ++ii) - { - // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction - asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); - } -#else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - arch::device_breakpoint(); - result.clear(); // Suppress compiler warning. -#endif - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h deleted file mode 100644 index 5a0cd29570..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h +++ /dev/null @@ -1,66 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines new layouts needed for MoE -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/pitch_linear_coord.h" - -namespace cutlass -{ -namespace layout -{ - -template -struct ColumnMajorTileInterleave -{ - static constexpr int kRowsPerTile = RowsPerTile; - static constexpr int kColumnsInterleaved = ColumnsInterleaved; -}; - -template -struct IsColumnMajorTileInterleave -{ - static constexpr bool value = false; -}; - -template -struct IsColumnMajorTileInterleave> -{ - static constexpr bool value = true; -}; - -} // namespace layout -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h deleted file mode 100644 index 6095925e37..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h +++ /dev/null @@ -1,250 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM - quantization. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" - -//////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace transform -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -template -class FineGrainedScaleZeroIterator; - -template -class FineGrainedScaleZeroIterator -{ -public: - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = 0; - static int const kAlignment = Alignment_; - - static int const kAccessesPerVector = 1; - - /// Row index of scales corresponding to the groupsize of 64 - int row_groupsize64_; - int group_size_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - using Pointer = Element*; - using NonConstPointer = typename platform::remove_const::type*; - - using AccessType = AlignedArray; - - using Fragment = cutlass::Array; - - // For compatibility with existing iterator interface - struct Params - { - LongIndex stride_ = 0; - - /// amount (in byte) to increment pointer from first access of current tile - /// to first access of next tile - LongIndex inc_advance_ = 0; - - // Default ctor - CUTLASS_HOST_DEVICE - Params() {} - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const& layout) - : stride_(layout.stride(0)) - { - inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; - } - }; - -private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char*; - -private: - // - // Data members - // - - /// Parameters object with precomputed internal state - Params const params_; - - /// Internal pointer to first access of tile - BytePointer pointer_scale_; - BytePointer pointer_zero_; - - bool is_valid_ = false; - -public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_DEVICE - FineGrainedScaleZeroIterator( - ///< Precomputed parameters object - Params const& params, - ///< Pointer to start of scale tensor - Pointer pointer_scale, - ///< Pointer to start of zero tensor - Pointer pointer_zero, - ///< Extent of the scale and bias - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const& threadblock_offset, - ///< Group size - int group_size) - : params_(params) - , pointer_scale_(reinterpret_cast(const_cast(pointer_scale))) - , pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) - { - row_groupsize64_ = threadblock_offset.row(); - group_size_ = group_size; - - const LongIndex tb_row_byte_offset - = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; - const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; - pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); - - if (pointer_zero_ != nullptr) - { - pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); - } - - static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; - - int const thread_row = thread_id / THREADS_PER_ROW; - int const thread_col = thread_id % THREADS_PER_ROW; - - const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; - const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; - pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); - if (pointer_zero_ != nullptr) - { - pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); - } - - // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on - // a given iteration. The same threads will be responsible for issues reads since the number of scales - // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ - // outside of the constructor. - int const global_row = threadblock_offset.row() + thread_row; - int const global_col = threadblock_offset.column() + thread_col * kAlignment; - - bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; - bool const col_in_bounds = global_col < extent.column(); - - is_valid_ = row_in_bounds && col_in_bounds; - } - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object - Pointer pointer_scale, ///< Pointer to start of scale tensor - Pointer pointer_zero, ///< Pointer to start of zero tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - int group_size) - : FineGrainedScaleZeroIterator( - params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) - { - } - - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const& tile_offset) - { - const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; - const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; - pointer_scale_ += row_byte_offset + col_byte_offset; - if (pointer_zero_ != nullptr) - { - pointer_zero_ += row_byte_offset + col_byte_offset; - } - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) - { - is_valid_ &= (!enable); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() const - { - return is_valid_; - } - - /// Returns a scale pointer - CUTLASS_HOST_DEVICE - AccessType* get_scale() const - { - return reinterpret_cast(pointer_scale_); - } - - /// Returns a zero pointer - CUTLASS_HOST_DEVICE - AccessType* get_zero() const - { - return reinterpret_cast(pointer_zero_); - } -}; - -} // namespace threadblock -} // namespace transform -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp deleted file mode 100644 index b430380b01..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp +++ /dev/null @@ -1,181 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/layout.hpp" -#include "cute/tensor.hpp" -#include "cute/util/print.hpp" - -using namespace cute; - -/// Function object that applies an index to its argument -template -struct IndexedGather -{ - CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {}) - : indices_(indices) - { - } - - template - CUTE_HOST_DEVICE constexpr auto operator()(I i) const - { - return indices_[i]; - } - - CUTE_HOST_DEVICE friend void print(IndexedGather const& s) - { - cute::print("Indexed{"); - print(s.indices_); - print("}"); - } - - Iter indices_; -}; - -/// Custom stride object that applies a function followed by a stride -template -struct CustomStride -{ - CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride) - : func_(func) - , stride_(stride) - { - } - - template - CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s) - { - return s.func_(i) * s.stride_; - } - - template - CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i) - { - return s.func_(i) * s.stride_; - } - - CUTE_HOST_DEVICE friend void print(CustomStride const& s) - { - cute::print("Custom{"); - print(s.func_); - cute::print(","); - print(s.stride_); - cute::print("}"); - } - - template - CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) - { - return CustomStride(s.func_, safe_div(s.stride_, div)); - } - - // Circumvent the requirement on make_layout that shape and stride are integral - template - CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride) - { - return Layout(shape, stride); - } - - Func func_; - Stride stride_; -}; - -template -CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func) -{ - // Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride - auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; }); - constexpr int I = decltype(idx)::value; - return make_layout( - repeat_like(stride, _1{}), replace(stride, CustomStride{static_cast(func), get(stride)})); -} - -/// Helper function to optionally create a gather tensor -template -CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func) -{ - Layout matrix_layout = make_identity_layout(shape); - auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); - Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); - return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); -} - -namespace cute -{ - -template -CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride) -{ - if constexpr (is_tuple::value) - { - return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s, d); }); - } - else if constexpr (is_scaled_basis::value) - { - if constexpr (Stride::mode() == I) - { - return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); - } - else - { - return make_layout(shape, stride); - } - } - else - { - return upcast(shape, stride); - } - - CUTE_GCC_UNREACHABLE; -} - -template -CUTE_HOST_DEVICE constexpr auto upcast( - ComposedLayout, Offset, Layout> const& layout) -{ - // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset - auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; }); - constexpr int I = decltype(idx)::value; - - // Upcast the outer layout (works as expected) - auto outer = upcast(layout.layout_a()); - - // Upcast the accumulated offset along stride-1 mode - auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); - - // Upcast the inner layout's shape along stride-1 mode - auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); - - return composition(outer, offset, inner); -} - -} // namespace cute diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h deleted file mode 100644 index 64774428e9..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h +++ /dev/null @@ -1,58 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - -namespace cutlass -{ - -enum class WeightOnlyQuantOp -{ - UNDEFINED, - PER_COLUMN_SCALE_ONLY, - FINEGRAINED_SCALE_ONLY, - FINEGRAINED_SCALE_AND_ZEROS -}; - -constexpr bool isFinegrained(WeightOnlyQuantOp op) -{ - return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; -} - -constexpr bool hasZero(WeightOnlyQuantOp op) -{ - return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; -} - -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h deleted file mode 100644 index f4eed277c1..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -namespace tensorrt_llm::kernels::cutlass_kernels -{ -template -void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, - ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert, - int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, - int* kernel_occupancy); -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl deleted file mode 100644 index 126e761ec9..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" - -#include "cute/tensor.hpp" -#include "cutlass/cutlass.h" - -#include -#include -#include - -namespace tensorrt_llm::kernels::cutlass_kernels -{ -template -void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, - ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert, - int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, - int* kernel_occupancy) -{ - constexpr auto activation_type = fused_moe::EpilogueRouting(true); - using GemmType = fused_moe::Fused_Moe_Kernel_sm80; - - // make sure GPU has enough resources.. - if (kernel_occupancy != nullptr) - { - constexpr int smem_size = GemmType::kSmemSize; - - if (smem_size > (48 << 10)) - { - cudaFuncAttributes attr{}; - int device = 0; - int max_smem_per_block = 0; - tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); - tensorrt_llm::common::check_cuda_error( - cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global)); - if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) - { - // This should mean that - // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - // smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the - // heuristic to ignore this configuration. - *kernel_occupancy = 0; - return; - } - } - - int max_active_blocks = -1; - tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, fused_moe::run_global, GemmType::kThreadCount, smem_size)); - *kernel_occupancy = max_active_blocks; - return; - } - int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks()); - int const threadblock_count = multi_processor_count * occupancy; - TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel"); - using Arguments = typename GemmType::Arguments; - Arguments args{{const_cast(A), const_cast(B), const_cast(biases), - reinterpret_cast(C), total_tokens_including_expert, static_cast(gemm_n), - static_cast(gemm_k), num_experts, bias_is_broadcast}, - num_experts, threadblock_count}; - auto params = GemmType::to_underlying_arguments(args); - if (GemmType::kSmemSize >= (48 << 10)) - { - cudaError_t result = cudaFuncSetAttribute( - fused_moe::run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize); - TLLM_CHECK_WITH_INFO(result == cudaSuccess, - "Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel"); - } - dim3 grid(params.threadblock_count, 1, 1); - dim3 block(GemmType::kThreadCount); - fused_moe::run_global<<>>(params); - auto result = cudaGetLastError(); - TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result)); -} -} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h deleted file mode 100644 index 91527fadb6..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h +++ /dev/null @@ -1,37 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" -#include - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -// Keep in sync with the signature generated by generate_kernels.py -template -void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts, - int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size); - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl deleted file mode 100644 index cca60a9816..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl +++ /dev/null @@ -1,348 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/tensor_ref.h" - -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" - -#include -#include -#include -#include - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion; - -// Hopper helper class for defining all the cutlass helper types -template -struct HopperGroupedGemmInfo -{ - using Arch = cutlass::arch::Sm90; - - // TODO Update once mixed input support is added - static_assert(cutlass::platform::is_same::value, - "CUTLASS does not currently have specialised SM90 support for quantized operations"); - -#ifdef ENABLE_FP8 - constexpr static bool IsFP8 - = cutlass::platform::is_same::value || cutlass::platform::is_same::value; -#else - constexpr static bool IsFP8 = false; -#endif - -#ifdef ENABLE_BF16 - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value - || cutlass::platform::is_same::value || IsFP8, - "Specialized for bfloat16, half, float, fp8"); -#else - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || IsFP8, - "Specialized for half, float, fp8"); -#endif - - static_assert(cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, - "Unexpected quantization type"); - - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - using ElementType = typename TllmToCutlassTypeAdapter::type; - - using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter::type; - // For legacy reasons we convert unsigned 8-bit to signed - using CutlassWeightTypeMaybeUint8 - = std::conditional_t, cutlass::int4b_t, - CutlassWeightTypeMaybeUint4>; - using CutlassWeightType - = std::conditional_t, int8_t, CutlassWeightTypeMaybeUint8>; - - using ElementA = ElementType; - using ElementB = CutlassWeightType; - - using ElementD = typename TllmToCutlassTypeAdapter>::type; - using ElementFinalOutput = typename TllmToCutlassTypeAdapter::type; - - // using ElementC = std::conditional_t; - // using ElementCNoVoid = std::conditional_t; - using ElementC = void; - using ElementCNoVoid = ElementD; - - using ElementAccumulator = float; - - using ElementBias = ElementFinalOutput; - using ElementRouterScales = float; - - // A matrix configuration - this is transposed and swapped with B - using LayoutA = HopperGroupedGemmInput::LayoutA; - constexpr static int AlignmentA - = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units - // of elements (up to 16 bytes) - - // B matrix configuration - this is transposed and swapped with A - using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand - constexpr static int AlignmentB - = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units - // of elements (up to 16 bytes) - - // C matrix configuration - using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand - using StrideC = HopperGroupedGemmInput::StrideC; - // Note we use ElementType here deliberately, so we don't break when BIAS is disabled - constexpr static int AlignmentC - = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units - // of elements (up to 16 bytes) - - // D matrix configuration - using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD; - using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD; - constexpr static int AlignmentD - = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix - // in units of elements (up to 16 bytes) - - static_assert(cutlass::platform::is_same::value, - "Hopper Grouped GEMM specialisation doesn't support fused activation"); - - using EpilogueOp - = cutlass::epilogue::fusion::LinearCombination; - - // TODO Add mode for fused activation once CUTLASS adds support - // using EpilogueSchedule = cutlass::platform::conditional_t< - // cutlass::platform::is_same::value, - // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, - // cutlass::epilogue::?????????????????? /// <<<<<< what supports activations - // >; - using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; - - // Epilogue For Default Finalize - using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< // - Arch, cutlass::arch::OpClassTensorOp, // - TileShape, ClusterShape, // - cutlass::epilogue::collective::EpilogueTileAuto, // - ElementAccumulator, ElementAccumulator, // - ElementC, LayoutC*, AlignmentC, // - ElementD, LayoutD*, AlignmentD, // - EpilogueSchedule>::CollectiveOp; - - // Epilogue For Fused Finalize - using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< // - TileShape, // - ElementCNoVoid, StrideC*, // - ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, // - ElementAccumulator, // - ElementAccumulator, // - ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, // - ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales // - >::CollectiveOp; - - using CollectiveEpilogue - = std::conditional_t; - - using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>; - - using KernelSchedule - = std::conditional_t; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< // - Arch, cutlass::arch::OpClassTensorOp, // - CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here - ElementType, LayoutA*, AlignmentA, // - ElementAccumulator, // - TileShape, ClusterShape, // - StageCountAutoCarveout, KernelSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal; - - using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter; -}; - -// Hopper specialised version -template -void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts, - int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size) -{ -#ifdef COMPILE_HOPPER_TMA_GEMMS - using namespace cute; - if constexpr (!should_filter_sm90_gemm_problem_shape_v) - { - using GemmInfo - = HopperGroupedGemmInfo; - - using ElementAccumulator = typename GemmInfo::ElementAccumulator; - using ElementA = typename GemmInfo::ElementA; - using ElementB = typename GemmInfo::ElementB; - using ElementC = typename GemmInfo::ElementC; - using ElementCNoVoid = typename GemmInfo::ElementCNoVoid; - using ElementD = typename GemmInfo::ElementD; - - using CollectiveMainloop = typename GemmInfo::CollectiveMainloop; - using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue; - using GemmKernel = typename GemmInfo::GemmKernel; - using GemmGrouped = typename GemmInfo::GemmGrouped; - - if (kernel_occupancy != nullptr) - { - *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); - return; - } - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - hw_info.sm_count = multi_processor_count; - - GemmGrouped gemm; - - if (workspace_size != nullptr) - { - // Make a mock problem shape with just the minimal information actually required to get the workspace size - // This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to - // catch future cutlass updates causing silent breakages, but that is not fool proof. - // The alternative is to wait until we have data and then dynamically allocate the workspace - typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr}; - - typename GemmGrouped::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info}; - *workspace_size = gemm.get_workspace_size(args); - return; - } - - using MainloopArguments = typename CollectiveMainloop::Arguments; - TLLM_CHECK(hopper_input.stride_a); - TLLM_CHECK(hopper_input.stride_b); - TLLM_CHECK(hopper_input.ptr_a); - TLLM_CHECK(hopper_input.ptr_b); - - MainloopArguments const mainloop_params = {reinterpret_cast(hopper_input.ptr_b), - hopper_input.stride_b, reinterpret_cast(hopper_input.ptr_a), hopper_input.stride_a}; - - typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{ - ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)}; - epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - // TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS - auto make_epi_args = [&]() - { - if constexpr (FUSION == EpilogueFusion::NONE) - { - auto epi_params = hopper_input.default_epilogue; - return EpilogueArguments{epilogue_scalars, reinterpret_cast(hopper_input.ptr_c), - hopper_input.stride_c, reinterpret_cast(epi_params.ptr_d), epi_params.stride_d}; - } - else if constexpr (FUSION == EpilogueFusion::FINALIZE) - { - // Parameters for fused finalize - auto epi_params = hopper_input.fused_finalize_epilogue; - return EpilogueArguments{ - epilogue_scalars, // Parameters to underlying epilogue - reinterpret_cast(hopper_input.ptr_c), hopper_input.stride_c, // C params - reinterpret_cast(epi_params.ptr_final_output), - epi_params.stride_final_output, // D (output) params - reinterpret_cast(epi_params.ptr_bias), - epi_params.stride_bias, // Bias params - epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales - epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales - epi_params.ptr_source_token_index, // Index of the source token to sum into - epi_params.num_rows_in_final_output // Number of tokens in the output buffer - }; - } - else - { - static_assert( - sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher"); - } - }; - EpilogueArguments const epilogue_params = make_epi_args(); - - typename GemmKernel::TileScheduler::Arguments scheduler_args{ - 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; - - typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info, - mainloop_params, epilogue_params, hw_info, scheduler_args}; - - size_t calculated_ws_size = gemm.get_workspace_size(args); - TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size, - "Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size); - - auto can_implement = gemm.can_implement(args); - TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, - "Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); - - auto init_status = gemm.initialize(args, hopper_input.gemm_workspace); - TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, - "Failed to initialize cutlass SM90 grouped gemm. Error: " - + std::string(cutlassGetStatusString(init_status))); - - auto run_status = gemm.run(stream); - TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, - "Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status))); - sync_check_cuda_error(); - } - else - { - TLLM_THROW("Configuration was disabled by FAST_BUILD"); - } - -#else // COMPILE_HOPPER_TMA_GEMMS - TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py."); -#endif // COMPILE_HOPPER_TMA_GEMMS -} - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu deleted file mode 100644 index 9862460dd6..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu +++ /dev/null @@ -1,131 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" -#include "cutlass/conv/convolution.h" -// Order matters here, packed_stride.hpp is missing cute and convolution includes -#include "cutlass/util/packed_stride.hpp" - -#include "tensorrt_llm/common/logger.h" - -namespace tensorrt_llm -{ -std::array HopperGroupedGemmInput::workspaceBuffers(int num_experts) -{ - size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts; - size_t stride_a_size = sizeof(StrideA) * num_experts; - size_t stride_b_size = sizeof(StrideB) * num_experts; - size_t stride_c_size = sizeof(StrideC) * num_experts; - size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts; - - size_t ptr_buf_size = sizeof(void*) * num_experts; - size_t scale_buf_size = sizeof(float*) * num_experts; - - return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size, - ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size}; -} - -size_t HopperGroupedGemmInput::workspaceSize(int num_experts) -{ - auto buffers = workspaceBuffers(num_experts); - return tensorrt_llm::common::calculateTotalWorkspaceSize(buffers.data(), buffers.size()); -} - -void HopperGroupedGemmInput::configureWorkspace( - int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size) -{ - auto buffers = workspaceBuffers(num_experts); - std::array pointers{}; - TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers"); - for (int i = 0; i < buffers.size(); i++) - { - pointers[i] = start_ptr; - start_ptr = tensorrt_llm::common::nextWorkspacePtr(start_ptr, buffers[i]); - } - - shape_info.num_groups = num_experts; - shape_info.problem_shapes = reinterpret_cast(pointers[0]); - shape_info.host_problem_shapes = nullptr; - stride_a = reinterpret_cast(pointers[1]); - stride_b = reinterpret_cast(pointers[2]); - stride_c = reinterpret_cast(pointers[3]); - default_epilogue.stride_d = reinterpret_cast(pointers[4]); - - ptr_a = reinterpret_cast(pointers[5]); - ptr_b = reinterpret_cast(pointers[6]); - ptr_c = reinterpret_cast(pointers[7]); - default_epilogue.ptr_d = reinterpret_cast(pointers[8]); - - alpha_scale_ptr_array = reinterpret_cast(pointers[9]); - - this->gemm_workspace = reinterpret_cast(gemm_workspace); - this->gemm_workspace_size = gemm_workspace_size; -} - -void HopperGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens) -{ - fused_finalize_epilogue.ptr_final_output = final_output; - fused_finalize_epilogue.ptr_router_scales = router_scales; - fused_finalize_epilogue.ptr_bias = bias; - fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset; - fused_finalize_epilogue.ptr_source_token_index = source_token_index; - - fused_finalize_epilogue.stride_final_output - = cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{}, - transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1))); - fused_finalize_epilogue.stride_bias - = transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size)); - fused_finalize_epilogue.stride_router_scales = {}; - - fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens; -} - -std::string HopperGroupedGemmInput::toString() const -{ - std::stringstream ss; - ss << "Hopper Input Information: " << (isValid() ? "valid" : "null") << "\n"; - if (isValid()) - { - ss << "Ptr A: " << ptr_a << ", Ptr B: " << ptr_b << ", Ptr C: " << ptr_c << "\n"; - ss << "Epilogue Fusion: " << (int) fusion; - if (fusion == HopperGroupedGemmInput::EpilogueFusion::FINALIZE) - { - ss << ",\nFinal Output: " << fused_finalize_epilogue.ptr_final_output; - ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; - ss << ",\nBias: " << fused_finalize_epilogue.ptr_bias; - ss << " with Stride: " << fused_finalize_epilogue.stride_bias; - ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales; - ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; - ss << ",\nExpert Offset: " << fused_finalize_epilogue.ptr_expert_first_token_offset; - ss << ", Source Map: " << fused_finalize_epilogue.ptr_source_token_index; - } - else - { - ss << ", Ptr D: " << default_epilogue.ptr_d; - } - ss << '\n'; - ss << "Alpha scale ptr: " << alpha_scale_ptr_array << "\n"; - } - return ss.str(); -} -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h deleted file mode 100644 index 0616c06365..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h +++ /dev/null @@ -1,230 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include "tensorrt_llm/common/workspace.h" -#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" -#include -#include -#include -#include - -#include "cute/tensor.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/layout/layout.h" - -namespace tensorrt_llm -{ -template -constexpr auto transpose_stride(T const& t) -{ - return cute::prepend(cute::prepend(cute::take<2, cute::rank_v>(t), cute::get<0>(t)), cute::get<1>(t)); -} - -struct HopperGroupedGemmInput -{ - template - using TransposeStride = decltype(transpose_stride(T{})); - template - using TransposeLayoutTag = std::conditional_t, - cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; - - static_assert(std::is_same_v>); - static_assert(std::is_same_v>); - - // Layout for A and B is transposed and then swapped in the implementation - // This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM - using LayoutA = TransposeLayoutTag; // Layout type for A matrix operand - using LayoutB = TransposeLayoutTag; // Layout type for B matrix operand - using LayoutC = TransposeLayoutTag; // Layout type for C matrix operand - - using StrideA - = std::remove_pointer_t>; // Use B because they will be swapped - using StrideB - = std::remove_pointer_t>; // Use A because they will be swapped - using StrideC = std::remove_pointer_t>; - - template - constexpr static bool IsFP8_v = std::is_same_v || std::is_same_v; - - // Currently this should always just be T - template - using OutputTypeAdaptor_t = std::conditional_t, nv_bfloat16, T>; - - using ProblemShape = cutlass::gemm::GroupProblemShape>; - - ProblemShape shape_info{}; - StrideA* stride_a = nullptr; - StrideB* stride_b = nullptr; - - void const** ptr_a = nullptr; - void const** ptr_b = nullptr; - - // C is currently the same in both epilogues - StrideC* stride_c = nullptr; - void const** ptr_c = nullptr; - - struct DefaultEpilogue - { - using LayoutD = TransposeLayoutTag; // Layout type for D matrix operand - using StrideD = std::remove_pointer_t>; - - StrideD* stride_d = nullptr; - void** ptr_d = nullptr; - }; - - struct FusedFinalizeEpilogue - { - using StrideFinalOutput = DefaultEpilogue::StrideD; - using StrideBias = TransposeStride>; - using StrideRouterScales = TransposeStride>; - - void* ptr_final_output = nullptr; - StrideFinalOutput stride_final_output{}; - - void const* ptr_bias = nullptr; - StrideBias stride_bias{}; - - float const* ptr_router_scales = nullptr; - StrideRouterScales stride_router_scales{}; - - int64_t const* ptr_expert_first_token_offset = nullptr; - int const* ptr_source_token_index = nullptr; - - size_t num_rows_in_final_output = 0; - }; - - DefaultEpilogue default_epilogue; - FusedFinalizeEpilogue fused_finalize_epilogue; - - enum class EpilogueFusion - { - NONE, - ACTIVATION, - GATED_ACTIVATION, - FINALIZE - }; - EpilogueFusion fusion = EpilogueFusion::NONE; - - float const** alpha_scale_ptr_array = nullptr; - - uint8_t* gemm_workspace = nullptr; - size_t gemm_workspace_size = 0; - - static std::array workspaceBuffers(int num_experts); - - static size_t workspaceSize(int num_experts); - - void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size); - - bool isValid() const - { - return stride_a != nullptr && ptr_a != nullptr; - } - - void setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens); - - std::string toString() const; -}; - -// Note update moe.py to match -enum class ActivationType -{ - Gelu = 0, - Relu, - Silu, - Swiglu, - Geglu, - Identity, - InvalidType -}; - -constexpr bool isGatedActivation(ActivationType activation_type) -{ - return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; -} - -template -class MoeGemmRunner -{ -public: - MoeGemmRunner(); - -#if defined(ENABLE_FP8) - static constexpr bool use_fp8 = std::is_same_v || std::is_same_v; -#else - static constexpr bool use_fp8 = false; -#endif - - void moeGemmBiasAct(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, - ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, - HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - ActivationType activation_type, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, - cutlass_extensions::CutlassGemmConfig chosen_conf); - - void moeGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, void* C, - int64_t const* total_tokens_including_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array, - cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf); - - std::vector getConfigs() const; - static std::vector getConfigs(int sm); - static std::vector getHopperConfigs(int sm); - static std::vector getAmpereConfigs(int sm); - - [[nodiscard]] bool isHopperSpecialised(cutlass_extensions::CutlassGemmConfig gemm_config) const; - [[nodiscard]] bool supportsHopperSpecialisation() const; - [[nodiscard]] bool isFusedGatedActivation( - cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const; - [[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const; - - size_t getMaxWorkspaceSize(int num_experts) const; - - [[nodiscard]] int getSM() const; - -private: - template - void dispatchToArch(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, - ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, - HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array, - cudaStream_t stream, int* occupancy = nullptr); - - template - void runGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, - bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, - HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, - cutlass_extensions::CutlassGemmConfig chosen_conf); - -private: - int sm_{}; - int multi_processor_count_{}; - mutable int num_experts_ = 0; - mutable size_t gemm_workspace_size_ = 0; - size_t calcMaxWorkspaceSize(int num_experts) const; -}; - -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu deleted file mode 100644 index 3aa96502d3..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>; -#endif -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu deleted file mode 100644 index fbb5270455..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16>; -#endif -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu deleted file mode 100644 index 78f1a93a6a..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_bfloat16, uint8_t, __nv_bfloat16>; -#endif -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu deleted file mode 100644 index 69c4b6a15a..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu deleted file mode 100644 index 4ffa5485f0..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu deleted file mode 100644 index 424b817b87..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu deleted file mode 100644 index f317023565..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu deleted file mode 100644 index c6b8fe7872..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_FP8 -template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>; -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; -#endif -// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>; -#endif -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h deleted file mode 100644 index 2a337e6ca4..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h +++ /dev/null @@ -1,823 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Ignore CUTLASS warnings about type punning -#ifdef __GNUC__ // Check if the compiler is GCC or Clang -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif - -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" - -#include "cute/tensor.hpp" - -#include "cutlass/cutlass.h" - -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/tensor_ref.h" - -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" - -#ifdef __GNUC__ // Restore GCC-specific diagnostics -#pragma GCC diagnostic pop -#endif - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/common/logger.h" - -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" - -#include "moe_gemm_kernels_template_sm90.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" -#include - -#include -#include -#include -#include - -namespace tensorrt_llm -{ -namespace kernels::cutlass_kernels -{ - -// ============================= Variable batched Gemm things =========================== -template -void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* kernel_occupancy = nullptr) -{ -#if defined(ENABLE_FP8) - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for fp8, bfloat16, half, float"); -#elif defined(ENABLE_BF16) - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, - "Specialized for bfloat16, half, float"); -#else - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for half, float"); -#endif - - static_assert(cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, - ""); - - static_assert(!cutlass::platform::is_same::value, - "Sm90 architecture should use specialised kernels"); - - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - using ElementType = typename TllmToCutlassTypeAdapter::type; - using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter::type; - using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; - if (!use_fused_moe) - { - // We need separate config for each architecture since we will target different tensorcore instructions. For - // float, we do not target TCs. - using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; - using ElementAccumulator = typename MixedGemmArchTraits::AccType; - - using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue::Op; - - typename EpilogueOp::Params epilogue_op( - ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); - -#if defined(ENABLE_FP8) - if constexpr ((std::is_same_v - || std::is_same_v) &&std::is_same_v) - { - TLLM_CHECK_WITH_INFO(weight_scales == nullptr && biases == nullptr && alpha_scale_ptr_array, - "weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 " - "Ada"); - epilogue_op.alpha_ptr_array = alpha_scale_ptr_array; - } -#endif - - // Finally, set up the kernel. - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped::GemmKernel; - - using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; - - using GemmGrouped = cutlass::gemm::device::GemmGrouped; - - if (kernel_occupancy != nullptr) - { - *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); - return; - } - int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); - TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); - int const threadblock_count = multi_processor_count * occupancy; - - int const group_size = gemm_k; - typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, - reinterpret_cast(A), reinterpret_cast(B), - reinterpret_cast(weight_scales), - reinterpret_cast(biases), bias_is_broadcast, - reinterpret_cast(C), total_tokens_including_expert, gemm_n, gemm_k); - - GemmGrouped gemm; - - auto can_implement = gemm.can_implement(args); - TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, - "MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); - - auto init_status = gemm.initialize(args); - TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, - "Failed to initialize cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(init_status))); - - auto run_status = gemm.run(stream); - TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, - "Failed to run cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status))); - } - else if constexpr (sizeof(ElementType) == 2 && sizeof(CutlassWeightType) == 2 - && (std::is_same_v - || std::is_same_v) ) // use fused moe gemm - // kernel.. (only support - // fp16 or bf16) - { - sm80_generic_fused_moe_gemm_kernelLauncher(reinterpret_cast(A), - reinterpret_cast(B), reinterpret_cast(biases), - bias_is_broadcast, reinterpret_cast(C), total_tokens_including_expert, num_rows, gemm_n, - gemm_k, num_experts, multi_processor_count, stream, kernel_occupancy); - } -} - -} // namespace kernels::cutlass_kernels - -template -static void dispatch(T const* A, WeightType const* B, GemmOutputType const* weight_scales, GemmOutputType const* biases, - bool bias_is_broadcast, GemmOutputType* C, int64_t const* total_tokens_including_expert, int64_t num_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, - int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, - int* occupancy = nullptr) -{ - - static_assert(!std::is_same_v, "Use TMA specialised functions for arch SM90"); -#if defined(ENABLE_FP8) - constexpr bool isFp8 = std::is_same_v || std::is_same_v; -#else - constexpr bool isFp8 = false; -#endif - - if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80) - && (!isFp8 || std::is_same_v) ) - { - kernels::cutlass_kernels::genericMoeGemmKernelLauncher(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - else - { - TLLM_THROW( - "Cutlass gemm. Not instantiated for arch %d with stages set to %d", Arch::kMinComputeCapability, Stages); - } -} - -template -void dispatchGemmConfig(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) -{ - switch (gemm_config.stages) - { - case 2: - dispatch(A, B, weight_scales, - biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case 3: - dispatch(A, B, weight_scales, - biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case 4: - dispatch(A, B, weight_scales, - biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break; - } -} - -// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. -// This overload is only enabled when T == WeightType. -template ::value -#if defined(ENABLE_FP8) - && !std::is_same::value && !std::is_same::value -#endif - && std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) - { - case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: - TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); - if constexpr (arch::kMinComputeCapability >= 75) - { - dispatchGemmConfig, - cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - break; - case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: - TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); - if constexpr (arch::kMinComputeCapability >= 75) - { - dispatchGemmConfig, - cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - break; - case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; - } -} - -// Tensorop GEMM overload -// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve -// compile time -template ::value && !std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) - { - case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: - TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); - if constexpr (arch::kMinComputeCapability >= 75) - { - dispatchGemmConfig, - cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - break; - case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: - TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); - if constexpr (arch::kMinComputeCapability >= 75) - { - dispatchGemmConfig, - cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - break; - case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Config is invalid for mixed type tensorop GEMM."); break; - } -} - -// This overload will handle tensorop gemms. -// This overload is only enabled when T == WeightType and T == __nv_fp8_e4m3 or __nv_fp8_e5m2 -#if defined(ENABLE_FP8) -template ::value || std::is_same::value) - && std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) - { - case cutlass_extensions::CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: - dispatchGemmConfig, - cutlass::gemm::GemmShape<16, 64, 128>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; - } -} -#endif - -// This overload will handle simt gemms. It is disabled via SFINAE for tensorop. -template ::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) - { - case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Unsupported config for float MoE gemm."); break; - } -} - -template -std::vector -MoeGemmRunner::getConfigs() const -{ - return getConfigs(sm_); -} - -template -std::vector MoeGemmRunner::getConfigs( - int sm) -{ - std::vector candidate_configs = getHopperConfigs(sm); - std::vector ampere_configs = getAmpereConfigs(sm); - std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs)); - - return candidate_configs; -} - -template -std::vector -MoeGemmRunner::getAmpereConfigs(int sm) -{ - using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; - static constexpr auto weight_only_flag - = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; - static constexpr auto simt_only_flag - = std::is_same::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE; - static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; - int const max_split_k = 1; - int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; - int const enable_hopper = CutlassGemmConfig::NONE; - - auto config_type_param = static_cast( - weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); - - if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) - { - return {}; - } - - std::vector ampere_configs - = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); - return ampere_configs; -} - -template -std::vector -MoeGemmRunner::getHopperConfigs(int sm) -{ - using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; - static constexpr auto weight_only_flag - = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; - static constexpr auto simt_only_flag - = std::is_same::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE; - int const max_split_k = 1; - int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; - int const enable_hopper = CutlassGemmConfig::HOPPER; - static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; - auto config_type_param = static_cast( - weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); - - if (!kernels::cutlass_kernels::isValidHopperMOESpecialisation()) - { - return {}; - } - - std::vector hopper_configs - = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); - return hopper_configs; -} - -template -bool MoeGemmRunner::isHopperSpecialised( - cutlass_extensions::CutlassGemmConfig gemm_config) const -{ - bool config_is_sm90 = gemm_config.is_sm90; - return supportsHopperSpecialisation() && config_is_sm90; -} - -template -bool MoeGemmRunner::supportsHopperSpecialisation() const -{ - return sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation(); -} - -template -int MoeGemmRunner::getSM() const -{ - return this->sm_; -} - -// currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction -template -bool MoeGemmRunner::supportsFusedGatedActivation( - bool is_gated_activation, int gemm_n, int gemm_k) const -{ - constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true; - return is_gated_activation && std::is_same_v && !std::is_same_v && !use_fp8 - && (this->getSM() >= 80) && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION; -} - -template -bool MoeGemmRunner::isFusedGatedActivation( - cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const -{ - return supportsFusedGatedActivation(is_gated_activation, gemm_n, gemm_k) && !gemm_config.is_sm90; -} - -template -MoeGemmRunner::MoeGemmRunner() -{ - int device{-1}; - tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); - sm_ = tensorrt_llm::common::getSMVersion(); - tensorrt_llm::common::check_cuda_error( - cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); -} - -template -template -void MoeGemmRunner::dispatchToArch(T const* A, - WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, - void* C_void, int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, - bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy) -{ - static_assert(std::is_same_v, - "Separate Scale/Bias type is not supported. This is assumed to be the gemm output type"); - - // For now we always cast this to output type. - // In the future this will vary based on what fusions are applied for FP8 - auto* C = reinterpret_cast(C_void); - - TLLM_CHECK_WITH_INFO( - sm_ >= 89 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation"); - TLLM_CHECK_WITH_INFO( - sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture"); - - if (sm_ >= 75 && sm_ < 80) - { - dispatchMoeGemmToCutlass(A, B, weight_scales, - biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - else if (sm_ >= 80 && sm_ < 90) - { - if constexpr (use_fp8) - { -#if defined(ENABLE_FP8) - static_assert(!std::is_same_v && !std::is_same_v, - "FP8 GEMM Output not supported"); -#endif - - TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); - dispatchMoeGemmToCutlass(A, B, - weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, - occupancy); - } - else - { - dispatchMoeGemmToCutlass(A, B, - weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, - occupancy); - } - } - else if (sm_ >= 90) - { - if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) - { - - // We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens - // SM80 is faster. We check here to see which is selected - if (gemm_config.is_sm90) - { - TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr, - "Input biases and hopper input disagree if bias is enabled"); - TLLM_CHECK_WITH_INFO(hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config"); - - // Select the appropriate fusion function - auto select_function = [&]() - { - switch (hopper_input.fusion) - { - case HopperGroupedGemmInput::EpilogueFusion::FINALIZE: - return &dispatchMoeGemmSelectTileShapeSM90; - case HopperGroupedGemmInput::EpilogueFusion::NONE: - return &dispatchMoeGemmSelectTileShapeSM90; - case HopperGroupedGemmInput::EpilogueFusion::ACTIVATION: - case HopperGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION: - default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_input.fusion); - }; - }; - auto selected_func = select_function(); - selected_func( - hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr); - return; - } - - // Fallthrough to SM80 impl below - } - - // Do Ampere case instead - if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) - { - TLLM_CHECK_WITH_INFO(!hopper_input.isValid(), - "Non-specialised Hopper implementation is being rerouted to fallback implementation so input " - "information is not required"); - TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90, - "GEMM config is for SM90 configuration, but this configuration is not valid for Hppper"); - dispatchMoeGemmToCutlass(A, B, - weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, - occupancy); - } - else - { - TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels"); - } - } - else - { - TLLM_THROW("Arch unsupported for MoE GEMM"); - } -} - -template -size_t MoeGemmRunner::getMaxWorkspaceSize(int num_experts) const -{ - if (num_experts != num_experts_) - { - TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_); - num_experts_ = num_experts; - gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts); - } - return gemm_workspace_size_; -} - -template -size_t MoeGemmRunner::calcMaxWorkspaceSize(int num_experts) const -{ - if (!supportsHopperSpecialisation()) - { - return 0; - } - if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) - { - auto configs = getHopperConfigs(sm_); - size_t max_size = 0; - bool has_config = false; - for (auto conf : configs) - { -#define CALC_SIZE_FUSION(FUSION) \ - do \ - { \ - try \ - { \ - size_t size = calcMaxWorkspaceSizeSM90( \ - num_experts, conf, multi_processor_count_); \ - max_size = std::max(max_size, size); \ - has_config = true; \ - } \ - catch (tensorrt_llm::common::TllmException const& e) \ - { \ - TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size"); \ - } \ - } while (0) - - CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::NONE); - CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::FINALIZE); - } - TLLM_CHECK_WITH_INFO(has_config, "Could not find valid config when calculating workspace size"); - return max_size; - } - else - { - TLLM_THROW("Attempting to calculate Hopper GEMM workspace size with unsupported weight combination"); - return 0; - } -} - -template -template -void MoeGemmRunner::runGemm(T const* A, WeightType const* B, - ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C, - int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array, - cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) -{ - dispatchToArch(A, B, weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, chosen_conf, use_fused_moe, alpha_scale_ptr_array, - stream, nullptr); -} - -template -void MoeGemmRunner::moeGemmBiasAct(T const* A, WeightType const* B, - ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C, - int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) -{ - switch (activation_type) - { - case ActivationType::Relu: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::Gelu: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::Silu: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::Identity: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::Swiglu: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::Geglu: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; - default: TLLM_THROW("Invalid activation type."); break; - } -} - -template -void MoeGemmRunner::moeGemm(T const* A, WeightType const* B, - ScaleBiasType const* weight_scales, void* C, int64_t const* total_tokens_including_expert, - HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, - cutlass_extensions::CutlassGemmConfig chosen_conf) -{ - runGemm(A, B, weight_scales, nullptr, true, C, total_tokens_including_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, - chosen_conf); -} - -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h deleted file mode 100644 index 3efb42f41e..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Ignore CUTLASS warnings about type punning -#ifdef __GNUC__ // Check if the compiler is GCC or Clang -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // __GNUC__ - -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/tensor_ref.h" - -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" - -#ifdef __GNUC__ // Check if the compiler is GCC or Clang -#pragma GCC diagnostic pop -#endif // __GNUC__ - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" - -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" - -#include -#include -#include -#include - -namespace tensorrt_llm -{ -using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion; - -template -void dispatchMoeGemmSelectBiasSM90(HopperGroupedGemmInput hopper_input, int num_experts, int multi_processor_count, - cudaStream_t stream, int* occupancy, size_t* workspace_size) -{ - static_assert(kernels::cutlass_kernels::isValidHopperMOESpecialisation(), - "Invalid hopper configuration invoked, fallback to Sm80"); - - TLLM_CHECK_WITH_INFO( - workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information"); - - // auto func = hopper_input.ptr_c ? - // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper - // : - // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper; - // TODO(dastokes) Re-enable bias when CUTLASS supports it - auto func = kernels::cutlass_kernels::sm90_generic_moe_gemm_kernelLauncher; - func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); -} - -/* - 1x1x1 cluster shape is are supported for any tile shape. - - 2x1x1 cluster shape is only supported for when the M tile is at least 128. - - 1x2x1 cluster shape is only supported when the N tile is at least 128. - - 2x2x1 cluster shape is only supported when both the M and N tiles are at least 128. - - We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels - that may not be very useful in practice. - */ -template -constexpr bool are_tile_shapes_supported() -{ - using namespace cute; - [[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{}); - [[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{}); - constexpr int cga_m = get<0>(ClusterShape{}); - constexpr int cga_n = get<1>(ClusterShape{}); - - if constexpr (cga_m == _1{} && cga_n == _1{}) - { - return true; - } - else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{}) - { - return true; - } - else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{}) - { - return true; - } - else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{}) - { - return true; - } - else - { - return false; - } -} - -template -void dispatchMoeGemmSelectClusterShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, - size_t* workspace_size) -{ - using namespace cute; - switch (gemm_config.cluster_shape) - { -#define SHAPE_CASE(M, N, K) \ - case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \ - { \ - using ClusterShape = Shape<_##M, _##N, _##K>; \ - if constexpr (are_tile_shapes_supported()) \ - { \ - dispatchMoeGemmSelectBiasSM90( \ - hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \ - break; \ - } \ - else \ - { \ - TLLM_THROW("Unsupported tile and cluster shape combination"); \ - } \ - } - - SHAPE_CASE(1, 1, 1) - SHAPE_CASE(1, 2, 1) - - SHAPE_CASE(2, 1, 1) - SHAPE_CASE(2, 2, 1) - -#undef SHAPE_CASE - default: TLLM_THROW("Unsupported config for MoE gemm."); - } -} // namespace tensorrt_llm - -template -void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, - size_t* workspace_size) -{ - using namespace cute; - - switch (gemm_config.tile_config_sm90) - { -#define SHAPE_CASE(M, N, K) \ - case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \ - { \ - constexpr int KtileBytes = K / sizeof(T); \ - using KTileDim = Int; \ - using TileShape = Shape<_##M, _##N, KTileDim>; \ - dispatchMoeGemmSelectClusterShapeSM90( \ - hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \ - break; \ - } - - SHAPE_CASE(128, 16, 128) - SHAPE_CASE(128, 32, 128) - SHAPE_CASE(128, 64, 128) - SHAPE_CASE(128, 128, 128) - SHAPE_CASE(128, 256, 128) - SHAPE_CASE(256, 128, 128) - -#undef SHAPE_CASE - case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Unsupported config for MoE gemm."); break; - } -} - -template -size_t calcMaxWorkspaceSizeSM90( - int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count) -{ - size_t count; - // Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat - dispatchMoeGemmSelectTileShapeSM90( - HopperGroupedGemmInput{}, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count); - return count; -} - -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h deleted file mode 100644 index 959d0ea088..0000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h +++ /dev/null @@ -1,44 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/arch/mma_sm90.h" -#include "cutlass_extensions/epilogue_helpers.h" - -namespace tensorrt_llm::kernels::cutlass_kernels -{ - -// Hopper arch -template -constexpr bool isValidHopperMOESpecialisation() -{ -#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) - return cutlass::platform::is_same::value - && cutlass::platform::is_same::value; -#else - return false; // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled -#endif -} - -// Hopper arch -template -constexpr bool isValidAmpereMOESpecialisation() -{ - return true; // Default to true -} - -} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 5029914031..90c3cbc1d3 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -39,8 +39,6 @@ def _get_version(): cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" turbomind = root / "3rdparty" / "turbomind" -tensorrt_llm_parent = root / "3rdparty" -tensorrt_llm = root / "3rdparty" / "tensorrt_llm" include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", @@ -53,8 +51,6 @@ def _get_version(): "cublasLt", turbomind.resolve(), turbomind.resolve() / "src", - tensorrt_llm_parent.resolve(), - tensorrt_llm.resolve() / "cutlass_extensions" / "include", ] nvcc_flags = [ From 734daedd8fd9155fa4854b88d3c36cb90831e441 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Fri, 31 Jan 2025 01:04:04 -0800 Subject: [PATCH 10/52] [fix] Clamp logprob with dtype min to prevent `-inf` (#3224) --- python/sglang/srt/layers/sampler.py | 7 +++++-- .../penaltylib/test_srt_endpoint_with_penalizers.py | 7 +++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index b24bfc8dac..73ef13c35f 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -72,9 +72,11 @@ def forward( # NOTE: the top_p_renorm_prob from flashinfer has numerical problems, # https://github.com/flashinfer-ai/flashinfer/issues/708 # so we use the torch implementation. + + # clamp to avoid -inf logprobs = torch.log( top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ) + ).clamp(min=torch.finfo(probs.dtype).min) max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( @@ -109,9 +111,10 @@ def forward( sampling_info.need_min_p_sampling, ) if return_logprob: + # clamp to avoid -inf logprobs = torch.log( top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ) + ).clamp(min=torch.finfo(probs.dtype).min) else: raise ValueError( f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index 34565c9ff6..d9d77a9ae2 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -36,7 +36,7 @@ def tearDownClass(cls): def run_decode( self, return_logprob=True, - top_logprobs_num=3, + top_logprobs_num=5, return_text=True, n=1, **sampling_params, @@ -58,8 +58,7 @@ def run_decode( "logprob_start_len": 0, }, ) - print(json.dumps(response.json())) - print("=" * 100) + assert response.status_code == 200, "Request failed: " + response.text def test_default_values(self): self.run_decode() @@ -112,4 +111,4 @@ def test_repetition_penalty(self): if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=3) From c02e3139149d5f0c318a3b292d389a58f172b6ba Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Fri, 31 Jan 2025 19:56:02 +0800 Subject: [PATCH 11/52] Fix block wise fp8 torch compile (#3232) --- python/sglang/srt/layers/quantization/fp8.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index b0b5b8952a..f5a0005a28 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -290,6 +290,13 @@ def process_weights_after_loading(self, layer: Module) -> None: weight_scale, requires_grad=False ) layer.input_scale = None + else: + layer.weight = torch.nn.Parameter( + layer.weight.data, requires_grad=False + ) + layer.weight_scale_inv = torch.nn.Parameter( + layer.weight_scale_inv.data, requires_grad=False + ) return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. From b49d6d0fee3cf83d72ed658bd9f514bd87fcaa56 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 31 Jan 2025 20:31:38 +0800 Subject: [PATCH 12/52] support 12.5 CUDA runtime (#3231) --- .github/workflows/release-docker.yml | 6 ++++-- docker/Dockerfile | 6 ++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml index 99ffd7c49c..d5669886d1 100644 --- a/.github/workflows/release-docker.yml +++ b/.github/workflows/release-docker.yml @@ -14,7 +14,7 @@ jobs: environment: 'prod' strategy: matrix: - cuda_version: ['11.8.0', '12.1.1', '12.4.1'] + cuda_version: ['11.8.0', '12.1.1', '12.4.1', '12.5.1'] build_type: ['all', 'srt'] steps: - name: Delete huge unnecessary tools folder @@ -39,6 +39,8 @@ jobs: cuda_tag="cu121" elif [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then cuda_tag="cu124" + elif [ "${{ matrix.cuda_version }}" = "12.5.1" ]; then + cuda_tag="cu125" else echo "Unsupported CUDA version" exit 1 @@ -58,7 +60,7 @@ jobs: docker build . -f docker/Dockerfile --build-arg CUDA_VERSION=${{ matrix.cuda_version }} --build-arg BUILD_TYPE=${{ matrix.build_type }} -t lmsysorg/sglang:${tag}${tag_suffix} --no-cache docker push lmsysorg/sglang:${tag}${tag_suffix} - if [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then + if [ "${{ matrix.cuda_version }}" = "12.5.1" ]; then docker tag lmsysorg/sglang:${tag}${tag_suffix} lmsysorg/sglang:latest${tag_suffix} docker push lmsysorg/sglang:latest${tag_suffix} fi diff --git a/docker/Dockerfile b/docker/Dockerfile index 1fe702d401..cec05825d0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -30,6 +30,8 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu121; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118; \ python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ @@ -42,6 +44,8 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ @@ -53,6 +57,8 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ From cf0f7eafe69a7bb2aebf2c6c6ac361be8d4ccfe6 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 31 Jan 2025 20:35:55 +0800 Subject: [PATCH 13/52] chore: bump v0.4.2.post1 (#3233) --- docker/Dockerfile.rocm | 2 +- docs/developer/setup_github_runner.md | 4 ++-- docs/start/install.md | 10 +++++----- python/pyproject.toml | 2 +- python/sglang/version.py | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index f04254e54c..af9f9e24df 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,5 +1,5 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.4.2 -t v0.4.2-rocm620 -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm620 -f Dockerfile.rocm . # default base image ARG BASE_IMAGE="rocmshared/vllm-rocm:20250114-tuned-elementwise-layernorm" diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md index 779c413977..96c9cae015 100644 --- a/docs/developer/setup_github_runner.md +++ b/docs/developer/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm620 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm620 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/start/install.md b/docs/start/install.md index 90964ac6b6..a5012d6fc7 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -14,7 +14,7 @@ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/ ## Method 2: From source ``` # Use the last release branch -git clone -b v0.4.2 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -28,7 +28,7 @@ Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: ``` # Use the last release branch -git clone -b v0.4.2 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -54,7 +54,7 @@ docker run --gpus all \ Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.4.2 -t v0.4.2-rocm620 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm620 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -63,11 +63,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.4.2-rocm620 \ + v0.4.2.post1-rocm620 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.4.2-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.2.post1-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/python/pyproject.toml b/python/pyproject.toml index 11c984f82d..8442ff5d2d 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.4.2" +version = "0.4.2.post1" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" diff --git a/python/sglang/version.py b/python/sglang/version.py index df12433297..d1b3e6d0ae 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.4.2" +__version__ = "0.4.2.post1" From 656f7fc1bc6bd128b227404dd2900d2b63073dcb Mon Sep 17 00:00:00 2001 From: Jhin <47354855+jhinpan@users.noreply.github.com> Date: Fri, 31 Jan 2025 10:30:40 -0600 Subject: [PATCH 14/52] Docs: Quick fix for Speculative_decoding doc (#3228) Co-authored-by: Chayenne Co-authored-by: Chayenne --- docs/backend/speculative_decoding.ipynb | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index d69436eed1..273d943d12 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -8,10 +8,11 @@ "\n", "SGLang now provides an EAGLE-based speculative decoding option. The implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n", "\n", + "**Note:** Currently, Speculative Decoding in SGLang does not support radix cache.\n", + "\n", "To run the following tests or benchmarks, you also need to install [**cutex**](https://pypi.org/project/cutex/): \n", - "> ```bash\n", - "> pip install cutex\n", - "> ```\n", + "\n", + "`pip install cutex`\n", "\n", "### Performance Highlights\n", "\n", From 7811bfdaa76f903b51e67d5c6b4f4dbb42ec2f69 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 01:32:18 +0800 Subject: [PATCH 15/52] compatible with flashinfer v0.2 (#3235) --- python/sglang/srt/layers/attention/flashinfer_backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 7540515c5f..cc6da781f5 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -800,7 +800,9 @@ def call_begin_forward( kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + paged_kernel_lens_sum + 256, + dtype=torch.int32, + device=req_pool_indices.device, ) create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, From 1ebe1d6de5e0ce082e0be059c222baf0c5ee340a Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 1 Feb 2025 01:36:50 +0800 Subject: [PATCH 16/52] Optimize MoE topk with torch compile (#3236) --- python/sglang/srt/layers/moe/topk.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 527a7d499b..dc53e4445d 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -17,6 +17,8 @@ import torch import torch.nn.functional as F +from sglang.srt.utils import get_compiler_backend + def fused_topk_native( hidden_states: torch.Tensor, @@ -74,6 +76,7 @@ def fused_topk( # This is used by the Deepseek-V2 model +@torch.compile(dynamic=True, backend=get_compiler_backend()) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -108,6 +111,7 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def biased_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, From 34e405e01f7ff15ad56399999b9c00859a0b5134 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 02:14:41 +0800 Subject: [PATCH 17/52] update sgl-kernel version for sglang (#3238) --- python/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 8442ff5d2d..d3d8c3f2a5 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,7 +27,7 @@ runtime_common = [ ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.3", "torch", "vllm==0.6.4.post1", + "sgl-kernel>=0.0.3.post1", "torch", "vllm==0.6.4.post1", "flashinfer==0.1.6" ] From 7876279ea7bc21fd73b8d56615d359343d4d1678 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 03:13:44 +0800 Subject: [PATCH 18/52] update cutlass dependency (#3240) --- sgl-kernel/3rdparty/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-kernel/3rdparty/cutlass b/sgl-kernel/3rdparty/cutlass index bdd641790a..3c28697b9f 160000 --- a/sgl-kernel/3rdparty/cutlass +++ b/sgl-kernel/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit bdd641790ad49353b40ada41330552a78d2f8b5a +Subproject commit 3c28697b9f41fee4517b1758ffe83a85ac3ce2b4 From 7b020cca2d6d95a374e5214928b148214d615583 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 03:58:18 +0800 Subject: [PATCH 19/52] add tuning block wise fp8 (#3242) Co-authored-by: HandH1998 <007aabbcc411@gmail.com> --- .../quantization/tuning_block_wise_fp8.py | 335 ++++++++++++++++++ 1 file changed, 335 insertions(+) create mode 100644 benchmark/kernels/quantization/tuning_block_wise_fp8.py diff --git a/benchmark/kernels/quantization/tuning_block_wise_fp8.py b/benchmark/kernels/quantization/tuning_block_wise_fp8.py new file mode 100644 index 0000000000..07bdb4bf16 --- /dev/null +++ b/benchmark/kernels/quantization/tuning_block_wise_fp8.py @@ -0,0 +1,335 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +import json +import os +import time +from datetime import datetime +from typing import Any, Dict, List + +import torch +import triton +from tqdm import tqdm + +from sglang.srt.layers.quantization.fp8_kernel import _w8a8_block_fp8_matmul +from sglang.srt.utils import get_device_name + +DTYPE_MAP = { + "float32": torch.float32, + "float16": torch.float16, + "half": torch.half, + "bfloat16": torch.bfloat16, +} + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + config: Dict[str, Any], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise quantization. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C + + +def get_configs_compute_bound(): + configs = [] + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def get_weight_shapes(tp_size): + # NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model. + # cannot TP + total = [ + (512 + 64, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (7168, 16384), + (7168, 18432), + ] + # N can TP + n_tp = [ + (18432 * 2, 7168), + ((128 + 64) * 128, 7168), + (128 * (128 + 128), 512), + (24576, 1536), + (4096, 7168), + ] + # K can TP + k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)] + + weight_shapes = [] + for t in total: + weight_shapes.append(t) + for n_t in n_tp: + new_t = (n_t[0] // tp_size, n_t[1]) + weight_shapes.append(new_t) + for k_t in k_tp: + new_t = (k_t[0], k_t[1] // tp_size) + weight_shapes.append(new_t) + return weight_shapes + + +def benchmark_config( + A_fp8, B_fp8, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10 +): + def run(): + w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, config, out_dtype) + + torch.cuda.synchronize() + # JIT complication & warmup + for _ in range(5): + run() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: List[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + run() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + return avg + + +def tune(M, N, K, block_size, out_dtype, search_space): + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale + Bs = ( + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") + * factor_for_scale + ) + + best_config = None + best_time = float("inf") + for config in tqdm(search_space): + try: + kernel_time = benchmark_config( + A_fp8, + B_fp8, + As, + Bs, + block_size, + config, + out_dtype, + num_iters=10, + ) + except triton.runtime.autotuner.OutOfResources: + # Some configurations may be invalid and fail to compile. + continue + + if kernel_time < best_time: + best_time = kernel_time + best_config = config + now = datetime.now() + print(f"{now.ctime()}] Completed tuning for batch_size={M}") + assert best_config is not None + return best_config + + +def save_configs( + N, + K, + block_n, + block_k, + configs, + save_path, +) -> None: + os.makedirs(save_path, exist_ok=True) + device_name = get_device_name().replace(" ", "_") + json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json" + + config_file_path = os.path.join(save_path, json_file_name) + print(f"Writing best config to {config_file_path}...") + + with open(config_file_path, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + + +def main(args): + print(args) + + block_n = args.block_n + block_k = args.block_k + + tp_size = args.tp_size + assert args.out_dtype in ["float32", "float16", "bfloat16", "half"] + out_dtype = DTYPE_MAP[args.out_dtype] + save_path = args.save_path + + search_space = get_configs_compute_bound() + search_space = [ + config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0 + ] + + if args.batch_size is None: + batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, + ] + else: + batch_sizes = [args.batch_size] + + print(f"Start tuning over {len(search_space)} configurations...") + + weight_shapes = get_weight_shapes(tp_size) + start = time.time() + for shape in tqdm(weight_shapes): + N, K = shape[0], shape[1] + print(f"Tune for weight shape of `N: {N}, K: {K}`") + benchmark_results = [ + tune(batch_size, N, K, [block_n, block_k], out_dtype, search_space) + for batch_size in batch_sizes + ] + best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)} + save_configs(N, K, block_n, block_k, best_configs, save_path) + + end = time.time() + print(f"Tuning took {end - start:.2f} seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--tp-size", "-tp", type=int, default=8) + parser.add_argument( + "--out-dtype", + type=str, + choices=["float32", "float16", "bfloat16", "half"], + default="float16", + ) + parser.add_argument("--block-n", type=int, default=128) + parser.add_argument("--block-k", type=int, default=128) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument( + "--save-path", type=str, default="python/sglang/srt/layers/quantization/configs" + ) + args = parser.parse_args() + + main(args) From d7c0b32f4d2b8c3fae6b560b7f9192831dcf1fb2 Mon Sep 17 00:00:00 2001 From: Wenxuan Tan Date: Fri, 31 Jan 2025 17:59:28 -0600 Subject: [PATCH 20/52] [Docs] Add more details to profiling docs (#3221) --- docs/references/benchmark_and_profiling.md | 83 ++++++++++++---------- 1 file changed, 44 insertions(+), 39 deletions(-) diff --git a/docs/references/benchmark_and_profiling.md b/docs/references/benchmark_and_profiling.md index 0600b192b4..762cae2767 100644 --- a/docs/references/benchmark_and_profiling.md +++ b/docs/references/benchmark_and_profiling.md @@ -15,8 +15,46 @@ python3 -m sglang.bench_serving --backend sglang --num-prompt 10 ``` +## Profile with PyTorch Profiler +Pytorch Profiler is a convenient basic tool to inspect kernel execution time, call stack, and kernel overlap and occupancy. +- To profile a server +```bash +# set trace path +export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log + +# start server +python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct + +# send profiling request from client +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile +``` +Please make sure that the `SGLANG_TORCH_PROFILER_DIR` should be set at both server and client side, otherwise the trace file cannot be generated correctly . A secure way will be setting `SGLANG_TORCH_PROFILER_DIR` in the `.*rc` file of shell (e.g. `~/.bashrc` for bash shells). + +- To profile offline +```bash +export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log +python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8 +``` + +- View Traces + +Trace files can be loaded and visualized from: +1. https://ui.perfetto.dev/ (any browser) +2. chrome://tracing (Chrome browser only) + +If browser cannot open trace file due to its large size, +client can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs. +For example, when profiling a server, +```bash +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile +``` +sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly. + ## Profile with Nsight -0. Prerequisite +Nsight systems is an advanced tool that exposes more profiling details, such as register and shared memory usage, annotated code regions and low-level CUDA APIs and events. + +0. Prerequisite: install using apt, or run inside a [NVIDIA Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) or [SGLang Docker container](https://github.com/sgl-project/sglang/tree/main/docker). + ```bash # install nsys # https://docs.nvidia.com/nsight-systems/InstallationGuide/index.html @@ -41,12 +79,13 @@ nsys profile --trace-fork-before-exec=true --cuda-graph-trace=node -o sglang.out python3 -m sglang.bench_serving --backend sglang --num-prompts 1000 --dataset-name random --random-input 1024 --random-output 512 ``` -3. Use NVTX, e.g. +3. Use NVTX to annotate code regions, e.g. to see their execution time. ```bash # install nvtx pip install nvtx - +``` +``` python # code snippets import nvtx with nvtx.annotate("description", color="color"): @@ -54,41 +93,7 @@ with nvtx.annotate("description", color="color"): ``` ## Other tips - 1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder. 2. You can benchmark a model with modified configs (e.g., less layers) by using `--json-model-override-args`. For example, you can benchmark a model with only 2 layers and 2 kv heads using `python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --batch 32 --input-len 256 --output-len 32 --load-format dummy --json-model-override-args '{"num_hidden_layers": 1, "num_key_value_heads": 1}'` - - -## Profile with PyTorch Profiler -- To profile a server -```bash -# set trace path -export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log - -# start server -python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct - -# send profiling request from client -python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile -``` -Please make sure that the `SGLANG_TORCH_PROFILER_DIR` should be set at both server and client side, otherwise the trace file cannot be generated correctly . A secure way will be setting `SGLANG_TORCH_PROFILER_DIR` in the `.*rc` file of shell (e.g. `~/.bashrc` for bash shells). - -- To profile offline -```bash -export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log -python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8 -``` - -- View Traces - -Trace files can be loaded and visualized from: -1. https://ui.perfetto.dev/ (any browser) -2. chrome://tracing (Chrome browser only) - -If browser cannot open trace file due to its large size, -client can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs. -For example, when profiling a server, -```bash -python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile -``` -sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly. +3. You can use `--python-backtrace=cuda` to see python call stack for all CUDA kernels, as in PyTorch Profiler. (Caveat: this can cause inaccurately long kernel runtimes for CUDA event based timing) +4. For more args please see https://docs.nvidia.com/nsight-systems/UserGuide/index.html From 5317902670fcedc59861b41e1fa36a49866495db Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 1 Feb 2025 16:07:54 +0800 Subject: [PATCH 21/52] Add test for fp8 torch compile (#3246) --- test/srt/test_mla.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index 34bc4b4464..6305732509 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -62,7 +62,12 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--trust-remote-code"], + other_args=[ + "--trust-remote-code", + "--enable-torch-compile", + "--cuda-graph-max-bs", + "2", + ], ) @classmethod From 17dbf976c58de83ce1d410a177954d60278b3505 Mon Sep 17 00:00:00 2001 From: HAI Date: Sat, 1 Feb 2025 01:27:43 -0800 Subject: [PATCH 22/52] update ENV to ROCm dockers (#3248) --- docker/Dockerfile.rocm | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index af9f9e24df..e1a242c87f 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -58,6 +58,7 @@ RUN git clone ${ATER_REPO} \ # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 +ENV HSA_NO_SCRATCH_RECLAIM=1 ENV SGLANG_SET_CPU_AFFINITY=1 ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 ENV NCCL_MIN_NCHANNELS=112 From 4eb4b401cc552cab162165e22e1428086eb0f874 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 18:56:44 +0800 Subject: [PATCH 23/52] update and simplify CustomOp (#3249) --- python/sglang/srt/custom_op.py | 40 +++++++++++++++++++ python/sglang/srt/layers/activation.py | 6 +-- python/sglang/srt/layers/custom_op_util.py | 25 ------------ python/sglang/srt/layers/layernorm.py | 6 +-- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 +- .../srt/layers/moe/fused_moe_triton/layer.py | 4 +- python/sglang/srt/layers/rotary_embedding.py | 4 +- .../srt/model_executor/cuda_graph_runner.py | 2 +- 8 files changed, 46 insertions(+), 45 deletions(-) create mode 100644 python/sglang/srt/custom_op.py delete mode 100644 python/sglang/srt/layers/custom_op_util.py diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py new file mode 100644 index 0000000000..a702e8f822 --- /dev/null +++ b/python/sglang/srt/custom_op.py @@ -0,0 +1,40 @@ +import torch +from torch import nn + +_is_cuda = torch.cuda.is_available() and torch.version.cuda +_is_rocm = torch.cuda.is_available() and torch.version.hip + + +class CustomOp(nn.Module): + def __init__(self): + super().__init__() + self._forward_method = self.dispatch_forward() + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + raise NotImplementedError + + def forward_cuda(self, *args, **kwargs): + raise NotImplementedError + + def forward_hip(self, *args, **kwargs): + raise NotImplementedError + + def forward_xpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def forward_hpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def forward_cpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def dispatch_forward(self): + if _is_cuda: + return self.forward_cuda + elif _is_rocm: + return self.forward_hip + else: + return self.forward_native diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index d69d854ab2..08ea91b9c1 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -25,21 +25,18 @@ if is_cuda_available(): from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul -from vllm.model_executor.custom_op import CustomOp - +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import set_weight_attrs logger = logging.getLogger(__name__) -@register_custom_op("sglang_silu_and_mul") class SiluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 @@ -53,7 +50,6 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return out -@register_custom_op("sglang_gelu_and_mul") class GeluAndMul(CustomOp): def __init__(self, approximate="tanh"): super().__init__() diff --git a/python/sglang/srt/layers/custom_op_util.py b/python/sglang/srt/layers/custom_op_util.py deleted file mode 100644 index 92e186cd20..0000000000 --- a/python/sglang/srt/layers/custom_op_util.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from vllm.model_executor.custom_op import CustomOp - - -def register_custom_op(op_name): - def decorator(cls): - if hasattr(CustomOp, "register"): - return CustomOp.register(op_name)(cls) - else: - return cls - - return decorator diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 207ba8d1b7..e3b23a2a92 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -29,14 +29,11 @@ rmsnorm, ) -from vllm.model_executor.custom_op import CustomOp - -from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.custom_op import CustomOp logger = logging.getLogger(__name__) -@register_custom_op("sglang_rmsnorm") class RMSNorm(CustomOp): def __init__( self, @@ -79,7 +76,6 @@ def forward_native( return x, residual -@register_custom_op("sglang_gemma_rmsnorm") class GemmaRMSNorm(CustomOp): def __init__( self, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index bc927621a8..4d6040646b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -4,13 +4,12 @@ import torch from torch.nn import Module from vllm import _custom_ops as ops -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.moe.ep_moe.kernels import ( grouped_gemm_triton, post_reorder_triton_kernel, @@ -407,7 +406,6 @@ def _load_fp8_scale( param_data[expert_id] = loaded_weight -@register_custom_op("sglang_unquantized_ep_moe") class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): def create_weights( self, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index b71a878a0b..dc7152da93 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -5,14 +5,13 @@ from typing import Callable, List, Optional, Tuple import torch -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( @@ -67,7 +66,6 @@ def apply( raise NotImplementedError -@register_custom_op("sglang_unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 7093bb90d8..ef8a96c985 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -7,9 +7,8 @@ import torch import torch.nn as nn from vllm import _custom_ops as ops -from vllm.model_executor.custom_op import CustomOp -from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.custom_op import CustomOp from sglang.srt.utils import is_cuda_available _is_cuda_available = is_cuda_available() @@ -59,7 +58,6 @@ def _apply_rotary_emb( return torch.stack((o1, o2), dim=-1).flatten(-2) -@register_custom_op("sglang_rotary_embedding") class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 93b4d0ea57..69615b8ff3 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -21,8 +21,8 @@ import torch import tqdm -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.logits_processor import LogitsProcessorOutput From 8db776f049732141d1acd6f0c7c24d2297974f31 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 19:31:47 +0800 Subject: [PATCH 24/52] support QuickGELU (#3250) --- python/sglang/srt/layers/activation.py | 9 +++++++++ python/sglang/srt/models/qwen2_vl.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 08ea91b9c1..82c39c2acb 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -72,6 +72,15 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return out +class QuickGELU(CustomOp): + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + # TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel + return self.forward_native(x) + + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 365891544e..adc5050819 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -31,10 +31,10 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from vllm.model_executor.layers.activation import QuickGELU from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.activation import QuickGELU from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor From ad6740977b0358caeac2606936aa18e0513b2a11 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 19:47:44 +0800 Subject: [PATCH 25/52] add contact us in README (#3251) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index b27271a181..b0f28e985a 100644 --- a/README.md +++ b/README.md @@ -60,5 +60,9 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s ## Adoption and Sponsorship The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. +## Contact Us + +For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at contact@sglang.ai or business@sglang.ai. + ## Acknowledgment and Citation We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. From f2b3a3188ed5504c02b4f18fbae5c1ae49babe40 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 21:19:15 +0800 Subject: [PATCH 26/52] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b0f28e985a..4b17633d81 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, ## Contact Us -For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at contact@sglang.ai or business@sglang.ai. +For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at contact@sglang.ai. ## Acknowledgment and Citation We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. From 959dca4fc7d720b8885e74761f7b098bed2bdeb7 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 22:23:09 +0800 Subject: [PATCH 27/52] use srt VocabParallelEmbedding (#3252) --- python/sglang/srt/lora/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index c8cbe36602..871c1a2291 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -23,7 +23,6 @@ import torch from torch import nn -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -31,6 +30,7 @@ QKVParallelLinear, RowParallelLinear, ) +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_loader.loader import DefaultModelLoader From d9eb9358ccf8803253d2f5cf7feafef13b60b8c5 Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Sat, 1 Feb 2025 19:29:45 -0600 Subject: [PATCH 28/52] Tune paged attention parameters for AMD GPU. (#3255) --- .../layers/attention/triton_ops/decode_attention.py | 11 +++++++++-- python/sglang/srt/server_args.py | 4 ++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 2b4871af98..512900bd30 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -181,6 +181,9 @@ def _decode_att_m_fwd( logit_cap, ): BLOCK = 64 + # [TODO] work around SGPR limit on MI3xx + if is_hip_: + BLOCK = 8 NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] @@ -194,6 +197,8 @@ def _decode_att_m_fwd( num_warps = 4 else: num_warps = 2 + if is_hip_: + num_warps = 1 BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DV = triton.next_power_of_2(Lv) @@ -433,10 +438,12 @@ def _decode_grouped_att_m_fwd( ) extra_kargs = {} + num_stages = 2 if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py - extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} + num_stages = 1 _fwd_grouped_kernel_stage1[grid]( q, @@ -467,7 +474,7 @@ def _decode_grouped_att_m_fwd( NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, num_warps=4, - num_stages=2, + num_stages=num_stages, Lk=Lk, Lv=Lv, **extra_kargs, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f9340e4776..8c5ad0b96e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -273,6 +273,10 @@ def __post_init__(self): ) and check_gguf_file(self.model_path): self.quantization = self.load_format = "gguf" + # AMD-specific Triton attention KV splits default number + if is_hip(): + self.triton_attention_num_kv_splits = 16 + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args From c27c378a196aa44c5c7e542ce987b81c972f0fdf Mon Sep 17 00:00:00 2001 From: simveit <69345428+simveit@users.noreply.github.com> Date: Sun, 2 Feb 2025 20:01:39 +0100 Subject: [PATCH 29/52] docs/accuracy evaluation (#3114) Co-authored-by: Shi Shuai <126407087+shuaills@users.noreply.github.com> Co-authored-by: zhaochenyang20 --- docs/index.rst | 1 + docs/references/accuracy_evaluation.md | 60 ++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 docs/references/accuracy_evaluation.md diff --git a/docs/index.rst b/docs/index.rst index aaa4638449..b8067c25d8 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -57,6 +57,7 @@ The core features include: references/sampling_params.md references/hyperparameter_tuning.md references/benchmark_and_profiling.md + references/accuracy_evaluation.md references/custom_chat_template.md references/deepseek.md references/llama_405B.md diff --git a/docs/references/accuracy_evaluation.md b/docs/references/accuracy_evaluation.md new file mode 100644 index 0000000000..053dd8369d --- /dev/null +++ b/docs/references/accuracy_evaluation.md @@ -0,0 +1,60 @@ +# Measuring Model Accuracy in SGLang + +This guide shows how to evaluate model accuracy using SGLang's [built-in benchmarks](https://github.com/sgl-project/sglang/tree/b045841baeff37a5601fcde23fa98bd09d942c36/benchmark). + +## Benchmarking Model Accuracy + +This is a reference workflow for the [MMLU benchmark](https://github.com/sgl-project/sglang/tree/main/benchmark/mmlu). For more details or other benchmarks, please refer to the README in each specific benchmark folder under [sglang/benchmark](https://github.com/sgl-project/sglang/tree/b045841baeff37a5601fcde23fa98bd09d942c36/benchmark). + +```bash +# Step 1: Download the dataset +bash download_data.sh + +# Step 2: Launch the server +python3 -m sglang.launch_server \ + --model-path Qwen/Qwen2.5-Math-1.5B-Instruct \ # Model selection + --port 30000 \ # Network configuration + --mem-fraction-static 0.8 # Memory optimization + +# Step 3: Run the benchmark script +python3 bench_sglang.py --nsub 10 # Test 10 subjects + +# Step 4: Extract the accuracy +cat result.jsonl | grep -oP '"accuracy": \K\d+\.\d+' +``` + +## Customizing Benchmark Scripts + +Some benchmark implementations may differ from ours, causing accuracy discrepancies. To match [[Qwen2.5-Math]](https://github.com/QwenLM/Qwen2.5-Math)'s reported 76.8% GSM8K accuracy, customization is required. + +```python +# The GSM8K benchmark script includes few shot examples for evaluation by default. +# Here we exclude them. +for i in range(len(lines[num_shots:num_questions])): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) +``` + +```python +@sgl.function +def few_shot_gsm8k(s, question): + # System prompt given in https://github.com/QwenLM/Qwen2.5-Math + s += sgl.system("Please reason step by step, and put your final answer within \\boxed{}.") # Include system prompt + s += few_shot_examples + question + # Stopwords given in evaluation/math_eval.py of the Qwen2.5-Math repo + s += sgl.gen( + "answer", max_tokens=2048, stop=["Question", "Assistant:", "", "<|im_end|>", "<|endoftext|>"] + ) +``` + +These adjustments give us the us the reported accuracy. + +## Extending Evaluation Capabilities + +1. **Contribute New Benchmarks** + * Follow our [contribution guidelines](https://docs.sglang.ai/references/contribution_guide.html) to add new test scripts +2. **Request Implementations** + * Feel free to open an issue describing your evaluation needs +3. **Use Alternative Tools** + * [OpenCompass](https://opencompass.org.cn) + * [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) From 55f5fc68acc84027a275649d2cfb320448ed95d7 Mon Sep 17 00:00:00 2001 From: Chayenne Date: Sun, 2 Feb 2025 11:14:59 -0800 Subject: [PATCH 30/52] Docs: Update accuracy evaluation (#3261) --- .github/pull_request_template.md | 2 +- docs/references/accuracy_evaluation.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 5493c4201c..279994c596 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -13,4 +13,4 @@ - [ ] Format your code according to the [Code Formatting with Pre-Commit](https://docs.sglang.ai/references/contribution_guide.html#code-formatting-with-pre-commit). - [ ] Add unit tests as outlined in the [Running Unit Tests](https://docs.sglang.ai/references/contribution_guide.html#running-unit-tests-adding-to-ci). - [ ] Update documentation / docstrings / example tutorials as needed, according to [Writing Documentation](https://docs.sglang.ai/references/contribution_guide.html#writing-documentation-running-docs-ci). -- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to [Benchmark and Profiling](https://docs.sglang.ai/references/benchmark_and_profiling.html). +- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to [Benchmark and Profiling](https://docs.sglang.ai/references/benchmark_and_profiling.html) and [Accuracy Results](https://docs.sglang.ai/references/accuracy_evaluation.html). diff --git a/docs/references/accuracy_evaluation.md b/docs/references/accuracy_evaluation.md index 053dd8369d..123d1cab08 100644 --- a/docs/references/accuracy_evaluation.md +++ b/docs/references/accuracy_evaluation.md @@ -1,6 +1,6 @@ # Measuring Model Accuracy in SGLang -This guide shows how to evaluate model accuracy using SGLang's [built-in benchmarks](https://github.com/sgl-project/sglang/tree/b045841baeff37a5601fcde23fa98bd09d942c36/benchmark). +This guide shows how to evaluate model accuracy using SGLang's [built-in benchmarks](https://github.com/sgl-project/sglang/tree/b045841baeff37a5601fcde23fa98bd09d942c36/benchmark). Please include accuracy on crucial benchmarks in your PR if you make modifications on the model side, like the kernel and model architecture. ## Benchmarking Model Accuracy @@ -47,7 +47,7 @@ def few_shot_gsm8k(s, question): ) ``` -These adjustments give us the us the reported accuracy. +These adjustments should return the desired accuracy. ## Extending Evaluation Capabilities From 566d61d90fd508f09179788e8b719a748af8e65b Mon Sep 17 00:00:00 2001 From: HAI Date: Sun, 2 Feb 2025 12:13:40 -0800 Subject: [PATCH 31/52] ROCm: bump 6.3.0 (#3259) --- .github/workflows/release-docker-amd.yml | 6 +++--- docker/Dockerfile.rocm | 4 ++-- docs/developer/setup_github_runner.md | 4 ++-- docs/start/install.md | 6 +++--- python/pyproject.toml | 18 ++++++++---------- .../sglang/srt/constrained/outlines_backend.py | 10 +++++++++- python/sglang/srt/custom_op.py | 2 +- 7 files changed, 28 insertions(+), 22 deletions(-) diff --git a/.github/workflows/release-docker-amd.yml b/.github/workflows/release-docker-amd.yml index 228eecdb9c..ffe2843d51 100644 --- a/.github/workflows/release-docker-amd.yml +++ b/.github/workflows/release-docker-amd.yml @@ -14,7 +14,7 @@ jobs: environment: 'prod' strategy: matrix: - rocm_version: ['6.2.0'] + rocm_version: ['6.3.0'] build_type: ['all', 'srt'] steps: - name: Checkout repository @@ -41,8 +41,8 @@ jobs: run: | version=$(cat python/sglang/version.py | cut -d'"' -f2) - if [ "${{ matrix.rocm_version }}" = "6.2.0" ]; then - rocm_tag="rocm620" + if [ "${{ matrix.rocm_version }}" = "6.3.0" ]; then + rocm_tag="rocm630" else echo "Unsupported ROCm version" exit 1 diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index e1a242c87f..caa4666c88 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,8 +1,8 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm620 -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm630 -f Dockerfile.rocm . # default base image -ARG BASE_IMAGE="rocmshared/vllm-rocm:20250114-tuned-elementwise-layernorm" +ARG BASE_IMAGE="rocm/vllm-dev:20250114" FROM $BASE_IMAGE AS base USER root diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md index 96c9cae015..cde8c0aa90 100644 --- a/docs/developer/setup_github_runner.md +++ b/docs/developer/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm630 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm630 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/start/install.md b/docs/start/install.md index a5012d6fc7..b9702f0215 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -54,7 +54,7 @@ docker run --gpus all \ Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm620 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm630 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -63,11 +63,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.4.2.post1-rocm620 \ + v0.4.2.post1-rocm630 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.4.2.post1-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.2.post1-rocm630 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/python/pyproject.toml b/python/pyproject.toml index d3d8c3f2a5..cf997fc964 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -19,31 +19,29 @@ dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"] runtime_common = [ "aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "modelscope", - "orjson", "outlines>=0.0.44,<0.1.0", - "packaging", "pillow", "prometheus-client>=0.20.0", - "psutil", "pydantic", "python-multipart", - "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", - "xgrammar>=0.1.10" + "orjson", "packaging", "pillow", "prometheus-client>=0.20.0", + "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", + "torchao>=0.7.0", "uvicorn", "uvloop", "xgrammar>=0.1.10" ] srt = [ "sglang[runtime_common]", "cuda-python", "sgl-kernel>=0.0.3.post1", "torch", "vllm==0.6.4.post1", - "flashinfer==0.1.6" + "flashinfer==0.1.6", "outlines>=0.0.44,<0.1.0" ] # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl -srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post2.dev1"] +srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"] # xpu is not enabled in public vllm and torch whl, # need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm -srt_xpu = ["sglang[runtime_common]"] +srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<0.1.0"] #For Intel Gaudi(device : hpu) follow the installation guide #https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html -srt_hpu = ["sglang[runtime_common]"] +srt_hpu = ["sglang[runtime_common]", "outlines>=0.0.44,<0.1.0"] # CPU: currently, there are no pre-built vllm wheels for CPU. # To install vllm for CPU, please follow the instruction here: # https://docs.vllm.ai/en/latest/getting_started/installation/cpu/index.html -srt_cpu = ["sglang[runtime_common]", "torch"] +srt_cpu = ["sglang[runtime_common]", "torch", "outlines>=0.0.44,<0.1.0"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] diff --git a/python/sglang/srt/constrained/outlines_backend.py b/python/sglang/srt/constrained/outlines_backend.py index 4820d47395..91dbcba24f 100644 --- a/python/sglang/srt/constrained/outlines_backend.py +++ b/python/sglang/srt/constrained/outlines_backend.py @@ -20,7 +20,6 @@ import interegular import torch from outlines.fsm.guide import RegexGuide -from outlines.fsm.json_schema import build_regex_from_schema from outlines.models.transformers import TransformerTokenizer from pydantic import BaseModel @@ -29,6 +28,15 @@ BaseGrammarObject, ) from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap +from sglang.srt.utils import is_hip + +is_hip_ = is_hip() + +if is_hip_: + from outlines_core.fsm.json_schema import build_regex_from_schema +else: + from outlines.fsm.json_schema import build_regex_from_schema + logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index a702e8f822..c35790691e 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -20,7 +20,7 @@ def forward_cuda(self, *args, **kwargs): raise NotImplementedError def forward_hip(self, *args, **kwargs): - raise NotImplementedError + return self.forward_native(*args, **kwargs) def forward_xpu(self, *args, **kwargs): return self.forward_native(*args, **kwargs) From 28b0a62bb3032f9fb5ec505c7fc773cea4e19a08 Mon Sep 17 00:00:00 2001 From: zifeitong Date: Sun, 2 Feb 2025 15:36:07 -0800 Subject: [PATCH 32/52] Bug: Fix min_p sampling crash when using flashinfer backend (#3207) Co-authored-by: zhaochenyang20 --- python/sglang/srt/layers/sampler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 73ef13c35f..181aadeaa7 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -85,7 +85,7 @@ def forward( if sampling_info.need_min_p_sampling: probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_p_renorm_prob(probs, sampling_info.top_ps) - batch_next_token_ids, success = min_p_sampling_from_probs( + batch_next_token_ids = min_p_sampling_from_probs( probs, uniform_samples, sampling_info.min_ps ) else: @@ -97,9 +97,9 @@ def forward( filter_apply_order="joint", ) - if self.use_nan_detectioin and not torch.all(success): - logger.warning("Detected errors during sampling!") - batch_next_token_ids = torch.zeros_like(batch_next_token_ids) + if self.use_nan_detectioin and not torch.all(success): + logger.warning("Detected errors during sampling!") + batch_next_token_ids = torch.zeros_like(batch_next_token_ids) elif global_server_args_dict["sampling_backend"] == "pytorch": # A slower fallback implementation with torch native operations. From 455bfe8dd35f70fe0e34fd9971397be436a85f58 Mon Sep 17 00:00:00 2001 From: Liangjun Song Date: Mon, 3 Feb 2025 15:29:10 +1100 Subject: [PATCH 33/52] Add a Doc about guide on nvidia jetson #3182 (#3205) Co-authored-by: Shi Shuai <126407087+shuaills@users.noreply.github.com> Co-authored-by: zhaochenyang20 --- docs/index.rst | 1 + docs/references/nvidia_jetson.md | 67 ++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 docs/references/nvidia_jetson.md diff --git a/docs/index.rst b/docs/index.rst index b8067c25d8..f6f14725fd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -64,5 +64,6 @@ The core features include: references/modelscope.md references/contribution_guide.md references/troubleshooting.md + references/nvidia_jetson.md references/faq.md references/learn_more.md diff --git a/docs/references/nvidia_jetson.md b/docs/references/nvidia_jetson.md new file mode 100644 index 0000000000..a36a42ba49 --- /dev/null +++ b/docs/references/nvidia_jetson.md @@ -0,0 +1,67 @@ +# Apply SGLang on NVIDIA Jetson Orin + +## Prerequisites + +Before starting, ensure the following: + +- [**NVIDIA Jetson AGX Orin Devkit**](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/jetson-orin/) is set up with **JetPack 6.1** or later. +- **CUDA Toolkit** and **cuDNN** are installed. +- Verify that the Jetson AGX Orin is in **high-performance mode**: + ```bash + sudo nvpmodel -m 0 + ``` +- A custom PyPI index hosted at https://pypi.jetson-ai-lab.dev/jp6/cu126, tailored for NVIDIA Jetson Orin platforms and CUDA 12.6. + +To install torch from this index: + ```bash +pip install torch --index-url https://pypi.jetson-ai-lab.dev/jp6/cu126 + ``` +* * * * * +## Installation +Please refer to [Installation Guide](https://docs.sglang.ai/start/install.html) to install FlashInfer and SGLang. +* * * * * + +Running Inference +----------------------------------------- + +Launch the server: +```bash +python -m sglang.launch_server \ + --model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ + --device cuda \ + --dtype half \ + --attention-backend flashinfer \ + --mem-fraction-static 0.8 \ + --context-length 8192 +``` +The quantization and limited context length (`--dtype half --context-length 8192`) are due to the limited computational resources in [Nvidia jetson kit](https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/jetson-orin/). A detailed explanation can be found in [Server Arguments](https://docs.sglang.ai/backend/server_arguments.html). + +After launching the engine, refer to [Chat completions](https://docs.sglang.ai/backend/openai_api_completions.html#Usage) to test the usability. +* * * * * +Running quantization with TorchAO +------------------------------------- +TorchAO is suggested to NVIDIA Jetson Orin. +```bash +python -m sglang.launch_server \ + --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \ + --device cuda \ + --dtype bfloat16 \ + --attention-backend flashinfer \ + --mem-fraction-static 0.8 \ + --context-length 8192 \ + --torchao-config int4wo-128 +``` +This enables TorchAO's int4 weight-only quantization with a 128-group size. The usage of `--torchao-config int4wo-128` is also for memory efficiency. + + +* * * * * +Structured output with XGrammar +------------------------------- +Please refer to [SGLang doc structured output](https://docs.sglang.ai/backend/structured_outputs.html). +* * * * * + +Thanks to the support from [shahizat](https://github.com/shahizat). + +References +---------- +- [NVIDIA Jetson AGX Orin Documentation](https://developer.nvidia.com/embedded/jetson-agx-orin) From 3c8ac78dc143a15eae5d6a4fdf44aec78ab27a80 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Mon, 3 Feb 2025 18:56:18 +0800 Subject: [PATCH 34/52] optimize test_fused_moe style (#3268) --- test/srt/test_fused_moe.py | 129 ++++++++++++++++++++++++++----------- 1 file changed, 90 insertions(+), 39 deletions(-) diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index 80aeab257c..6534a4a60d 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -1,6 +1,8 @@ import unittest import torch +import torch.nn.functional as F +from tqdm import tqdm from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm from sglang.srt.layers.activation import SiluAndMul @@ -11,6 +13,37 @@ class TestFusedMOE(unittest.TestCase): NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] + @staticmethod + def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01): + """Create a random CUDA tensor + + Args: + shape: Tensor shape + dtype: Data type + mean: Mean value + std: Standard deviation + + Returns: + torch.Tensor: Randomly initialized CUDA tensor + """ + return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std) + + def get_tolerance(self, dtype): + """Get tolerance values for different data types + + Args: + dtype: Data type + + Returns: + tuple: (relative tolerance, absolute tolerance) + """ + if dtype == torch.float32: + return 1e-3, 1e-5 + elif dtype in [torch.float16, torch.bfloat16]: + return 1e-1, 1e-2 + else: + return 1e-2, 1e-2 # Default values for other types + def torch_naive_moe(self, a, w1, w2, score, topk): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) @@ -30,23 +63,25 @@ def torch_naive_moe(self, a, w1, w2, score, topk): ).sum(dim=1) def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): + rtol, atol = self.get_tolerance(dtype) + if use_fp8_w8a8: # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 capability = torch.cuda.get_device_capability() if not (capability[0] >= 9 or capability == (8, 9)): return - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + a = self.create_random_cuda_tensor((m, k), dtype) + w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) + w2 = self.create_random_cuda_tensor((e, k, n), dtype) w1 = w1.to(torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fn) - score = torch.randn((m, e), device="cuda", dtype=dtype) + score = self.create_random_cuda_tensor((m, e), dtype) - w1_scale = torch.randn(e, dtype=torch.float32, device="cuda") - w2_scale = torch.randn(e, dtype=torch.float32, device="cuda") - a1_scale = torch.randn(1, dtype=torch.float32, device="cuda") - a2_scale = torch.randn(1, dtype=torch.float32, device="cuda") + w1_scale = self.create_random_cuda_tensor(e, torch.float32) + w2_scale = self.create_random_cuda_tensor(e, torch.float32) + a1_scale = self.create_random_cuda_tensor(1, torch.float32) + a2_scale = self.create_random_cuda_tensor(1, torch.float32) sglang_output = fused_moe( a, @@ -76,17 +111,19 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): a2_scale=a2_scale, ) - torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0) + torch.testing.assert_close(sglang_output, vllm_output, rtol=rtol, atol=atol) else: - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) + a = self.create_random_cuda_tensor((m, k), dtype) + w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) + w2 = self.create_random_cuda_tensor((e, k, n), dtype) + score = self.create_random_cuda_tensor((m, e), dtype) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = self.torch_naive_moe(a, w1, w2, score, topk) - torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close( + triton_output, torch_output, rtol=rtol, atol=atol + ) def test_various_configurations(self): m_values = [1, 33, 64, 222, 1024 * 128] @@ -95,31 +132,45 @@ def test_various_configurations(self): dtypes = [torch.float16, torch.bfloat16] fp8_modes = [False, True] - for m in m_values: - for n in n_values: - for k in k_values: - for e in self.NUM_EXPERTS: - for topk in self.TOP_KS: - for dtype in dtypes: - for use_fp8_w8a8 in fp8_modes: - with self.subTest( - m=m, - n=n, - k=k, - e=e, - topk=topk, - dtype=dtype, - fp8=use_fp8_w8a8, - ): - self._test_case( - m, - n, - k, - e, - topk, - dtype, - use_fp8_w8a8=use_fp8_w8a8, - ) + # Calculate total number of tests + total_tests = ( + len(m_values) + * len(n_values) + * len(k_values) + * len(self.NUM_EXPERTS) + * len(self.TOP_KS) + * len(dtypes) + * len(fp8_modes) + ) + + # Create progress bar + with tqdm(total=total_tests, desc="Running MoE tests") as pbar: + for m in m_values: + for n in n_values: + for k in k_values: + for e in self.NUM_EXPERTS: + for topk in self.TOP_KS: + for dtype in dtypes: + for use_fp8_w8a8 in fp8_modes: + with self.subTest( + m=m, + n=n, + k=k, + e=e, + topk=topk, + dtype=dtype, + fp8=use_fp8_w8a8, + ): + self._test_case( + m, + n, + k, + e, + topk, + dtype, + use_fp8_w8a8=use_fp8_w8a8, + ) + pbar.update(1) if __name__ == "__main__": From 013021b6a1c3a95fb9569ff730d047c960c78380 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 3 Feb 2025 20:52:30 +0800 Subject: [PATCH 35/52] refactor EAGLE 2 (#3269) Co-authored-by: Ying Sheng Co-authored-by: merrymercy Co-authored-by: Ying1123 --- .../engine/EAGLE_offline_batch_inference.py | 1 + .../layers/attention/flashinfer_backend.py | 333 +++++++- .../srt/model_executor/cuda_graph_runner.py | 155 ++-- .../srt/model_executor/forward_batch_info.py | 117 ++- .../sglang/srt/model_executor/model_runner.py | 3 +- .../srt/speculative/build_eagle_tree.py | 6 +- .../eagle_draft_cuda_graph_runner.py | 213 +++++ python/sglang/srt/speculative/eagle_utils.py | 728 +++++++++--------- python/sglang/srt/speculative/eagle_worker.py | 220 ++++-- 9 files changed, 1180 insertions(+), 596 deletions(-) create mode 100644 python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py diff --git a/examples/runtime/engine/EAGLE_offline_batch_inference.py b/examples/runtime/engine/EAGLE_offline_batch_inference.py index 0885959b3f..897d50ae2d 100644 --- a/examples/runtime/engine/EAGLE_offline_batch_inference.py +++ b/examples/runtime/engine/EAGLE_offline_batch_inference.py @@ -21,6 +21,7 @@ def main(): speculative_num_steps=3, speculative_eagle_topk=4, speculative_num_draft_tokens=16, + cuda_graph_max_bs=8, ) outputs = llm.generate(prompts, sampling_params) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index cc6da781f5..863cb031db 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -10,6 +10,7 @@ import os from dataclasses import dataclass from enum import Enum, auto +from functools import partial from typing import TYPE_CHECKING, List, Optional, Union import torch @@ -34,6 +35,7 @@ BatchPrefillWithRaggedKVCacheWrapper, ) from flashinfer.cascade import merge_state + from flashinfer.decode import PosEncodingMode class WrapperDispatch(Enum): @@ -53,10 +55,19 @@ class PrefillMetadata: extend_no_prefix: bool +# Reuse this workspace buffer across all flashinfer wrappers +global_workspace_buffer = None + + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" - def __init__(self, model_runner: ModelRunner): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + ): super().__init__() # Parse constants @@ -69,6 +80,7 @@ def __init__(self, model_runner: ModelRunner): ), ) self.max_context_len = model_runner.model_config.context_len + self.skip_prefill = skip_prefill assert not ( model_runner.sliding_window_size is not None @@ -90,16 +102,26 @@ def __init__(self, model_runner: ModelRunner): global_config.flashinfer_workspace_size = 512 * 1024 * 1024 # Allocate buffers - self.workspace_buffer = torch.empty( - global_config.flashinfer_workspace_size, - dtype=torch.uint8, - device=model_runner.device, - ) + global global_workspace_buffer + if global_workspace_buffer is None: + global_workspace_buffer = torch.empty( + global_config.flashinfer_workspace_size, + dtype=torch.uint8, + device=model_runner.device, + ) + self.workspace_buffer = global_workspace_buffer max_bs = model_runner.req_to_token_pool.size - self.kv_indptr = [ - torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device) - for _ in range(self.num_wrappers) - ] + if kv_indptr_buf is None: + self.kv_indptr = [ + torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + for _ in range(self.num_wrappers) + ] + else: + assert self.num_wrappers == 1 + self.kv_indptr = [kv_indptr_buf] + self.kv_last_page_len = torch.ones( (max_bs,), dtype=torch.int32, device=model_runner.device ) @@ -122,12 +144,16 @@ def __init__(self, model_runner: ModelRunner): self.prefill_wrappers_verify = [] self.decode_wrappers = [] for _ in range(self.num_wrappers): - self.prefill_wrappers_paged.append( - BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") - ) - self.prefill_wrappers_verify.append( - BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") - ) + if not skip_prefill: + self.prefill_wrappers_paged.append( + BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + ) + ) + self.prefill_wrappers_verify.append( + BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + ) self.decode_wrappers.append( BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, @@ -137,10 +163,11 @@ def __init__(self, model_runner: ModelRunner): ) # Create indices updater + if not skip_prefill: + self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill( + model_runner, self + ) self.indices_updater_decode = FlashInferIndicesUpdaterDecode(model_runner, self) - self.indices_updater_prefill = FlashInferIndicesUpdaterPrefill( - model_runner, self - ) # Other metadata self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None @@ -211,23 +238,30 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.prefill_wrappers_paged, use_ragged, extend_no_prefix ) - def init_cuda_graph_state(self, max_bs: int): - cuda_graph_kv_indices = torch.zeros( - (max_bs * self.max_context_len,), - dtype=torch.int32, - device="cuda", - ) + def init_cuda_graph_state( + self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None + ): + if kv_indices_buf is None: + cuda_graph_kv_indices = torch.zeros( + (max_bs * self.max_context_len,), + dtype=torch.int32, + device="cuda", + ) + else: + cuda_graph_kv_indices = kv_indices_buf + self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [ cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) ] - self.cuda_graph_custom_mask = torch.zeros( - (max_bs * self.max_context_len), - dtype=torch.uint8, - device="cuda", - ) - self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] - self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] + if not self.skip_prefill: + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.uint8, + device="cuda", + ) + self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] + self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] def init_forward_metadata_capture_cuda_graph( self, @@ -602,11 +636,8 @@ def call_begin_forward( self.req_to_token.shape[1], ) else: - bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode( - req_pool_indices, - paged_kernel_lens, - self.req_to_token, - ) + kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices + bs = kv_indptr.shape[0] - 1 wrapper.end_forward() wrapper.begin_forward( @@ -854,6 +885,132 @@ def call_begin_forward( ) +class FlashInferMultiStepDraftBackend: + """ + Wrap multiple flashinfer attention backends as one for multiple consecutive + draft decoding steps. + """ + + def __init__( + self, + model_runner: ModelRunner, + topk: int, + speculative_num_steps: int, + ): + from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices + max_bs = model_runner.req_to_token_pool.size + self.kv_indptr = torch.zeros( + ( + self.speculative_num_steps, + max_bs + 1, + ), + dtype=torch.int32, + device=model_runner.device, + ) + self.attn_backends = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + FlashInferAttnBackend( + model_runner, + skip_prefill=True, + kv_indptr_buf=self.kv_indptr[i], + ) + ) + self.max_context_len = self.attn_backends[0].max_context_len + # Cached variables for generate_draft_decode_kv_indices + self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] + self.kv_indptr_stride = self.kv_indptr.shape[1] + + def common_template(self, forward_batch: ForwardBatch, call_fn: int): + num_seqs = forward_batch.batch_size + bs = self.topk * num_seqs + seq_lens_sum = forward_batch.seq_lens_sum + self.generate_draft_decode_kv_indices[ + (self.speculative_num_steps, num_seqs, self.topk) + ]( + forward_batch.req_pool_indices, + forward_batch.req_to_token_pool.req_to_token, + forward_batch.seq_lens, + self.cuda_graph_kv_indices, + self.kv_indptr, + forward_batch.positions, + num_seqs, + self.topk, + self.pool_len, + self.kv_indptr_stride, + self.kv_indptr.shape[1], + triton.next_power_of_2(num_seqs), + triton.next_power_of_2(self.speculative_num_steps), + triton.next_power_of_2(bs), + ) + for i in range(self.speculative_num_steps): + forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] + forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][ + : seq_lens_sum * self.topk + bs * (i + 1) + ] + call_fn(i, forward_batch) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + forward_batch.spec_info.kv_indptr = ( + forward_batch.spec_info.kv_indptr.clone() + ) + forward_batch.spec_info.kv_indices = ( + forward_batch.spec_info.kv_indices.clone() + ) + self.attn_backends[i].init_forward_metadata(forward_batch) + + self.common_template(forward_batch, call_fn) + + def init_cuda_graph_state(self, max_bs: int): + self.cuda_graph_kv_indices = torch.zeros( + (self.speculative_num_steps, max_bs * self.max_context_len), + dtype=torch.int32, + device="cuda", + ) + self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1] + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state( + max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] + ) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + decode_wrapper = self.attn_backends[i].decode_cuda_graph_metadata[ + forward_batch.batch_size + ][0] + decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper) + + self.common_template(forward_batch, call_fn) + + def init_forward_metadata_replay_cuda_graph(self, forward_batch): + def call_fn(i, forward_batch): + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + forward_batch.batch_size, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + seq_lens_sum=-1, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + self.common_template(forward_batch, call_fn) + + @triton.jit def create_flashinfer_kv_indices_triton( req_to_token_ptr, # [max_batch, max_context_len] @@ -937,3 +1094,105 @@ def should_use_tensor_core( return gqa_group_size > 4 else: return False + + +def fast_decode_plan( + self, + indptr: torch.Tensor, + indices: torch.Tensor, + last_page_len: torch.Tensor, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + page_size: int, + pos_encoding_mode: str = "NONE", + window_left: int = -1, + logits_soft_cap: Optional[float] = None, + data_type: Union[str, torch.dtype] = "float16", + q_data_type: Optional[Union[str, torch.dtype]] = None, + sm_scale: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, +) -> None: + """A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.""" + batch_size = len(last_page_len) + if logits_soft_cap is None: + logits_soft_cap = 0.0 + if self.is_cuda_graph_enabled: + if batch_size != self._fixed_batch_size: + raise ValueError( + "The batch size should be fixed in cudagraph mode, the runtime batch size {} " + " mismatches the batch size set during initialization {}".format( + batch_size, self._fixed_batch_size + ) + ) + if len(indices) > len(self._paged_kv_indices_buf): + raise ValueError( + "The size of indices should be less than or equal to the allocated buffer" + ) + else: + self._paged_kv_indptr_buf = indptr + self._paged_kv_indices_buf = indices + self._paged_kv_last_page_len_buf = last_page_len + # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info + if not q_data_type: + q_data_type = data_type + if not hasattr(self, "empty_q_data"): + self.empty_q_data = torch.empty( + 0, + dtype=( + getattr(torch, q_data_type) + if isinstance(q_data_type, str) + else q_data_type + ), + ) + self.empty_kv_cache = torch.empty( + 0, + dtype=( + getattr(torch, data_type) if isinstance(data_type, str) else data_type + ), + ) + self.last_page_len = torch.ones(32768, dtype=torch.int32) + empty_q_data = self.empty_q_data + empty_kv_cache = self.empty_kv_cache + if self.use_tensor_cores: + if not self.is_cuda_graph_enabled: + # when not using cudagraph, we need to create the indptr buffer, otherwise + # the buffer is already created during initialization + self._qo_indptr_buf = torch.arange( + batch_size + 1, dtype=torch.int32, device=indptr.device + ) + self._wrapper.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._qo_indptr_buf, + indptr, + batch_size, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + empty_q_data, + ) + else: + self._wrapper.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + indptr, + self.last_page_len, + batch_size, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + PosEncodingMode[pos_encoding_mode].value, + logits_soft_cap, + empty_q_data, + empty_kv_cache, + ) + self._pos_encoding_mode = pos_encoding_mode + self._window_left = window_left + self._logits_soft_cap = logits_soft_cap + self._sm_scale = sm_scale + self._rope_scale = rope_scale + self._rope_theta = rope_theta diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 69615b8ff3..1f5e8e8518 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -103,69 +103,75 @@ def set_torch_compile_config(): torch._dynamo.config.cache_size_limit = 1024 +def get_batch_sizes_to_capture(model_runner: ModelRunner): + server_args = model_runner.server_args + capture_bs = server_args.cuda_graph_bs + if capture_bs is None: + if server_args.disable_cuda_graph_padding: + capture_bs = list(range(1, 33)) + [64, 128] + else: + capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + if max(capture_bs) > model_runner.req_to_token_pool.size: + # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests + # is very samll. We add more values here to make sure we capture the maximum bs. + capture_bs = list( + sorted( + set( + capture_bs + + [model_runner.req_to_token_pool.size - 1] + + [model_runner.req_to_token_pool.size] + ) + ) + ) + capture_bs = [ + bs + for bs in capture_bs + if bs <= model_runner.req_to_token_pool.size + and bs <= server_args.cuda_graph_max_bs + ] + compile_bs = ( + [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] + if server_args.enable_torch_compile + else [] + ) + return capture_bs, compile_bs + + +# Reuse this memory pool across all cuda graph runners. +global_graph_memory_pool = None + + +def get_global_graph_memory_pool(): + return global_graph_memory_pool + + +def set_global_graph_memory_pool(val): + global global_graph_memory_pool + global_graph_memory_pool = val + + class CudaGraphRunner: """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" - def __init__(self, model_runner: "ModelRunner"): + def __init__(self, model_runner: ModelRunner): # Parse args self.model_runner = model_runner self.graphs = {} - self.input_buffers = {} self.output_buffers = {} - self.flashinfer_handlers = {} - self.graph_memory_pool = None - self.use_torch_compile = model_runner.server_args.enable_torch_compile + self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.disable_padding = model_runner.server_args.disable_cuda_graph_padding - self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder - self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention - self.tp_size = self.model_runner.tp_size - self.dp_size = self.model_runner.server_args.dp_size + self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder + self.enable_dp_attention = model_runner.server_args.enable_dp_attention + self.tp_size = model_runner.server_args.tp_size + self.dp_size = model_runner.server_args.dp_size # Batch sizes to capture - self.capture_bs = self.model_runner.server_args.cuda_graph_bs - if self.capture_bs is None: - if model_runner.server_args.disable_cuda_graph_padding: - self.capture_bs = list(range(1, 33)) + [64, 128] - else: - self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] - - if max(self.capture_bs) > model_runner.req_to_token_pool.size: - # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests - # is very samll. We add more values here to make sure we capture the maximum bs. - self.capture_bs = list( - sorted( - set( - self.capture_bs - + [model_runner.req_to_token_pool.size - 1] - + [model_runner.req_to_token_pool.size] - ) - ) - ) - - self.capture_bs = [ - bs - for bs in self.capture_bs - if bs <= model_runner.req_to_token_pool.size - and bs <= model_runner.server_args.cuda_graph_max_bs - ] - - self.compile_bs = ( - [ - bs - for bs in self.capture_bs - if bs <= self.model_runner.server_args.torch_compile_max_bs - ] - if self.use_torch_compile - else [] - ) - + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture_forward_mode = ForwardMode.DECODE self.num_tokens_per_bs = 1 if model_runner.spec_algorithm.is_eagle(): if self.model_runner.is_draft_worker: - self.num_tokens_per_bs = ( - self.model_runner.server_args.speculative_eagle_topk - ) + raise RuntimeError("This should not happen") else: self.capture_forward_mode = ForwardMode.TARGET_VERIFY self.num_tokens_per_bs = ( @@ -182,10 +188,10 @@ def __init__(self, model_runner: "ModelRunner"): # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 - if self.use_torch_compile: + if self.enable_torch_compile: set_torch_compile_config() - # Common inputs + # Graph inputs with torch.device("cuda"): self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) @@ -301,7 +307,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable): stream = self.stream num_tokens = bs * self.num_tokens_per_bs - # Common inputs + # Graph inputs input_ids = self.input_ids[:num_tokens] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] @@ -320,7 +326,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable): global_num_tokens = None gathered_buffer = None - spec_info = self.get_spec_info(num_tokens, positions) + spec_info = self.get_spec_info(num_tokens) forward_batch = ForwardBatch( forward_mode=self.capture_forward_mode, @@ -335,7 +341,6 @@ def capture_one_batch_size(self, bs: int, forward: Callable): seq_lens_sum=seq_lens.sum(), encoder_lens=encoder_lens, return_logprob=False, - top_logprobs_nums=[0] * bs, positions=positions, global_num_tokens=global_num_tokens, gathered_buffer=gathered_buffer, @@ -375,13 +380,14 @@ def run_once(): torch.cuda.synchronize() self.model_runner.tp_group.barrier() - with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream): + global global_graph_memory_pool + with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream): out = run_once() torch.cuda.synchronize() self.model_runner.tp_group.barrier() - self.graph_memory_pool = graph.pool() + global_graph_memory_pool = graph.pool() return graph, out def replay(self, forward_batch: ForwardBatch): @@ -439,35 +445,26 @@ def replay(self, forward_batch: ForwardBatch): ) return logits_output - def get_spec_info(self, num_tokens: int, positions: torch.Tensor): + def get_spec_info(self, num_tokens: int): spec_info = None if self.model_runner.spec_algorithm.is_eagle(): - from sglang.srt.speculative.eagle_utils import ( - EAGLEDraftInput, - EagleVerifyInput, - ) + from sglang.srt.speculative.eagle_utils import EagleVerifyInput if self.model_runner.is_draft_worker: - spec_info = EAGLEDraftInput() - spec_info.load_server_args(self.model_runner.server_args) - spec_info.hidden_states = self.hidden_states[:num_tokens] - spec_info.positions = positions - spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + raise RuntimeError("This should not happen.") else: spec_info = EagleVerifyInput( - None, - None, - None, - None, - None, - None, - self.model_runner.server_args.speculative_num_draft_tokens, - ) - spec_info.custom_mask = torch.zeros( - (num_tokens * self.model_runner.model_config.context_len), - dtype=torch.bool, - device="cuda", + draft_token=None, + custom_mask=torch.zeros( + (num_tokens * self.model_runner.model_config.context_len), + dtype=torch.bool, + device="cuda", + ), + positions=None, + retrive_index=None, + retrive_cum_len=None, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + capture_hidden_mode=CaptureHiddenMode.FULL, ) - spec_info.capture_hidden_mode = CaptureHiddenMode.FULL return spec_info diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 8bd1052754..b36dedc9fd 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -197,64 +197,6 @@ class ForwardBatch: # For Qwen2-VL mrope_positions: torch.Tensor = None - def compute_mrope_positions( - self, model_runner: ModelRunner, batch: ModelWorkerBatch - ): - device = model_runner.device - hf_config = model_runner.model_config.hf_config - mrope_positions_list = [None] * self.seq_lens.shape[0] - if self.forward_mode.is_decode(): - for i, _ in enumerate(mrope_positions_list): - mrope_position_delta = ( - 0 - if batch.image_inputs[i] is None - else batch.image_inputs[i].mrope_position_delta - ) - mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions( - mrope_position_delta, - int(self.seq_lens[i]) - 1, - int(self.seq_lens[i]), - ) - elif self.forward_mode.is_extend(): - extend_start_loc_cpu = self.extend_start_loc.cpu().numpy() - for i, image_inputs in enumerate(batch.image_inputs): - extend_start_loc, extend_seq_len, extend_prefix_len = ( - extend_start_loc_cpu[i], - batch.extend_seq_lens[i], - batch.extend_prefix_lens[i], - ) - if image_inputs is None: - # text only - mrope_positions = [ - [ - pos - for pos in range( - extend_prefix_len, extend_prefix_len + extend_seq_len - ) - ] - ] * 3 - else: - # TODO: current qwen2-vl do not support radix cache since mrope position calculation - mrope_positions, mrope_position_delta = ( - MRotaryEmbedding.get_input_positions( - input_tokens=self.input_ids[ - extend_start_loc : extend_start_loc + extend_seq_len - ], - image_grid_thw=image_inputs.image_grid_thws, - vision_start_token_id=hf_config.vision_start_token_id, - spatial_merge_size=hf_config.vision_config.spatial_merge_size, - context_len=0, - ) - ) - batch.image_inputs[i].mrope_position_delta = mrope_position_delta - mrope_positions_list[i] = mrope_positions - - self.mrope_positions = torch.concat( - [torch.tensor(pos, device=device) for pos in mrope_positions_list], - axis=1, - ) - self.mrope_positions = self.mrope_positions.to(torch.int64) - @classmethod def init_new( cls, @@ -337,7 +279,7 @@ def init_new( ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens if model_runner.model_is_mrope: - ret.compute_mrope_positions(model_runner, batch) + ret._compute_mrope_positions(model_runner, batch) # Init lora information if model_runner.server_args.lora_paths is not None: @@ -345,6 +287,63 @@ def init_new( return ret + def _compute_mrope_positions( + self, model_runner: ModelRunner, batch: ModelWorkerBatch + ): + device = model_runner.device + hf_config = model_runner.model_config.hf_config + mrope_positions_list = [None] * self.seq_lens.shape[0] + if self.forward_mode.is_decode(): + for i, _ in enumerate(mrope_positions_list): + mrope_position_delta = ( + 0 + if batch.image_inputs[i] is None + else batch.image_inputs[i].mrope_position_delta + ) + mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions( + mrope_position_delta, + int(self.seq_lens[i]) - 1, + int(self.seq_lens[i]), + ) + elif self.forward_mode.is_extend(): + extend_start_loc_cpu = self.extend_start_loc.cpu().numpy() + for i, image_inputs in enumerate(batch.image_inputs): + extend_start_loc, extend_seq_len, extend_prefix_len = ( + extend_start_loc_cpu[i], + batch.extend_seq_lens[i], + batch.extend_prefix_lens[i], + ) + if image_inputs is None: + # text only + mrope_positions = [ + [ + pos + for pos in range( + extend_prefix_len, extend_prefix_len + extend_seq_len + ) + ] + ] * 3 + else: + # TODO: current qwen2-vl do not support radix cache since mrope position calculation + mrope_positions, mrope_position_delta = ( + MRotaryEmbedding.get_input_positions( + input_tokens=self.input_ids[ + extend_start_loc : extend_start_loc + extend_seq_len + ], + image_grid_thw=image_inputs.image_grid_thws, + vision_start_token_id=hf_config.vision_start_token_id, + spatial_merge_size=hf_config.vision_config.spatial_merge_size, + context_len=0, + ) + ) + batch.image_inputs[i].mrope_position_delta = mrope_position_delta + mrope_positions_list[i] = mrope_positions + self.mrope_positions = torch.concat( + [torch.tensor(pos, device=device) for pos in mrope_positions_list], + axis=1, + ) + self.mrope_positions = self.mrope_positions.to(torch.int64) + def compute_position_triton( extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6fa1429dc2..5b19c77e26 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -52,6 +52,7 @@ MLATokenToKVPool, ReqToTokenPool, ) +from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model from sglang.srt.server_args import ServerArgs @@ -714,8 +715,6 @@ def init_double_sparsity_channel_config(self, selected_channel): def init_cuda_graphs(self): """Capture cuda graphs.""" - from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner - self.cuda_graph_runner = None if not self.is_generation: diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index 6412825ed8..e0ac9fe0bb 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -79,11 +79,13 @@ ) -def build_tree_kernel(parent_list, top_score_index, seq_lens, topk, depth, draft_token): +def build_tree_kernel( + parent_list, top_score_index, seq_lens, seq_lens_sum, topk, depth, draft_token +): bs = seq_lens.numel() device = parent_list.device tree_mask = torch.full( - (torch.sum(seq_lens).item() * draft_token + draft_token * draft_token * bs,), + (seq_lens_sum * draft_token + draft_token * draft_token * bs,), True, device=device, ) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py new file mode 100644 index 0000000000..41ff5c19e5 --- /dev/null +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import bisect +import time +from typing import TYPE_CHECKING, Callable + +import torch + +from sglang.srt.model_executor.cuda_graph_runner import ( + CudaGraphRunner, + get_batch_sizes_to_capture, + get_global_graph_memory_pool, + set_global_graph_memory_pool, + set_torch_compile_config, +) +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.speculative.eagle_utils import EagleDraftInput + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.eagle_worker import EAGLEWorker + + +class EAGLEDraftCudaGraphRunner: + def __init__(self, eagle_worker: EAGLEWorker): + # Parse args + self.eagle_worker = eagle_worker + self.model_runner = model_runner = eagle_worker.model_runner + self.graphs = {} + self.output_buffers = {} + self.enable_torch_compile = model_runner.server_args.enable_torch_compile + self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.tp_size = self.model_runner.tp_size + self.dp_size = model_runner.server_args.dp_size + self.topk = model_runner.server_args.speculative_eagle_topk + self.speculative_num_steps = model_runner.server_args.speculative_num_steps + server_args = model_runner.server_args + + assert self.disable_padding + + # Batch sizes to capture + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) + self.num_tokens_per_bs = server_args.speculative_eagle_topk + + # Attention backend + self.max_bs = max(self.capture_bs) + self.max_num_token = self.max_bs * self.num_tokens_per_bs + self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token) + self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[ + 0 + ].get_cuda_graph_seq_len_fill_value() + + if self.enable_torch_compile: + set_torch_compile_config() + + # Graph inputs + with torch.device("cuda"): + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) + self.seq_lens = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) + self.out_cache_loc = torch.zeros( + (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64 + ) + self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) + self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) + self.hidden_states = torch.zeros( + (self.max_bs, self.model_runner.model_config.hidden_size), + dtype=self.model_runner.dtype, + ) + + # Capture + try: + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture cuda graph failed: {e}\n" + "Possible solutions:\n" + "1. disable cuda graph by --disable-cuda-graph\n" + "2. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" + "3. disable torch compile by not using --enable-torch-compile\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" + ) + + def can_run(self, forward_batch: ForwardBatch): + is_bs_supported = ( + forward_batch.batch_size in self.graphs + if self.disable_padding + else forward_batch.batch_size <= self.max_bs + ) + return is_bs_supported + + def capture(self): + CudaGraphRunner.capture(self) + + def capture_one_batch_size(self, num_seqs: int, forward: Callable): + graph = torch.cuda.CUDAGraph() + stream = self.stream + num_tokens = num_seqs * self.num_tokens_per_bs + + # Graph inputs + req_pool_indices = self.req_pool_indices[:num_seqs] + seq_lens = self.seq_lens[:num_seqs] + out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps] + positions = self.positions[:num_tokens] + topk_p = self.topk_p[:num_seqs] + topk_index = self.topk_index[:num_seqs] + hidden_states = self.hidden_states[:num_seqs] + + spec_info = EagleDraftInput( + topk_p=topk_p, + topk_index=topk_index, + hidden_states=hidden_states, + ) + + # Forward batch + forward_batch = ForwardBatch( + forward_mode=ForwardMode.DECODE, + batch_size=num_seqs, + input_ids=None, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens.sum(), + return_logprob=False, + positions=positions, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=spec_info, + capture_hidden_mode=( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ), + ) + + # Attention backend + self.model_runner.draft_attn_backend.init_forward_metadata_capture_cuda_graph( + forward_batch + ) + + # Run and capture + def run_once(): + # Backup two fileds, which will be modified in-place in `draft_forward`. + output_cache_loc_backup = forward_batch.out_cache_loc + hidden_states_backup = forward_batch.spec_info.hidden_states + + ret = self.eagle_worker.draft_forward(forward_batch) + + forward_batch.out_cache_loc = output_cache_loc_backup + forward_batch.spec_info.hidden_states = hidden_states_backup + return ret + + for _ in range(2): + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + run_once() + + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + with torch.cuda.graph( + graph, pool=get_global_graph_memory_pool(), stream=stream + ): + out = run_once() + + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + set_global_graph_memory_pool(graph.pool()) + return graph, out + + def replay(self, forward_batch: ForwardBatch): + assert forward_batch.out_cache_loc is not None + raw_bs = forward_batch.batch_size + raw_num_token = raw_bs * self.num_tokens_per_bs + + # Pad + index = bisect.bisect_left(self.capture_bs, raw_bs) + bs = self.capture_bs[index] + if bs != raw_bs: + self.seq_lens.fill_(1) + self.out_cache_loc.zero_() + + # Common inputs + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + self.out_cache_loc[: raw_num_token * self.speculative_num_steps].copy_( + forward_batch.out_cache_loc + ) + self.positions[:raw_num_token].copy_(forward_batch.positions) + self.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) + self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) + self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + + # Attention backend + self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph( + forward_batch + ) + + # Replay + self.graphs[bs].replay() + + return self.output_buffers[bs] diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 97cdb26404..0b8c99f041 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses from typing import TYPE_CHECKING, List import torch @@ -9,201 +10,33 @@ from sglang.srt.layers.attention.flashinfer_backend import ( create_flashinfer_kv_indices_triton, ) -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.build_eagle_tree import build_tree_kernel -from sglang.srt.speculative.spec_info import SpecInfo if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch - from sglang.srt.server_args import ServerArgs -@triton.jit -def eagle_verify_retrive( - retrive_index, - accept_mask, - retrive_cum_len, - accept_index, - accept_length, - extract_index, - max_len: tl.constexpr, - draft_token_num: tl.constexpr, - max_len_upper: tl.constexpr, -): - pid = tl.program_id(axis=0) - - retrive_end = tl.load(retrive_cum_len + pid + 1) - retrive_start = tl.load(retrive_cum_len + pid) - retrive_len = retrive_end - retrive_start - accept_ptr = accept_mask + retrive_start - accept_offset = tl.arange(0, draft_token_num) - accept_load_mask = accept_offset < retrive_len - accept_len_list = tl.load( - accept_ptr + accept_offset, mask=accept_load_mask, other=-1 - ) - - accept_len = tl.max(accept_len_list) - max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True) - # triton is not support argmax with tie_break_right, so I need implement it by some way - mask_max = accept_len_list == accept_len - - count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32) - count = tl.sum(tl.where(mask_max, 1, count_mask)) - if count > 1: - index = tl.arange(0, draft_token_num) - mask_left = index != max_index - remained_index = tl.where(mask_max and mask_left, index, 0) - max_index = tl.max(remained_index) - - tl.store(accept_length + pid, accept_len) - retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len - retrive_offset = tl.arange(0, max_len_upper) - retrive_load_mask = retrive_offset < accept_len + 1 - data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask) - - tl.store( - accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask - ) - - extract_load_ptr = accept_index + pid * max_len + accept_len - if accept_len == max_len - 1: - extract_data = tl.load(extract_load_ptr - 1) - tl.store(extract_index + pid * 2, extract_data) - extract_data = tl.load(extract_load_ptr) - tl.store(extract_index + pid * 2 + 1, extract_data) - - else: - extract_data = tl.load(extract_load_ptr) - tl.store(extract_index + pid * 2, extract_data) - - -@triton.jit -def create_extend_spec_info( - verified_id, - seq_len, - accept_len, - accept_len_cum, - positions, - new_verified_id, - accept_len_upper: tl.constexpr, -): - pid = tl.program_id(axis=0) - offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1) - seq_length = tl.load(seq_len + pid) - accept_length = tl.load(accept_len + pid) - positions_ptr = positions + offset - data = tl.arange(0, accept_len_upper) - mask = data < accept_length - tl.store(positions_ptr + data, seq_length - accept_length + data, mask) - - offset = tl.load(accept_len_cum + pid) - 1 - verified_id_data = tl.load(verified_id + offset) - tl.store(new_verified_id + pid, verified_id_data) - - -@triton.jit -def assign_req_to_token_pool( - req_pool_indices, - req_to_token, - start_offset, - end_offset, - out_cache_loc, - pool_len: tl.constexpr, - bs_upper: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 32 - pid = tl.program_id(axis=0) - kv_start = tl.load(start_offset + pid) - kv_end = tl.load(end_offset + pid) - token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len - - length_offset = tl.arange(0, bs_upper) - start = tl.load(start_offset + length_offset, mask=length_offset < pid) - end = tl.load(end_offset + length_offset, mask=length_offset < pid) - out_offset = tl.sum(end - start, axis=0) - - out_cache_ptr = out_cache_loc + out_offset - - save_offset = tl.arange(0, BLOCK_SIZE) + kv_start - load_offset = tl.arange(0, BLOCK_SIZE) - - num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for _ in range(num_loop): - mask = save_offset < kv_end - data = tl.load(out_cache_ptr + load_offset, mask=mask) - tl.store(token_pool + save_offset, data, mask=mask) - save_offset += BLOCK_SIZE - load_offset += BLOCK_SIZE +@dataclasses.dataclass +class EagleDraftInput: + # The inputs for decode + # shape: (b, topk) + topk_p: torch.Tensor = None + topk_index: torch.Tensor = None + # shape: (b, hidden_size) + hidden_states: torch.Tensor = None + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL + # Inputs for extend + # shape: (b,) + verified_id: torch.Tensor = None + accept_length: torch.Tensor = None + accept_length_cpu: List[int] = None -@triton.jit -def generate_draft_decode_kv_indices( - req_pool_indices, - req_to_token, - paged_kernel_lens, - kv_indices, - iters: tl.constexpr, - topk: tl.constexpr, - pool_len: tl.constexpr, - bs_upper: tl.constexpr, - iter_upper: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 128 - bid = tl.program_id(axis=0) - topk_id = tl.program_id(axis=1) - - load_offset = tl.arange(0, bs_upper) - seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid) - seq_len = tl.load(paged_kernel_lens + bid) - cum_seq_len = tl.sum(seq_lens) - - kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) - kv_ptr = kv_indices + kv_offset - token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len - - kv_offset = tl.arange(0, BLOCK_SIZE) - num_loop = tl.cdiv(seq_len, BLOCK_SIZE) - for _ in range(num_loop): - mask = kv_offset < seq_len - data = tl.load(token_pool_ptr + kv_offset, mask=mask) - tl.store(kv_ptr + kv_offset, data, mask=mask) - kv_offset += BLOCK_SIZE - - extend_offset = tl.arange(0, iter_upper) - extend_data = tl.load( - token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id, - mask=extend_offset < iters, - ) - tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) - - -class EAGLEDraftInput(SpecInfo): - def __init__(self): - self.prev_mode = ForwardMode.DECODE - - self.scores: torch.Tensor = None - self.score_list: List[torch.Tensor] = [] - self.token_list: List[torch.Tensor] = [] - self.origin_score_list: List[torch.Tensor] = [] # used for sampling - self.parents_list: List[torch.Tensor] = [] - self.cache_list: List[torch.Tenor] = [] - self.iter = 0 - - # shape: (b, hidden_size) - self.hidden_states: torch.Tensor = None - # shape: (b,) - self.verified_id: torch.Tensor = None - # shape: (b, vocab_size) - self.sample_output: torch.Tensor = None - - self.positions: torch.Tensor = None - self.accept_length: torch.Tensor = None - self.accept_length_cpu: List[int] = None - - def load_server_args(self, server_args: ServerArgs): - self.topk: int = server_args.speculative_eagle_topk - self.num_verify_token: int = server_args.speculative_num_draft_tokens - self.spec_steps = server_args.speculative_num_steps + # Inputs for the attention backends + # shape: (b + 1,) + kv_indptr: torch.Tensor = None + kv_indices: torch.Tensor = None def prepare_for_extend(self, batch: ScheduleBatch): req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) @@ -231,91 +64,7 @@ def prepare_for_extend(self, batch: ScheduleBatch): assert len(batch.extend_lens) == 1 batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id)) - def filter_batch( - self, - new_indices: torch.Tensor, - ): - self.sample_output = self.sample_output[: len(new_indices)] - self.hidden_states = self.hidden_states[: len(new_indices)] - self.verified_id = self.verified_id[: len(new_indices)] - - def prepare_for_decode(self, batch: ScheduleBatch): - prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab) - top = torch.topk(prob, self.topk, dim=-1) - topk_index, topk_p = ( - top.indices, - top.values, - ) # shape: (b * top_k, top_k) or (b, top_k) - - if self.prev_mode.is_decode(): - scores = torch.mul( - self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk) - ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) - topk_cs = torch.topk( - scores.flatten(start_dim=1), self.topk, dim=-1 - ) # (b, topk) - topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values - - selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange( - 0, batch.batch_size() * self.topk, step=self.topk, device="cuda" - ).repeat_interleave(self.topk) - - batch.spec_info.hidden_states = batch.spec_info.hidden_states[ - selected_input_index, : - ] - - topk_index = topk_index.reshape(-1, self.topk**2) - batch.input_ids = torch.gather( - topk_index, index=topk_cs_index, dim=1 - ).flatten() - batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) - - self.scores = topk_cs_p - self.score_list.append(scores) # (b, topk, topk) - self.token_list.append(topk_index) # (b, topk * topk) - self.origin_score_list.append(topk_p.reshape(topk_index.shape)) - self.parents_list.append( - topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk) - ) # shape: (b, topk) - else: - # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND - batch.spec_info.hidden_states = ( - batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0) - ) - - batch.input_ids = topk_index.flatten() - batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) - - self.scores = topk_p # shape: (b, topk) - self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk) - self.token_list.append(topk_index) # shape: (b, topk) - self.origin_score_list.append(topk_p) - self.parents_list.append( - torch.arange(-1, self.topk, dtype=torch.long, device="cuda") - .unsqueeze(0) - .repeat(self.scores.shape[0], 1) - ) # shape: (b, topk + 1) - self.cache_list.append(batch.out_cache_loc) - self.positions = ( - batch.seq_lens[:, None] - + torch.full( - [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long - ) - ).flatten() - - bs = len(batch.seq_lens) - assign_req_to_token_pool[(bs,)]( - batch.req_pool_indices, - batch.req_to_token_pool.req_to_token, - batch.seq_lens + self.topk * self.iter, - batch.seq_lens + self.topk * (self.iter + 1), - batch.out_cache_loc, - batch.req_to_token_pool.req_to_token.shape[1], - triton.next_power_of_2(bs), - ) - self.iter += 1 - - def prepare_extend_after_decode(self, batch: ScheduleBatch): + def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps): batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) accept_length_cpu = batch.spec_info.accept_length_cpu batch.extend_lens = [x + 1 for x in accept_length_cpu] @@ -348,86 +97,13 @@ def prepare_extend_after_decode(self, batch: ScheduleBatch): torch.cumsum(self.accept_length, axis=0, dtype=torch.int), self.positions, new_verified_id, - triton.next_power_of_2(self.spec_steps + 1), + triton.next_power_of_2(speculative_num_steps + 1), ) batch.seq_lens_sum = sum(seq_lens_cpu) batch.input_ids = self.verified_id self.verified_id = new_verified_id - def prepare_for_verify(self, batch: ScheduleBatch): - score_list = torch.cat(self.score_list, dim=1).flatten( - 1 - ) # b, n, topk; n= 1+(self.iter-1)*self.topk - ss_token_list = torch.cat( - self.token_list, dim=1 - ) # b, (self.topk+(self.iter-1)*self.topk) - origin_token_list = torch.cat(self.origin_score_list, dim=1) - top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1) - top_scores_index = top_scores.indices - top_scores_index = torch.sort(top_scores_index).values - - draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) - scores = torch.gather(origin_token_list, index=top_scores_index, dim=1) - draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1) - parent_list = torch.cat(self.parents_list[:-1], dim=1) - - tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( - parent_list, - top_scores_index, - batch.seq_lens, - self.topk, - self.iter - 1, - self.num_verify_token, - ) - - return EagleVerifyInput( - draft_tokens.flatten(), - scores.flatten(), - tree_mask, - position, - retrive_index, - retrive_cum_len, - self.num_verify_token, - ) - - def generate_attn_arg_decode( - self, - req_pool_indices: torch.Tensor, - paged_kernel_lens: torch.Tensor, - req_to_token: torch.Tensor, - ): - seq_num = req_pool_indices.numel() - bs = self.topk * req_pool_indices.numel() - seq_len = self.positions.reshape(-1).contiguous() - - cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") - cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0) - total_len = torch.sum(paged_kernel_lens).item() - - kv_indices = torch.empty( - (total_len * self.topk + seq_num * self.iter * self.topk,), - dtype=torch.int32, - device="cuda", - ) - - generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)]( - req_pool_indices, - req_to_token, - paged_kernel_lens, - kv_indices, - self.iter, - self.topk, - req_to_token.shape[1], - triton.next_power_of_2(seq_num), - triton.next_power_of_2(self.spec_steps), - ) - return bs, kv_indices, cum_kv_seq_len - - def clear_draft_cache(self, batch): - draft_cache = torch.cat(self.cache_list, dim=0) - batch.token_to_kv_pool.free(draft_cache) - def generate_attn_arg_prefill( self, req_pool_indices: torch.Tensor, @@ -454,12 +130,18 @@ def generate_attn_arg_prefill( return kv_indices, cum_kv_seq_len, qo_indptr, None - def merge_batch(self, spec_info: EAGLEDraftInput): + def filter_batch(self, new_indices: torch.Tensor): + self.topk_p = self.topk_p[: len(new_indices)] + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] + + def merge_batch(self, spec_info: EagleDraftInput): if self.hidden_states is None: self.hidden_states = spec_info.hidden_states self.verified_id = spec_info.verified_id - self.sample_output = spec_info.sample_output - self.prev_mode = spec_info.prev_mode + self.topk_p = spec_info.topk_p + self.topk_index = spec_info.topk_index return if spec_info.hidden_states is None: return @@ -467,32 +149,68 @@ def merge_batch(self, spec_info: EAGLEDraftInput): [self.hidden_states, spec_info.hidden_states], axis=0 ) self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) - self.sample_output = torch.cat([self.sample_output, spec_info.sample_output]) - - -class EagleVerifyInput(SpecInfo): - def __init__( - self, - draft_token: torch.Tensor, - draft_score: torch.Tensor, - tree_mask: torch.Tensor, - positions: torch.Tensor, - retrive_index: torch.Tensor, - retrive_cum_len: torch.Tensor, - draft_token_num: int, + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) + + +@dataclasses.dataclass +class EagleVerifyInput: + draft_token: torch.Tensor + custom_mask: torch.Tensor + positions: torch.Tensor + retrive_index: torch.Tensor + retrive_cum_len: torch.Tensor + draft_token_num: int + capture_hidden_mode: CaptureHiddenMode + + @classmethod + def create( + cls, + verified_id: torch.Tensor, + score_list: List[torch.Tensor], + token_list: List[torch.Tensor], + parents_list: List[torch.Tensor], + seq_lens: torch.Tensor, + seq_lens_sum: int, + topk: int, + spec_steps: int, + num_verify_token: int, ): - self.draft_token = draft_token - self.draft_score = draft_score - self.custom_mask = tree_mask - self.positions = positions - self.retrive_index = retrive_index - self.retrive_cum_len = retrive_cum_len - self.draft_token_num = draft_token_num + score_list = torch.cat(score_list, dim=1).flatten( + 1 + ) # b, n, topk; n= 1 + (num_steps-1) * self.topk + ss_token_list = torch.cat( + token_list, dim=1 + ) # b, (self.topk + (num_steps-1) * self.topk) + top_scores = torch.topk(score_list, num_verify_token - 1, dim=-1) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) + draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1) + parent_list = torch.cat(parents_list[:-1], dim=1) + tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( + parent_list, + top_scores_index, + seq_lens, + seq_lens_sum, + topk, + spec_steps, + num_verify_token, + ) + return cls( + draft_tokens.flatten(), + tree_mask, + position, + retrive_index, + retrive_cum_len, + num_verify_token, + CaptureHiddenMode.FULL, + ) def prepare_for_verify(self, batch: ScheduleBatch): batch.input_ids = self.draft_token batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) - bs = batch.seq_lens.numel() + bs = batch.batch_size() assign_req_to_token_pool[(bs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, @@ -573,7 +291,6 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten triton.next_power_of_2(max_draft_len), ) - draft_input = EAGLEDraftInput() new_accept_index = [] unfinished_index = [] finished_extend_len = {} # {rid:accept_length + 1} @@ -625,10 +342,11 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten ) batch.seq_lens.add_(accept_length + 1) + draft_input = EagleDraftInput() if len(new_accept_index) > 0: new_accept_index = torch.tensor(new_accept_index, device="cuda") - draft_input.verified_id = predict[new_accept_index] draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] + draft_input.verified_id = predict[new_accept_index] draft_input.accept_length = accept_length[unfinished_index] draft_input.accept_length_cpu = [ accept_length_cpu[i] for i in unfinished_index @@ -646,3 +364,269 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten finished_extend_len, accept_length_cpu, ) + + +@triton.jit +def eagle_verify_retrive( + retrive_index, + accept_mask, + retrive_cum_len, + accept_index, + accept_length, + extract_index, + max_len: tl.constexpr, + draft_token_num: tl.constexpr, + max_len_upper: tl.constexpr, +): + pid = tl.program_id(axis=0) + + retrive_end = tl.load(retrive_cum_len + pid + 1) + retrive_start = tl.load(retrive_cum_len + pid) + retrive_len = retrive_end - retrive_start + accept_ptr = accept_mask + retrive_start + accept_offset = tl.arange(0, draft_token_num) + accept_load_mask = accept_offset < retrive_len + accept_len_list = tl.load( + accept_ptr + accept_offset, mask=accept_load_mask, other=-1 + ) + + accept_len = tl.max(accept_len_list) + max_index = tl.argmax(accept_len_list, axis=0, tie_break_left=True) + # triton is not support argmax with tie_break_right, so I need implement it by some way + mask_max = accept_len_list == accept_len + + count_mask = tl.full(shape=[draft_token_num], value=0, dtype=tl.int32) + count = tl.sum(tl.where(mask_max, 1, count_mask)) + if count > 1: + index = tl.arange(0, draft_token_num) + mask_left = index != max_index + remained_index = tl.where(mask_max and mask_left, index, 0) + max_index = tl.max(remained_index) + + tl.store(accept_length + pid, accept_len) + retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len + retrive_offset = tl.arange(0, max_len_upper) + retrive_load_mask = retrive_offset < accept_len + 1 + data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask) + + tl.store( + accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask + ) + + extract_load_ptr = accept_index + pid * max_len + accept_len + if accept_len == max_len - 1: + extract_data = tl.load(extract_load_ptr - 1) + tl.store(extract_index + pid * 2, extract_data) + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2 + 1, extract_data) + + else: + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2, extract_data) + + +@triton.jit +def create_extend_spec_info( + verified_id, + seq_len, + accept_len, + accept_len_cum, + positions, + new_verified_id, + accept_len_upper: tl.constexpr, +): + pid = tl.program_id(axis=0) + offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1) + seq_length = tl.load(seq_len + pid) + accept_length = tl.load(accept_len + pid) + positions_ptr = positions + offset + data = tl.arange(0, accept_len_upper) + mask = data < accept_length + tl.store(positions_ptr + data, seq_length - accept_length + data, mask) + + offset = tl.load(accept_len_cum + pid) - 1 + verified_id_data = tl.load(verified_id + offset) + tl.store(new_verified_id + pid, verified_id_data) + + +@triton.jit +def assign_req_to_token_pool( + req_pool_indices, + req_to_token, + start_offset, + end_offset, + out_cache_loc, + pool_len: tl.constexpr, + bs_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 32 + pid = tl.program_id(axis=0) + kv_start = tl.load(start_offset + pid) + kv_end = tl.load(end_offset + pid) + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + + length_offset = tl.arange(0, bs_upper) + start = tl.load(start_offset + length_offset, mask=length_offset < pid) + end = tl.load(end_offset + length_offset, mask=length_offset < pid) + out_offset = tl.sum(end - start, axis=0) + + out_cache_ptr = out_cache_loc + out_offset + + save_offset = tl.arange(0, BLOCK_SIZE) + kv_start + load_offset = tl.arange(0, BLOCK_SIZE) + + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = save_offset < kv_end + data = tl.load(out_cache_ptr + load_offset, mask=mask) + tl.store(token_pool + save_offset, data, mask=mask) + save_offset += BLOCK_SIZE + load_offset += BLOCK_SIZE + + +@triton.jit +def assign_draft_cache_locs( + req_pool_indices, + req_to_token, + seq_lens, + out_cache_loc, + pool_len: tl.constexpr, + topk: tl.constexpr, + speculative_num_steps: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 32 + pid = tl.program_id(axis=0) + kv_start = tl.load(seq_lens + pid) + kv_end = tl.load(seq_lens + pid) + topk * speculative_num_steps + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps + + num_loop = tl.cdiv(topk * speculative_num_steps, BLOCK_SIZE) + for i in range(num_loop): + save_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + kv_start + load_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = save_offset < kv_end + data = tl.load(out_cache_ptr + load_offset, mask=mask) + tl.store(token_pool + save_offset, data, mask=mask) + + +@triton.jit +def generate_draft_decode_kv_indices( + req_pool_indices, + req_to_token, + paged_kernel_lens, + kv_indices, + kv_indptr, + positions, + num_seqs: tl.constexpr, + topk: tl.constexpr, + pool_len: tl.constexpr, + kv_indices_stride: tl.constexpr, + kv_indptr_stride: tl.constexpr, + bs_upper: tl.constexpr, + iter_upper: tl.constexpr, + num_tokens_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 128 + iters = tl.program_id(axis=0) + bid = tl.program_id(axis=1) + topk_id = tl.program_id(axis=2) + + kv_indices += kv_indices_stride * iters + kv_indptr += kv_indptr_stride * iters + iters += 1 + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid) + seq_len = tl.load(paged_kernel_lens + bid) + cum_seq_len = tl.sum(seq_lens) + + kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) + kv_ptr = kv_indices + kv_offset + token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len + + kv_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(seq_len, BLOCK_SIZE) + for _ in range(num_loop): + mask = kv_offset < seq_len + data = tl.load(token_pool_ptr + kv_offset, mask=mask) + tl.store(kv_ptr + kv_offset, data, mask=mask) + kv_offset += BLOCK_SIZE + + extend_offset = tl.arange(0, iter_upper) + extend_data = tl.load( + token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id, + mask=extend_offset < iters, + ) + tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) + + # Update kv_indptr + bs_offset = tl.arange(0, num_tokens_upper) + + zid = bid * topk + topk_id + if zid == 0: + zid = num_seqs * topk + positions = tl.load(positions + bs_offset, mask=bs_offset < zid) + base = tl.sum(positions) + tl.store(kv_indptr + zid, base + zid * iters) + + +@torch.compile +def select_top_k_tokens( + i: int, + topk_p: torch.Tensor, + topk_index: torch.Tensor, + hidden_states: torch.Tensor, + scores: torch.Tensor, + topk: int, +): + if i == 0: + # The first step after extend + input_ids = topk_index.flatten() + hidden_states = hidden_states.repeat_interleave(topk, dim=0) + scores = topk_p # shape: (b, topk) + + tree_info = ( + topk_p.unsqueeze(1), # shape: (b, 1, topk) + topk_index, # shape: (b, topk) + torch.arange(-1, topk, dtype=torch.long, device="cuda") + .unsqueeze(0) + .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) + ) + + else: + # The later decode steps + expand_scores = torch.mul( + scores.unsqueeze(2), topk_p.reshape(-1, topk, topk) + ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) + + topk_cs_p, topk_cs_index = fast_topk( + expand_scores.flatten(start_dim=1), topk, dim=-1 + ) # (b, topk) + scores = topk_cs_p # shape: (b, topk) + + topk_index = topk_index.reshape(-1, topk**2) + input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() + + selected_input_index = topk_cs_index.flatten() // topk + torch.arange( + 0, hidden_states.shape[0], step=topk, device="cuda" + ).repeat_interleave(topk) + hidden_states = hidden_states[selected_input_index, :] + + tree_info = ( + expand_scores, # shape: (b, topk, topk) + topk_index, # shape: (b, topk * topk) + topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk) + ) + + return input_ids, hidden_states, scores, tree_info + + +def fast_topk(values, topk, dim): + if topk == 1: + # Use max along the specified dimension to get both value and index + max_value, max_index = torch.max(values, dim=dim) + return max_value.unsqueeze(1), max_index.unsqueeze(1) + else: + # Use topk for efficiency with larger k values + return torch.topk(values, topk, dim=dim) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 06a4372fce..b5a3de6cae 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1,3 +1,5 @@ +import logging +import time from typing import List, Optional, Union import torch @@ -12,8 +14,18 @@ ) from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs -from sglang.srt.speculative.eagle_utils import EAGLEDraftInput -from sglang.srt.utils import rank0_print +from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( + EAGLEDraftCudaGraphRunner, +) +from sglang.srt.speculative.eagle_utils import ( + EagleDraftInput, + EagleVerifyInput, + assign_draft_cache_locs, + fast_topk, + select_top_k_tokens, +) + +logger = logging.getLogger(__name__) class EAGLEWorker(TpModelWorker): @@ -40,41 +52,47 @@ def __init__( is_draft_worker=True, ) self.target_worker = target_worker - self.server_args = server_args self.finish_extend_len = [] + # Parse arguments + self.topk = server_args.speculative_eagle_topk + self.speculative_num_steps = server_args.speculative_num_steps + self.server_args = server_args + # Share the embedding and lm_head embed, head = self.target_worker.model_runner.model.get_embed_and_head() self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph - self.model_runner.init_cuda_graphs() - def forward_draft_decode(self, batch: ScheduleBatch): - batch.spec_info.prepare_for_decode(batch) - batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST - model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch) - self.capture_for_decode(logits_output, forward_batch) + # Create multi-step attn backends and cuda graph runners + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferMultiStepDraftBackend, + ) - def forward_draft_extend(self, batch: ScheduleBatch): - self._set_mem_pool(batch, self.model_runner) - batch.spec_info.prepare_for_extend(batch) - batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST - model_worker_batch = batch.get_model_worker_batch() - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch) - self.capture_for_decode(logits_output, forward_batch) - self._set_mem_pool(batch, self.target_worker.model_runner) + self.draft_attn_backend = FlashInferMultiStepDraftBackend( + self.model_runner, + self.topk, + self.speculative_num_steps, + ) + self.model_runner.draft_attn_backend = self.draft_attn_backend + self.init_cuda_graphs() + + def init_cuda_graphs(self): + """Capture cuda graphs.""" + self.cuda_graph_runner = None + + if self.server_args.disable_cuda_graph: + return + + tic = time.time() + logger.info("Capture cuda graph begin. This can take up to several minutes.") + self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self) + logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s") def forward_batch_speculative_generation(self, batch: ScheduleBatch): if batch.forward_mode.is_decode(): # Draft - self._set_mem_pool(batch, self.model_runner) - for i in range(self.server_args.speculative_num_steps): - self.forward_draft_decode(batch) - batch.spec_info.clear_draft_cache(batch) - self._set_mem_pool(batch, self.target_worker.model_runner) + spec_info: EagleVerifyInput = self.draft(batch) # Verify ( @@ -84,8 +102,7 @@ def forward_batch_speculative_generation(self, batch: ScheduleBatch): self.finish_extend_len, accept_length_cpu, model_worker_batch, - ) = self.verify(batch) - next_draft_input.load_server_args(self.server_args) + ) = self.verify(batch, spec_info) batch.spec_info = next_draft_input # if it is None, means all requsets are finished if batch.spec_info.verified_id is not None: @@ -107,29 +124,145 @@ def forward_batch_speculative_generation(self, batch: ScheduleBatch): ) # Forward with the draft model. - spec_info = EAGLEDraftInput() - spec_info.load_server_args(self.server_args) - spec_info.hidden_states = logits_output.hidden_states - spec_info.verified_id = next_token_ids - batch.spec_info = spec_info + batch.spec_info = EagleDraftInput( + hidden_states=logits_output.hidden_states, + verified_id=next_token_ids, + ) self.forward_draft_extend(batch) return logits_output, next_token_ids, model_worker_batch, 0 - def verify(self, batch: ScheduleBatch): - verify_input = batch.spec_info.prepare_for_verify(batch) - verify_input.prepare_for_verify(batch) + def draft(self, batch: ScheduleBatch): + self._set_mem_pool(batch, self.model_runner) + + # Parse args + num_seqs = batch.batch_size() + spec_info = batch.spec_info + + # Allocate cache locations + out_cache_loc = batch.alloc_token_slots( + num_seqs * self.topk * self.speculative_num_steps + ) + assign_draft_cache_locs[(num_seqs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + self.topk, + self.speculative_num_steps, + ) + + batch.out_cache_loc = out_cache_loc + batch.seq_lens_sum = torch.sum(batch.seq_lens).item() + spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0) + + # Get forward batch + spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run( + forward_batch + ) + + if can_cuda_graph: + score_list, token_list, parents_list = self.cuda_graph_runner.replay( + forward_batch + ) + else: + # Initialize attention backend + self.draft_attn_backend.init_forward_metadata(forward_batch) + + # Run forward steps + score_list, token_list, parents_list = self.draft_forward(forward_batch) + + ret = EagleVerifyInput.create( + spec_info.verified_id, + score_list, + token_list, + parents_list, + batch.seq_lens, + batch.seq_lens_sum, + self.topk, + self.speculative_num_steps, + self.server_args.speculative_num_draft_tokens, + ) + + # Free cache locations + batch.token_to_kv_pool.free(out_cache_loc) + self._set_mem_pool(batch, self.target_worker.model_runner) + return ret + + def draft_forward(self, forward_batch: ForwardBatch): + # Parse args + spec_info = forward_batch.spec_info + out_cache_loc = forward_batch.out_cache_loc + topk_p, topk_index, hidden_states = ( + spec_info.topk_p, + spec_info.topk_index, + spec_info.hidden_states, + ) + + # Return values + score_list: List[torch.Tensor] = [] + token_list: List[torch.Tensor] = [] + parents_list: List[torch.Tensor] = [] + + # Forward multiple steps + scores = None + for i in range(self.speculative_num_steps): + input_ids, hidden_states, scores, tree_info = select_top_k_tokens( + i, topk_p, topk_index, hidden_states, scores, self.topk + ) + score_list.append(tree_info[0]) + token_list.append(tree_info[1]) + parents_list.append(tree_info[2]) + + # Set inputs + forward_batch.input_ids = input_ids + forward_batch.out_cache_loc = out_cache_loc[ + forward_batch.batch_size + * self.topk + * i : forward_batch.batch_size + * self.topk + * (i + 1) + ] + forward_batch.positions.add_(1) + forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i] + spec_info.hidden_states = hidden_states + + # Run forward + logits_output = self.model_runner.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + hidden_states = logits_output.hidden_states + + return score_list, token_list, parents_list + + def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): + spec_info.prepare_for_verify(batch) batch.forward_mode = ForwardMode.TARGET_VERIFY - batch.spec_info = verify_input - batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch() logits_output, _ = self.target_worker.forward_batch_generation( model_worker_batch, skip_sample=True ) - verify_input.hidden_states = logits_output.hidden_states - res = verify_input.verify(batch, logits_output) + spec_info.hidden_states = logits_output.hidden_states + res = spec_info.verify(batch, logits_output) batch.forward_mode = ForwardMode.DECODE return res + (model_worker_batch,) + def forward_draft_extend(self, batch: ScheduleBatch): + self._set_mem_pool(batch, self.model_runner) + batch.spec_info.prepare_for_extend(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) + self._set_mem_pool(batch, self.target_worker.model_runner) + def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): batch.token_to_kv_pool = runner.token_to_kv_pool batch.req_to_token_pool = runner.req_to_token_pool @@ -139,7 +272,7 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch): self._set_mem_pool(batch, self.model_runner) batch.forward_mode = ForwardMode.DRAFT_EXTEND - batch.spec_info.prepare_extend_after_decode(batch) + batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps) batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) @@ -155,13 +288,10 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch): def capture_for_decode( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch ): - sample_output = torch.softmax( - logits_output.next_token_logits, dim=-1 - ) # TODO(kavioyu): Support more sampling methods + probs = torch.softmax(logits_output.next_token_logits, dim=-1) spec_info = forward_batch.spec_info - spec_info.sample_output = sample_output + spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1) spec_info.hidden_states = logits_output.hidden_states - spec_info.prev_mode = forward_batch.forward_mode # Don't support prefix share now. def finish_request(self, reqs: Union[Req, List[Req]]): From 00fa7d0417bf8d49b332c0cfc9a2647d6a9307fe Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 3 Feb 2025 21:34:44 +0800 Subject: [PATCH 36/52] add copyright for sgl-kernel (#3270) --- sgl-kernel/setup.py | 15 +++++++++++++++ .../epilogue/epilogue_per_row_per_col_scale.h | 15 +++++++++++++++ .../gemm/gemm_universal_base_compat.h | 15 +++++++++++++++ .../gemm/gemm_with_epilogue_visitor.h | 15 +++++++++++++++ sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu | 15 +++++++++++++++ .../sgl-kernel/csrc/fused_add_rms_norm_kernel.cu | 15 +++++++++++++++ .../src/sgl-kernel/csrc/int8_gemm_kernel.cu | 15 +++++++++++++++ .../csrc/lightning_attention_decode_kernel.cu | 15 +++++++++++++++ .../src/sgl-kernel/csrc/moe_align_kernel.cu | 15 +++++++++++++++ .../src/sgl-kernel/csrc/trt_reduce_internal.cu | 15 +++++++++++++++ .../src/sgl-kernel/csrc/trt_reduce_kernel.cu | 15 +++++++++++++++ .../src/sgl-kernel/include/sgl_kernels_ops.h | 15 +++++++++++++++ .../sgl-kernel/include/trt_reduce_internal.cuh | 15 +++++++++++++++ sgl-kernel/src/sgl-kernel/include/utils.h | 15 +++++++++++++++ sgl-kernel/src/sgl-kernel/ops/utils.py | 15 +++++++++++++++ sgl-kernel/src/sgl-kernel/torch_extension.cc | 15 +++++++++++++++ 16 files changed, 240 insertions(+) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 90c3cbc1d3..9a93ae9922 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -1,3 +1,18 @@ +# Copyright 2025 SGLang Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import multiprocessing import os import sys diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h index c83cf49ad8..f5cd438156 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Adapted from // https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h index 33e82decc2..3de9ff078b 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Adapted from // https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h #pragma once diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h index 674e191a07..11fc872505 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Adapted from // https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu index 3e33e143c0..36b9585f34 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Adapted from // https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h // https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h diff --git a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu index f0f3a51744..a4ae14ae59 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include diff --git a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu index c77851c32b..4a8130d667 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include diff --git a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu index e62a154cb1..e9fc1c0ecd 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include #include diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index 19e9850b51..d51ca51759 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // Adapted from https://github.com/vllm-project/vllm/blob/v0.6.5/csrc/moe/moe_align_sum_kernels.cu #include diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index 2ee0c98c91..fa9e3a2c5d 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // reference: // https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu /* diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu index fd0483e39e..af129de52e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h #include diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index c5cc30c188..1fdcc9c35a 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #pragma once #include diff --git a/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh index 46522348aa..f4b01230cf 100644 --- a/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // reference: // https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp /* diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h index 55594f7b27..b714df7754 100644 --- a/sgl-kernel/src/sgl-kernel/include/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #pragma once #include diff --git a/sgl-kernel/src/sgl-kernel/ops/utils.py b/sgl-kernel/src/sgl-kernel/ops/utils.py index 31a6bbf991..683748da0f 100644 --- a/sgl-kernel/src/sgl-kernel/ops/utils.py +++ b/sgl-kernel/src/sgl-kernel/ops/utils.py @@ -1,3 +1,18 @@ +# Copyright 2025 SGLang Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, Tuple import torch diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc index 01f93199cc..aaed142a1e 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -1,3 +1,18 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + #include #include From d54cee1441753414ab2c8c381d39e21c4681193d Mon Sep 17 00:00:00 2001 From: kushanam <42385577+kushanam@users.noreply.github.com> Date: Mon, 3 Feb 2025 12:12:09 -0800 Subject: [PATCH 37/52] adding Triton configs for DeepSeekV3 on Blackwell (#3272) --- ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++++++++++++++++++ 14 files changed, 2044 insertions(+) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..2840e9f472 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..77ba0d7477 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..0a5d7bfdba --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..cb91a279d4 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..7febe3d272 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..9d7658bfc4 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..03dba5ad15 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..9a5ff48b89 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..386928de13 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..9c908e8040 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..f78e7060e6 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..1d3ce5c94c --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..3ab5796ee1 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..3cb7eaa07c --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + } +} From 897e2e253af5a738518ee8f044c7d894cd2a339d Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 4 Feb 2025 04:41:26 +0800 Subject: [PATCH 38/52] add Nebius for Adoption and Sponsorship (#3274) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 4b17633d81..48290ae893 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487) ## Adoption and Sponsorship -The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. +The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, Nebius, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. ## Contact Us From 4b6f62e2bc52a528551e9a21e7b0a4945c6115bb Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 4 Feb 2025 05:31:30 +0800 Subject: [PATCH 39/52] add Atlas Cloud for Adoption and Sponsorship (#3276) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 48290ae893..f524465797 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487) ## Adoption and Sponsorship -The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, Nebius, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. +The project is supported by (alphabetically): AMD, Atlas Cloud, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS CORP, Meituan, Nebius, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. ## Contact Us From 7b5a374114c7c6095fe2dd7898ed73da534e05eb Mon Sep 17 00:00:00 2001 From: simveit <69345428+simveit@users.noreply.github.com> Date: Tue, 4 Feb 2025 00:39:41 +0100 Subject: [PATCH 40/52] Update server args doc (#3273) Co-authored-by: Shi Shuai <126407087+shuaills@users.noreply.github.com> --- docs/backend/server_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 7e8f4ca0a5..35a2a8c4e7 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -159,7 +159,7 @@ Please consult the documentation below to learn more about the parameters you ma * `disable_radix_cache`: Disable [Radix](https://lmsys.org/blog/2024-01-17-sglang/) backend for prefix caching. * `disable_jump_forward`: Disable [jump-forward](https://lmsys.org/blog/2024-02-05-compressed-fsm/#our-method-jump-forward-decoding-with-a-compressed-finite-state-machine) for outlines grammar backend. -* `disable_cuda_graph`: Disable [cuda graph](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for model forward. +* `disable_cuda_graph`: Disable [cuda graph](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for model forward. Use if encountering uncorrectable CUDA ECC errors. * `disable_cuda_graph_padding`: Disable cuda graph when padding is needed. In other case still use cuda graph. * `disable_outlines_disk_cache`: Disable disk cache for outlines grammar backend. * `disable_custom_all_reduce`: Disable usage of custom all reduce kernel. From 70817a7eae0055cb3c98c7827b73b8058fb342f4 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Mon, 3 Feb 2025 22:09:13 -0800 Subject: [PATCH 41/52] [Feature] Define backends and add Triton backend for Lora (#3161) Co-authored-by: Ying Sheng --- benchmark/lora/launch_server.py | 14 +- benchmark/lora/lora_bench.py | 5 + docs/backend/server_arguments.md | 1 + python/sglang/srt/lora/backend/__init__.py | 8 + .../sglang/srt/lora/backend/base_backend.py | 95 +++++++ .../srt/lora/backend/flashinfer_backend.py | 88 +++++++ .../sglang/srt/lora/backend/triton_backend.py | 61 +++++ python/sglang/srt/lora/lora.py | 237 ++++++++++-------- python/sglang/srt/lora/lora_manager.py | 68 +++-- python/sglang/srt/lora/triton_ops/__init__.py | 5 + .../sglang/srt/lora/triton_ops/qkv_lora_b.py | 182 ++++++++++++++ .../srt/lora/triton_ops/sgemm_lora_a.py | 143 +++++++++++ .../srt/lora/triton_ops/sgemm_lora_b.py | 159 ++++++++++++ .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/server_args.py | 11 +- python/sglang/test/runners.py | 2 + test/srt/models/test_lora_backend.py | 183 ++++++++++++++ test/srt/run_suite.py | 1 + 18 files changed, 1129 insertions(+), 135 deletions(-) create mode 100644 python/sglang/srt/lora/backend/__init__.py create mode 100644 python/sglang/srt/lora/backend/base_backend.py create mode 100644 python/sglang/srt/lora/backend/flashinfer_backend.py create mode 100644 python/sglang/srt/lora/backend/triton_backend.py create mode 100644 python/sglang/srt/lora/triton_ops/__init__.py create mode 100644 python/sglang/srt/lora/triton_ops/qkv_lora_b.py create mode 100644 python/sglang/srt/lora/triton_ops/sgemm_lora_a.py create mode 100644 python/sglang/srt/lora/triton_ops/sgemm_lora_b.py create mode 100644 test/srt/models/test_lora_backend.py diff --git a/benchmark/lora/launch_server.py b/benchmark/lora/launch_server.py index f139f0df6f..418155dbf5 100644 --- a/benchmark/lora/launch_server.py +++ b/benchmark/lora/launch_server.py @@ -1,10 +1,10 @@ import argparse import os -NUM_LORAS = 8 +NUM_LORAS = 4 LORA_PATH = { - "base": "mistralai/Mistral-7B-Instruct-v0.3", - "lora": "/home/ying/test_lora", + "base": "meta-llama/Llama-2-7b-hf", + "lora": "winddude/wizardLM-LlaMA-LoRA-7B", } @@ -21,7 +21,8 @@ def launch_server(args): cmd += f"{lora_name}={lora_path} " cmd += f"--disable-radix --disable-cuda-graph " cmd += f"--max-loras-per-batch {args.max_loras_per_batch} " - cmd += f"--max-running-requests {args.max_running_requests}" + cmd += f"--max-running-requests {args.max_running_requests} " + cmd += f"--lora-backend {args.lora_backend}" print(cmd) os.system(cmd) @@ -42,6 +43,11 @@ def launch_server(args): type=int, default=8, ) + parser.add_argument( + "--lora-backend", + type=str, + default="triton", + ) args = parser.parse_args() launch_server(args) diff --git a/benchmark/lora/lora_bench.py b/benchmark/lora/lora_bench.py index 713cbbf76c..b5af65a7dd 100644 --- a/benchmark/lora/lora_bench.py +++ b/benchmark/lora/lora_bench.py @@ -183,6 +183,7 @@ async def benchmark( api_url=api_url, prompt_len=test_prompt_len, output_len=test_output_len, + lora_name="dummy", # the lora_name argument will not be used extra_request_body=extra_request_body, ) test_output = await request_func(request_func_input=test_input) @@ -206,6 +207,7 @@ async def benchmark( api_url=api_url, prompt_len=prompt_len, output_len=output_len, + lora_name="dummy", extra_request_body=extra_request_body, ) tasks.append( @@ -255,6 +257,9 @@ async def benchmark( "Output token throughput (tok/s):", metrics.output_throughput ) ) + print( + "{:<40} {:<10.2f}".format("Total throughput (tok/s):", metrics.total_throughput) + ) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print( "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 35a2a8c4e7..d6b12b1056 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -124,6 +124,7 @@ Please consult the documentation below to learn more about the parameters you ma * `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supportet with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929). * `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model. +* `lora_backend`: The backend of running GEMM kernels for Lora modules, can be one of `triton` or `flashinfer`. Defaults to be `triton`. ## Kernel backend diff --git a/python/sglang/srt/lora/backend/__init__.py b/python/sglang/srt/lora/backend/__init__.py new file mode 100644 index 0000000000..ed377b4b4a --- /dev/null +++ b/python/sglang/srt/lora/backend/__init__.py @@ -0,0 +1,8 @@ +from .base_backend import BaseLoraBackend +from .flashinfer_backend import FlashInferLoraBackend +from .triton_backend import TritonLoraBackend + +__all__ = [ + "FlashInferLoraBackend", + "TritonLoraBackend", +] diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py new file mode 100644 index 0000000000..d6c72a14e7 --- /dev/null +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -0,0 +1,95 @@ +from typing import Tuple, Union + +import torch + +from sglang.srt.lora.lora import LoraBatchInfo + + +def get_fuse_output_scaling_add_from_name(name: str) -> bool: + mapping = { + "triton": True, + "flashinfer": False, + } + return mapping.get(name, False) + + +def get_fuse_qkv_lora_b_from_name(name: str) -> bool: + mapping = { + "triton": True, + "flashinfer": False, + } + return mapping.get(name, False) + + +class BaseLoraBackend: + """Base class for different Lora backends. + Each backend has its own implementation of Lora kernels. + + Args: + name: name of backend + batch_info: information of current batch for use + fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward, + and the operation of scaling and adding will be fused into kernel + """ + + def __init__(self, name: str, batch_info: LoraBatchInfo = None): + self.name = name + self.batch_info = batch_info + self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name) + self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name) + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + """Run segment Gemm of lora a modules with current backend. + The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. + + Args: + x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths + weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank + usually input_dim is much larger than r + Returns: + result with shape (s, r) + """ + pass + + def run_lora_b_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + """Run segment Gemm of lora b modules with current backend. + The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html. + + Args: + x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank + weights: a set of lora weights with shape (num_lora, output_dim, r) + usually output_dim is much larger than r + Returns: + result with shape (s, output_dim) + """ + pass + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]], + *args, + **kwargs + ) -> torch.Tensor: + """Run the lora pass for QKV Layer. + + Args: + x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths + qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim) + qkv_lora_b: lora_b module for qkv. + If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r) + If passed in as a tuple of two tensors containing: + a lora_b module for q, with shape (1, num_lora, output_dim_q, r) + and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r) + Returns: + result with shape (s, output_dim_q + 2 * output_dim_kv) + """ + pass + + def set_batch_info(self, batch_info: LoraBatchInfo): + self.batch_info = batch_info diff --git a/python/sglang/srt/lora/backend/flashinfer_backend.py b/python/sglang/srt/lora/backend/flashinfer_backend.py new file mode 100644 index 0000000000..5374a3e0a6 --- /dev/null +++ b/python/sglang/srt/lora/backend/flashinfer_backend.py @@ -0,0 +1,88 @@ +from typing import Tuple + +import torch +from flashinfer import SegmentGEMMWrapper + +from sglang.srt.lora.backend import BaseLoraBackend +from sglang.srt.lora.lora import LoraBatchInfo + + +class FlashInferLoraBackend(BaseLoraBackend): + + def __init__(self, name: str, batch_info: LoraBatchInfo = None): + super().__init__(name, batch_info) + + # Set up SGemm Wrapper from flashinfer + # FIXME wait for flashinfer segment gemm update + workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") + self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + + return self.segment_gemm.run( + x=x, + weights=weights, + batch_size=self.batch_info.bs, + weight_column_major=True, + seg_indptr=self.batch_info.seg_indptr, + weight_indices=self.batch_info.weight_indices, + ) + + def run_lora_b_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + + return self.segment_gemm.run( + x=x, + weights=weights, + batch_size=self.batch_info.bs, + weight_column_major=True, + seg_indptr=self.batch_info.seg_indptr, + weight_indices=self.batch_info.weight_indices, + ) + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: Tuple[torch.Tensor], + *args, + **kwargs, + ) -> torch.Tensor: + + # Shape of lora_a_output: (s, 3 * r) + lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a) + + q_lora_b, kv_lora_b = qkv_lora_b + lora_rank = kv_lora_b.shape[-1] + output_dim_q = q_lora_b.shape[-2] + output_dim_kv = kv_lora_b.shape[-2] + lora_output = torch.empty( + (x.shape[0], output_dim_q + 2 * output_dim_kv), + device=x.device, + dtype=x.dtype, + ) + + # q + lora_output[:, :output_dim_q] = self.run_lora_b_sgemm( + x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0] + ) + + # kv + lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = ( + self.run_lora_b_sgemm( + x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(), + weights=kv_lora_b[0], + ) + ) + + lora_output[ + :, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv + ] = self.run_lora_b_sgemm( + x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(), + weights=kv_lora_b[1], + ) + + return lora_output diff --git a/python/sglang/srt/lora/backend/triton_backend.py b/python/sglang/srt/lora/backend/triton_backend.py new file mode 100644 index 0000000000..357040bf9d --- /dev/null +++ b/python/sglang/srt/lora/backend/triton_backend.py @@ -0,0 +1,61 @@ +import torch + +from sglang.srt.lora.backend import BaseLoraBackend +from sglang.srt.lora.lora import LoraBatchInfo +from sglang.srt.lora.triton_ops import ( + qkv_lora_b_fwd, + sgemm_lora_a_fwd, + sgemm_lora_b_fwd, +) + + +class TritonLoraBackend(BaseLoraBackend): + + def __init__(self, name: str, batch_info: LoraBatchInfo = None): + super().__init__(name, batch_info) + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + return sgemm_lora_a_fwd(x, weights, self.batch_info) + + def run_lora_b_sgemm( + self, + x: torch.Tensor, + weights: torch.Tensor, + base_output: torch.Tensor = None, + scaling: float = 1.0, + *args, + **kwargs + ) -> torch.Tensor: + return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling) + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: torch.Tensor, + output_offset: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor = None, + scaling: float = 1.0, + *args, + **kwargs + ) -> torch.Tensor: + + # x: (s, input_dim) + # qkv_lora_a: (num_lora, 3 * r, input_dim) + # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r) + assert isinstance(qkv_lora_b, torch.Tensor) + + lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info) + lora_output = qkv_lora_b_fwd( + lora_a_output, + qkv_lora_b, + self.batch_info, + output_offset, + max_qkv_out_dim, + base_output, + scaling, + ) + return lora_output diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 871c1a2291..9de3b9236b 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -18,8 +18,8 @@ # LoRA layers class inheritance adapted from: # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py - import re +from dataclasses import dataclass import torch from torch import nn @@ -34,14 +34,32 @@ from sglang.srt.model_loader.loader import DefaultModelLoader +@dataclass +class LoraBatchInfo: + # Batch size + bs: int + + # Lengths of each sequence in shape (bs,) + seg_lens: torch.Tensor + + # Indice pointers of each sequence in shape (bs + 1, ) + seg_indptr: torch.Tensor + + # Maximum sequence length of current batch + max_len: int + + # The index of lora adapter used by each sequence, in shape (bs,) + weight_indices: torch.Tensor + + class BaseLayerWithLoRA(nn.Module): - def __init__(self, base_layer, segment_gemm, lora_rank, scaling): + def __init__(self, base_layer, lora_rank, scaling, lora_backend): super().__init__() self.base_layer = base_layer - self.segment_gemm = segment_gemm self.lora_rank = lora_rank self.scaling = scaling self.set_lora = False + self.lora_backend = lora_backend def forward(self, x: torch.Tensor): return self.base_layer.forward(x) @@ -52,17 +70,17 @@ def set_lora_info(self, *args): class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__( - self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling + self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) self.weight = base_layer.weight class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__( - self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling + self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: # TODO @@ -88,136 +106,127 @@ def forward(self, input_: torch.Tensor): class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): def __init__( - self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling + self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) - def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices): + def set_lora_info( + self, + A_buffer, + B_buffer, + ): self.set_lora = True self.A_buffer = A_buffer self.B_buffer = B_buffer - self.bs = bs - self.seg_indptr = seg_indptr - self.weight_indices = weight_indices def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_a_output = self.segment_gemm.run( - x=x, - weights=self.A_buffer, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, - ) - # FIXME + lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer) + + output_dim = base_output.shape[-1] lora_output = torch.empty_like(base_output) - output_dim = lora_output.shape[-1] // 2 - for i in range(2): - left = output_dim * i - right = left + output_dim - lora_output[:, left:right] = self.segment_gemm.run( - x=lora_a_output[ - :, self.lora_rank * i : self.lora_rank * (i + 1) - ].contiguous(), - weights=self.B_buffer[:, left:right, :].contiguous(), - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm( + x=lora_a_output[:, 0 : self.lora_rank].contiguous(), + weights=self.B_buffer[0], + ) + + lora_output[:, output_dim : 2 * output_dim] = ( + self.lora_backend.run_lora_b_sgemm( + x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(), + weights=self.B_buffer[1], ) + ) + return base_output + lora_output * self.scaling class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): - def __init__( - self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling + def init__( + self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) def set_lora_info( - self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices + self, + A_buffer_qkv, + B_buffer_q, + B_buffer_kv, ): self.set_lora = True self.A_buffer_qkv = A_buffer_qkv - self.B_buffer_q = B_buffer_q - self.B_buffer_kv = B_buffer_kv - self.bs = bs - self.seg_indptr = seg_indptr - self.weight_indices = weight_indices + + if self.lora_backend.fuse_qkv_lora_b: + assert ( + B_buffer_q.shape[-1] == B_buffer_kv.shape[-1] + ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b" + output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2] + + # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r) + self.B_buffer_qkv = torch.cat( + (B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2 + ).contiguous() + + # Offsets of q/k/v in output dimension + self.output_offset = torch.tensor( + [ + 0, + output_dim_q, + output_dim_q + output_dim_kv, + output_dim_q + 2 * output_dim_kv, + ], + dtype=torch.int32, + device=B_buffer_q.device, + ) + # For computing number of launched blocks + self.max_qkv_out_dim = max(output_dim_q, output_dim_kv) + else: + self.B_buffer_qkv = ( + B_buffer_q, + B_buffer_kv, + ) + self.output_offset = None + self.max_qkv_out_dim = None def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_a_output = self.segment_gemm.run( - x=x, - weights=self.A_buffer_qkv, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + lora_output = self.lora_backend.run_qkv_lora( + x, + self.A_buffer_qkv, + self.B_buffer_qkv, + output_offset=self.output_offset, + max_qkv_out_dim=self.max_qkv_out_dim, + base_output=base_output, + scaling=self.scaling, ) - # FIXME parallelize qkv - lora_output = torch.empty_like(base_output) - # q - output_dim_q = self.B_buffer_q.shape[-2] - lora_output[:, :output_dim_q] = self.segment_gemm.run( - x=lora_a_output[:, : self.lora_rank].contiguous(), - weights=self.B_buffer_q, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling ) - # kv - output_dim_kv = self.B_buffer_kv.shape[-2] // 2 - for i in range(2): - left = output_dim_kv * i - right = left + output_dim_kv - lora_output[:, output_dim_q + left : output_dim_q + right] = ( - self.segment_gemm.run( - x=lora_a_output[ - :, self.lora_rank * (i + 1) : self.lora_rank * (i + 2) - ].contiguous(), - weights=self.B_buffer_kv[:, left:right, :].contiguous(), - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, - ) - ) - return base_output + lora_output * self.scaling class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__( - self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling + self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend ) -> None: - super().__init__(base_layer, segment_gemm, lora_rank, scaling) + super().__init__(base_layer, lora_rank, scaling, lora_backend) - def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices): + def set_lora_info(self, A_buffer, B_buffer): self.set_lora = True self.A_buffer = A_buffer self.B_buffer = B_buffer - self.bs = bs - self.seg_indptr = seg_indptr - self.weight_indices = weight_indices def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - lora_output = self.segment_gemm.run( - x=x, - weights=self.A_buffer, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) + lora_output = self.lora_backend.run_lora_b_sgemm( + lora_a_output, + self.B_buffer[0], + base_output=base_output, + scaling=self.scaling, ) - lora_output = self.segment_gemm.run( - x=lora_output, - weights=self.B_buffer, - batch_size=self.bs, - weight_column_major=True, - seg_indptr=self.seg_indptr, - weight_indices=self.weight_indices, + return ( + lora_output + if self.lora_backend.fuse_output_scaling_add + else base_output + lora_output * self.scaling ) - return base_output + lora_output * self.scaling def forward(self, input_): # duplicate the logic in RowParallelLinear @@ -255,7 +264,7 @@ def forward(self, input_): def get_lora_layer( - layer: nn.Module, segment_gemm, lora_rank, scaling + layer: nn.Module, lora_rank, scaling, lora_backend ) -> BaseLayerWithLoRA: supported_layer_types = { # the order matters @@ -267,7 +276,7 @@ def get_lora_layer( } for src_layer_type, lora_layer_type in supported_layer_types.items(): if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck - ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling) + ret = lora_layer_type(layer, lora_rank, scaling, lora_backend) return ret raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") @@ -297,13 +306,14 @@ def offload_from_gpu(self): class LoRAAdapter(nn.Module): - def __init__(self, uid, config, base_hf_config, load_config): + def __init__(self, uid, config, base_hf_config, load_config, lora_backend): super().__init__() self.uid = uid self.config = config assert self.config.hf_config["peft_type"].lower() == "lora" self.base_hf_config = base_hf_config self.load_config = load_config + self.lora_backend = lora_backend self.scaling = self.config.lora_alpha / self.config.r self.layers = nn.ModuleList( @@ -376,20 +386,25 @@ def initialize_weights(self): layer.weights.pop(weight_name) layer.weights.pop(v_name) else: - layer.weights[kv_name] = torch.cat( - ( + layer.weights[kv_name] = torch.stack( + [ layer.weights[weight_name], layer.weights[v_name], - ), - 0, + ], + dim=0, ) layer.weights.pop(weight_name) layer.weights.pop(v_name) elif "gate_proj" in weight_name: up_name = weight_name.replace("gate_proj", "up_proj") gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") - layer.weights[gate_up_name] = torch.cat( - (layer.weights[weight_name], layer.weights[up_name]), 0 - ) + if "lora_A" in weight_name: + layer.weights[gate_up_name] = torch.cat( + (layer.weights[weight_name], layer.weights[up_name]), 0 + ) + else: + layer.weights[gate_up_name] = torch.stack( + [layer.weights[weight_name], layer.weights[up_name]], dim=0 + ) layer.weights.pop(weight_name) layer.weights.pop(up_name) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 0449e25245..404f3f5070 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -20,16 +20,14 @@ import torch -from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer +from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend +from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import is_flashinfer_available, replace_submodule logger = logging.getLogger(__name__) -if is_flashinfer_available(): - from flashinfer import SegmentGEMMWrapper - def get_module_name(name): # Fallback solution of mapping from config module name to module name in model class. @@ -77,6 +75,20 @@ def get_stacked_name(name): return params_mapping.get(name, (name, name)) +def get_backend_from_name(name): + backend_mapping = { + "triton": TritonLoraBackend, + "flashinfer": FlashInferLoraBackend, + } + + if name in backend_mapping: + return backend_mapping[name] + + raise Exception( + f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}" + ) + + def get_layer_id(name): match = re.search(r"layers\.(\d+)\.", name) if match is None: @@ -93,6 +105,7 @@ def __init__( max_loras_per_batch, load_config, dtype, + lora_backend, ): self.base_model = base_model self.lora_paths = lora_paths @@ -101,8 +114,9 @@ def __init__( self.load_config = load_config self.dtype = dtype - workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") - self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) + logger.info(f"Using {lora_backend} as backend of Lora kernels.") + backend_type = get_backend_from_name(lora_backend) + self.lora_backend = backend_type(lora_backend) self.init_loras() self.init_lora_memory_pool() @@ -123,7 +137,7 @@ def get_target_modules(self): def set_lora_module(self, module_name, module): lora_module = get_lora_layer( - module, self.segment_gemm, self.max_lora_dim, self.scaling + module, self.max_lora_dim, self.scaling, self.lora_backend ) replace_submodule(self.base_model, module_name, lora_module) return lora_module @@ -162,7 +176,11 @@ def init_loras(self): self.lora_id[name] = len(self.loras) self.loras.append( LoRAAdapter( - name, self.configs[name], self.base_hf_config, self.load_config + name, + self.configs[name], + self.base_hf_config, + self.load_config, + self.lora_backend, ) ) self.loras[-1].initialize_weights() @@ -226,8 +244,9 @@ def init_lora_memory_pool(self): self.B_buffer[module_B] = [ torch.empty( ( + c, self.max_loras_per_batch, - hidden_dim_B * c, + hidden_dim_B, self.max_lora_dim, ), dtype=self.dtype, @@ -263,7 +282,16 @@ def load_lora(self, uid, buffer_id): else: lora_weight_name = self.get_weight_name(name, 1) if lora_weight_name: - self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights) + c = self.loras[-1].get_stacked_multiply(lora_weight_name) + if c > 1: + for j in range(c): + self.B_buffer[lora_weight_name][i][j][buffer_id].copy_( + weights[j] + ) + else: + self.B_buffer[lora_weight_name][i][0][buffer_id].copy_( + weights + ) def prepare_lora_batch(self, forward_batch: ForwardBatch): # load active loras into lora memory pool @@ -292,20 +320,30 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): if cur_uids == set([None]): return - # setup lora in forward modules + # set up batch info shared by all lora moruldes bs = forward_batch.batch_size seg_lens = ( forward_batch.extend_seq_lens if forward_batch.forward_mode.is_extend() else torch.ones(bs, device="cuda") ) - # FIXME: reuse the data rather than recompute seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) + max_len = int(torch.max(seg_lens)) weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") for i, lora_path in enumerate(forward_batch.lora_paths): weight_indices[i] = self.buffer_id[lora_path] + batch_info = LoraBatchInfo( + bs=bs, + seg_lens=seg_lens, + seg_indptr=seg_indptr, + max_len=max_len, + weight_indices=weight_indices, + ) + self.lora_backend.set_batch_info(batch_info) + + # call set_lora_info for each lora modules for module_name, module in self.lora_modules: layer_id = get_layer_id(module_name) @@ -314,16 +352,10 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): module.set_lora_info( self.A_buffer[weight_name][layer_id], self.B_buffer[weight_name][layer_id], - bs, - seg_indptr, - weight_indices, ) else: module.set_lora_info( self.A_buffer["qkv_proj"][layer_id], self.B_buffer["q_proj"][layer_id], self.B_buffer["kv_proj"][layer_id], - bs, - seg_indptr, - weight_indices, ) diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py new file mode 100644 index 0000000000..efc76bb8b4 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -0,0 +1,5 @@ +from .qkv_lora_b import qkv_lora_b_fwd +from .sgemm_lora_a import sgemm_lora_a_fwd +from .sgemm_lora_b import sgemm_lora_b_fwd + +__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"] diff --git a/python/sglang/srt/lora/triton_ops/qkv_lora_b.py b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py new file mode 100644 index 0000000000..3e090f4dc3 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/qkv_lora_b.py @@ -0,0 +1,182 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _qkv_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Parameters of size + K, # K = R + max_qkv_out_dim, # max(output_q_dim, output_kv_dim) + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Offsets of q/k/v slice on output dimension + n_offs, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # For fused output scaling and adding + fuse_scaling_add, + scaling, +): + # This kernel packs 3 sgemms (q/k/v) into a single kernel. + + # x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank + # weights: (num_lora, N_Q + 2 * N_KV, K) + # output: (s, N_Q + 2 * N_KV) + # N_Q >> K, N_KV >> K + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len. + # qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v) + batch_id = tl.program_id(axis=2) + qkv_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + n_start = tl.load(n_offs + qkv_id) + n_size = tl.load(n_offs + qkv_id + 1) - n_start + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id][n_start: n_end][:] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + + x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size) + if fuse_scaling_add: + partial_sum += tl.load(output_ptr, mask=output_mask) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def qkv_lora_b_fwd( + x: torch.Tensor, + qkv_lora_b: torch.Tensor, + batch_info: LoraBatchInfo, + output_offset: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor = None, + scaling: float = 1.0, +) -> torch.Tensor: + + # x: (s, 3 * r) + # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r) + # output_offset = [0, output_dim_q, output_dim_q + output_dim_kv, + # output_dim_q + 2 * output_dim_kv] + # max_qkv_out_dim = max(output_dim_q, output_dim_kv) + # output: (s, output_dim_q + 2 * output_dim_kv) + + # Compute lora_output with shape (s, output_dim) as follows: + # lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], ) + # lora_output[:, output_dim_q: output_dim_q + output_dim_kv] + # = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0]) + # lora_output[:, output_dim_q + output_dim_kv: ] + # = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1]) + + # Get dims + s = x.shape[0] + input_dim = x.shape[1] + r = qkv_lora_b.shape[-1] + output_dim = qkv_lora_b.shape[-2] + assert input_dim == 3 * r + assert output_offset.shape[0] == 4 + + BLOCK_S = 16 + BLOCK_R = 16 + BLOCK_OUT = 64 + + grid_b = ( + triton.cdiv(batch_info.max_len, BLOCK_S) + * triton.cdiv(max_qkv_out_dim, BLOCK_OUT), + 3, # this dimension decides current block computes on q, k or v + batch_info.bs, + ) + + if base_output is None: + output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype) + fuse_scaling_add = False + else: + output = base_output + fuse_scaling_add = True + + _qkv_lora_b_kernel[grid_b]( + x, + qkv_lora_b, + output, + r, + max_qkv_out_dim, + x.stride(0), + x.stride(1), + qkv_lora_b.stride(0), + qkv_lora_b.stride(1), + qkv_lora_b.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + output_offset, + BLOCK_S, + BLOCK_OUT, + BLOCK_R, + fuse_scaling_add, + scaling, + ) + + return output diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py new file mode 100644 index 0000000000..305bb8c5f0 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py @@ -0,0 +1,143 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _sgemm_lora_a_kernel( + # Pointers to matrices + x, + weights, + output, + # Matrix dimensions + N, # r + K, # input_dim + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + + # x: (s, K), s is the sum of sequence lengths + # weights: (num_lora, N, K) + # output: (s, N) + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len + batch_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = (x + seg_start * x_stride_0) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def sgemm_lora_a_fwd( + x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo +) -> torch.Tensor: + # x: (s, input_dim) + # weights: (num_lora, r, input_dim) + # output: (s, r) + # when called by run_qkv_lora, the weights.shape[-2] will be 3 * r + # input_dim is much larger than r + + assert x.is_contiguous() + assert weights.is_contiguous() + assert len(x.shape) == 2 + assert len(weights.shape) == 3 + + S = x.shape[0] + R = weights.shape[-2] + K = weights.shape[-1] + assert x.shape[-1] == K + + # Block shapes + BLOCK_S = 16 + BLOCK_K = 256 + BLOCK_R = 16 + + grid = ( + triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R), + batch_info.bs, + ) + + output = torch.empty((S, R), device=x.device, dtype=x.dtype) + _sgemm_lora_a_kernel[grid]( + x, + weights, + output, + R, + K, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + BLOCK_S, + BLOCK_R, + BLOCK_K, + ) + return output diff --git a/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py new file mode 100644 index 0000000000..c0bc913630 --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py @@ -0,0 +1,159 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.lora.lora import LoraBatchInfo + + +@triton.jit +def _sgemm_lora_b_kernel( + # Pointers to matrices + x, + weights, + output, + # Matrix dimensions + N, # output_dim + K, # r + # Strides + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + # Information on sequence lengths and weight id + seg_lens, + seg_indptr, + weight_indices, + # Meta parameters + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + # For fused output scaling and adding + fuse_scaling_add, + scaling, +): + # x: (s, K), s is the sum of sequence lengths + # weights: (num_lora, N, K) + # output: (s, N) + + # Current block computes sequence with batch_id, + # which starts from row seg_start of x with length seg_len + batch_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + w_index = tl.load(weight_indices + batch_id) + seg_start = tl.load(seg_indptr + batch_id) + + # The tile in output matrix will have (pid_s, pid_n) as id + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + + # Create pointers for the first block of x and weights[batch_id] + # The pointers will be advanced as we move in the K direction + # and accumulate + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + x_ptrs = (x + seg_start * x_stride_0) + ( + s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + ) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Iteate to compute the block in output matrix + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) + and (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K), + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + # Store result to output matrix + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = (output + seg_start * output_stride_0) + ( + s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = s_offset[:, None] < seg_len + if fuse_scaling_add: + partial_sum += tl.load(output_ptr, mask=output_mask) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def sgemm_lora_b_fwd( + x: torch.Tensor, + weights: torch.Tensor, + batch_info: LoraBatchInfo, + base_output: torch.Tensor = None, + scaling: float = 1.0, +) -> torch.Tensor: + # x: (s, r) + # weights: (num_lora, output_dim, r) + # output: (s, output_dim) + # output_dim is much larger than r + + assert x.is_contiguous() + assert weights.is_contiguous() + assert len(x.shape) == 2 + assert len(weights.shape) == 3 + + S = x.shape[0] + N = weights.shape[-2] + R = weights.shape[-1] + assert x.shape[-1] == R + + # Block shapes + BLOCK_S = 16 + BLOCK_R = 16 + BLOCK_N = 256 + + grid = ( + triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N), + batch_info.bs, + ) + + if base_output is None: + output = torch.empty((S, N), device=x.device, dtype=x.dtype) + fuse_scaling_add = False + else: + output = base_output + fuse_scaling_add = True + + _sgemm_lora_b_kernel[grid]( + x, + weights, + output, + N, + R, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + BLOCK_S, + BLOCK_N, + BLOCK_R, + fuse_scaling_add, + scaling, + ) + return output diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5b19c77e26..d125868b09 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -530,6 +530,7 @@ def init_lora_manager(self): max_loras_per_batch=self.server_args.max_loras_per_batch, load_config=self.load_config, dtype=self.dtype, + lora_backend=self.server_args.lora_backend, ) logger.info("LoRA manager ready.") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8c5ad0b96e..f90495824f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -113,6 +113,7 @@ class ServerArgs: # LoRA lora_paths: Optional[List[str]] = None max_loras_per_batch: int = 8 + lora_backend: str = "triton" # Kernel backend attention_backend: Optional[str] = None @@ -653,13 +654,19 @@ def add_cli_args(parser: argparse.ArgumentParser): nargs="*", default=None, action=LoRAPathAction, - help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}", + help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.", ) parser.add_argument( "--max-loras-per-batch", type=int, default=8, - help="Maximum number of adapters for a running batch, include base-only request", + help="Maximum number of adapters for a running batch, include base-only request.", + ) + parser.add_argument( + "--lora-backend", + type=str, + default="triton", + help="Choose the kernel backend for multi-LoRA serving.", ) # Kernel backend diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index bae0fcf2a4..6486b2550d 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -272,6 +272,7 @@ def __init__( port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, lora_paths: List[str] = None, max_loras_per_batch: int = 4, + lora_backend: str = "triton", disable_cuda_graph: bool = False, disable_radix_cache: bool = False, ): @@ -287,6 +288,7 @@ def __init__( is_embedding=not self.is_generation, lora_paths=lora_paths, max_loras_per_batch=max_loras_per_batch, + lora_backend=lora_backend, disable_cuda_graph=disable_cuda_graph, disable_radix_cache=disable_radix_cache, ) diff --git a/test/srt/models/test_lora_backend.py b/test/srt/models/test_lora_backend.py new file mode 100644 index 0000000000..6d61633004 --- /dev/null +++ b/test/srt/models/test_lora_backend.py @@ -0,0 +1,183 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing as mp +import unittest + +import torch + +from sglang.test.runners import HFRunner, SRTRunner +from sglang.test.test_utils import calculate_rouge_l + +LORA_SETS = [ + {"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]}, + # {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]} +] +TORCH_DTYPES = [torch.float16] + +PROMPTS = [ + "AI is a field of computer science focused on", + """ + ### Instruction: + Tell me about llamas and alpacas + ### Response: + Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. + ### Question 2: + What do you know about llamas? + ### Answer: + """, +] + +BACKENDS = ["triton", "flashinfer"] + +prefill_tolerance: float = 5e-2 +decode_tolerance: float = 5e-2 +rouge_l_tolerance: float = 1 + + +class TestLoRABackend(unittest.TestCase): + + def run_backend( + self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens, backend + ): + print(f"=================== testing {backend} backend =======================") + base_path = lora_set["base"] + all_lora_paths = lora_set["loras"] + batch_lora_paths = [] + i = 0 + for _ in range(len(prompts)): + batch_lora_paths.append(all_lora_paths[i]) + i = (i + 1) % len(all_lora_paths) + print(f"batch lora paths={batch_lora_paths}") + with SRTRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + tp_size=tp_size, + lora_paths=all_lora_paths, + max_loras_per_batch=3, + lora_backend=backend, + disable_cuda_graph=True, + disable_radix_cache=True, + ) as srt_runner: + srt_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + with HFRunner( + base_path, torch_dtype=torch_dtype, model_type="generation" + ) as hf_runner: + hf_outputs = hf_runner.forward( + prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths + ) + + with SRTRunner( + base_path, + tp_size=tp_size, + torch_dtype=torch_dtype, + model_type="generation", + ) as srt_runner: + srt_no_lora_outputs = srt_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + with HFRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + ) as hf_runner: + hf_no_lora_outputs = hf_runner.forward( + prompts, max_new_tokens=max_new_tokens + ) + + for i in range(len(prompts)): + print(f"Prompt {i} with lora path {batch_lora_paths[i]}:") + + # compare input logprobs + hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) + hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i]) + srt_no_lora_logprobs = torch.Tensor( + srt_no_lora_outputs.top_input_logprobs[i] + ) + print( + "max input diff between hf_lora and srt_lora", + torch.max(abs(hf_logprobs - srt_logprobs)), + ) + print( + "max input diff between srt_base and srt_lora", + torch.max(abs(srt_no_lora_logprobs - srt_logprobs)), + ) + print( + "max input diff between srt_base and hf_base", + torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)), + ) + print( + "max input diff between hf_lora and hf_base", + torch.max(abs(hf_logprobs - hf_no_lora_logprobs)), + ) + if hf_logprobs.shape[0] <= 100: + assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), ( + f"prefill logprobs are not all close with model_path={base_path}," + f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}" + f"prefill_tolerance={prefill_tolerance}." + f"{hf_logprobs=}, {srt_logprobs=}" + ) + + # compare output logprobs + hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i]) + print( + "max output diff between hf_lora and srt_lora", + torch.max(abs(hf_logprobs - srt_logprobs)), + "\n", + ) + if hf_logprobs.shape[0] <= 100: + assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), ( + f"decode logprobs are not all close with model_path={base_path}," + f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}" + f"decode_tolerance={decode_tolerance}." + f"{hf_logprobs=}, {srt_logprobs=}" + ) + + # compare output strings + srt_output_str = srt_outputs.output_strs[i].strip(" ") + hf_output_str = hf_outputs.output_strs[i] + print(f"srt_output_str={srt_output_str}") + print(f"hf_output_str={hf_output_str}") + rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str]) + print(f"{rouge_l_scores=}") + assert ( + rouge_l_scores[0] >= rouge_l_tolerance + ), f"ROUGE-L scores of prompt {i} outputs are greater than rouge_l_tolerance={rouge_l_tolerance}" + + def test_all(self): + for lora_set in LORA_SETS: + print(f"Testing lora set {lora_set}: ") + for torch_dtype in TORCH_DTYPES: + tp_size = 1 + max_new_tokens = 32 + for backend in BACKENDS: + self.run_backend( + PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens, backend + ) + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main(warnings="ignore") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 603bab957b..1fbb7f92f2 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -8,6 +8,7 @@ "models/test_embedding_models.py", "models/test_generation_models.py", "models/test_lora.py", + "models/test_lora_backend.py", "models/test_qwen_models.py", "models/test_reward_models.py", "sampling/penaltylib", From d39899e85c5c29b3aeb2ea36d19f59214de60336 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 4 Feb 2025 21:41:40 +0800 Subject: [PATCH 42/52] upgrade flashinfer v0.2.0.post2 (#3288) Co-authored-by: pankajroark --- .github/workflows/pr-test.yml | 16 +++--- python/pyproject.toml | 2 +- python/sglang/srt/entrypoints/engine.py | 4 +- .../layers/attention/flashinfer_backend.py | 56 +++++++------------ python/sglang/srt/speculative/eagle_utils.py | 5 ++ python/sglang/srt/speculative/eagle_worker.py | 2 + scripts/ci_install_dependency.sh | 7 ++- test/srt/run_suite.py | 1 - 8 files changed, 42 insertions(+), 51 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 6ed6046ee6..7fd91a5e9a 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -37,7 +37,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -60,7 +60,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -84,7 +84,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -121,7 +121,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -165,7 +165,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -196,7 +196,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -234,7 +234,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -258,7 +258,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer' }} run: | bash scripts/ci_install_dependency.sh diff --git a/python/pyproject.toml b/python/pyproject.toml index cf997fc964..c600ffc0d3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -26,7 +26,7 @@ runtime_common = [ srt = [ "sglang[runtime_common]", "cuda-python", "sgl-kernel>=0.0.3.post1", "torch", "vllm==0.6.4.post1", - "flashinfer==0.1.6", "outlines>=0.0.44,<0.1.0" + "flashinfer_python>=0.2.0.post2", "outlines>=0.0.44,<0.1.0" ] # HIP (Heterogeneous-computing Interface for Portability) for AMD diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 098a3d1e32..7f01e312cd 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -316,8 +316,8 @@ def _set_envs_and_config(server_args: ServerArgs): # Check flashinfer version if server_args.attention_backend == "flashinfer": assert_pkg_version( - "flashinfer", - "0.1.6", + "flashinfer_python", + "0.2.0.post2", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.", diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 863cb031db..1f701f9464 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -149,6 +149,7 @@ def __init__( BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, "NHD", + backend="fa2", ) ) self.prefill_wrappers_verify.append( @@ -313,7 +314,7 @@ def init_forward_metadata_capture_cuda_graph( paged_kv_indices_buf=self.cuda_graph_kv_indices[i], paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], custom_mask_buf=self.cuda_graph_custom_mask, - qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1], + mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1], ) ) seq_lens_sum = seq_lens.sum().item() @@ -1155,41 +1156,24 @@ def fast_decode_plan( self.last_page_len = torch.ones(32768, dtype=torch.int32) empty_q_data = self.empty_q_data empty_kv_cache = self.empty_kv_cache - if self.use_tensor_cores: - if not self.is_cuda_graph_enabled: - # when not using cudagraph, we need to create the indptr buffer, otherwise - # the buffer is already created during initialization - self._qo_indptr_buf = torch.arange( - batch_size + 1, dtype=torch.int32, device=indptr.device - ) - self._wrapper.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - self._qo_indptr_buf, - indptr, - batch_size, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - empty_q_data, - ) - else: - self._wrapper.plan( - self._float_workspace_buffer, - self._int_workspace_buffer, - indptr, - self.last_page_len, - batch_size, - num_qo_heads, - num_kv_heads, - head_dim, - page_size, - PosEncodingMode[pos_encoding_mode].value, - logits_soft_cap, - empty_q_data, - empty_kv_cache, - ) + stream = torch.cuda.current_stream() + self._cached_module.plan( + self._float_workspace_buffer, + self._int_workspace_buffer, + self._pin_memory_int_workspace_buffer, + indptr.to("cpu"), + batch_size, + num_qo_heads, + num_kv_heads, + page_size, + self.is_cuda_graph_enabled, + window_left, + logits_soft_cap, + head_dim, + empty_q_data, + empty_kv_cache, + stream.cuda_stream, + ) self._pos_encoding_mode = pos_encoding_mode self._window_left = window_left self._logits_soft_cap = logits_soft_cap diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 0b8c99f041..4abcba9550 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -69,6 +69,7 @@ def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_step accept_length_cpu = batch.spec_info.accept_length_cpu batch.extend_lens = [x + 1 for x in accept_length_cpu] batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend + batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend seq_lens_cpu = batch.seq_lens.tolist() pt = 0 @@ -353,8 +354,12 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten ] if has_finished: draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] + draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[ + unfinished_index + ] else: draft_input.seq_lens_for_draft_extend = batch.seq_lens + draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices logits_output.next_token_logits = logits_output.next_token_logits[accept_index] return ( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index b5a3de6cae..6d84cc3051 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -269,6 +269,7 @@ def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): def forward_draft_extend_after_decode(self, batch: ScheduleBatch): seq_lens_backup = batch.seq_lens + req_pool_indices_backup = batch.req_pool_indices self._set_mem_pool(batch, self.model_runner) batch.forward_mode = ForwardMode.DRAFT_EXTEND @@ -284,6 +285,7 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch): # This is because `seq_lens` can be modified in `prepare_extend_after_decode` batch.forward_mode = ForwardMode.DECODE batch.seq_lens = seq_lens_backup + batch.req_pool_indices = req_pool_indices_backup def capture_for_decode( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index 1a059d5ff6..ffe405d5aa 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -4,16 +4,17 @@ set -euxo pipefail # Install the dependency in CI. # Use repo from environment variable, passed from GitHub Actions -FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.4/flashinfer}" +FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.5/flashinfer}" SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" bash "${SCRIPT_DIR}/killall_sglang.sh" pip install --upgrade pip -pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ +pip uninstall flashinfer -y +pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ # Force reinstall flashinfer and torch_memory_saver -pip install flashinfer==0.1.6 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps +pip install flashinfer_python==0.2.0.post2 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps pip install torch_memory_saver --force-reinstall pip install transformers==4.45.2 sentence_transformers accelerate peft diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 1fbb7f92f2..039fde96a7 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -52,7 +52,6 @@ "test_vision_llm.py", "test_vision_openai_server.py", "test_w8a8_quantization.py", - "test_fp8_kvcache.py", "test_fp8_kernel.py", ], "nightly": [ From 2c1a695ff111cadc200cf97d4c2cbfe95ebecb70 Mon Sep 17 00:00:00 2001 From: HAI Date: Tue, 4 Feb 2025 05:44:44 -0800 Subject: [PATCH 43/52] ROCm: sgl-kernel enablement starting with sgl_moe_align_block (#3287) --- docker/Dockerfile.rocm | 3 + docs/start/install.md | 4 +- python/pyproject.toml | 2 +- .../layers/moe/fused_moe_triton/fused_moe.py | 14 +-- sgl-kernel/setup_rocm.py | 92 +++++++++++++++++++ .../src/sgl-kernel/torch_extension_rocm.cc | 29 ++++++ 6 files changed, 131 insertions(+), 13 deletions(-) create mode 100644 sgl-kernel/setup_rocm.py create mode 100644 sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index caa4666c88..480f80854b 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -28,6 +28,9 @@ RUN git clone ${SGL_REPO} \ echo "Using ${SGL_BRANCH} branch."; \ git checkout ${SGL_BRANCH}; \ fi \ + && cd sgl-kernel \ + && python setup_rocm.py install \ + && cd .. \ && if [ "$BUILD_TYPE" = "srt" ]; then \ python -m pip --no-cache-dir install -e "python[srt_hip]"; \ else \ diff --git a/docs/start/install.md b/docs/start/install.md index b9702f0215..fc1a936c68 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -32,7 +32,9 @@ git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip -pip install sgl-kernel --force-reinstall --no-deps +cd sgl-kernel +python setup_rocm.py install +cd .. pip install -e "python[all_hip]" ``` diff --git a/python/pyproject.toml b/python/pyproject.toml index c600ffc0d3..f87a2702b3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -31,7 +31,7 @@ srt = [ # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl -srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"] +srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11", "sgl-kernel>=0.0.3.post1"] # xpu is not enabled in public vllm and torch whl, # need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<0.1.0"] diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 32c8fcbb62..fab71809b1 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -15,18 +15,10 @@ from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import ( - direct_register_custom_op, - get_device_name, - is_cuda_available, - is_hip, -) +from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip -is_cuda = is_cuda_available() is_hip_flag = is_hip() -if is_cuda: - from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size - +from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @@ -415,7 +407,7 @@ def moe_align_block_size( ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) if num_experts >= 224: - if enable_moe_align_block_size_triton or is_hip_flag: + if enable_moe_align_block_size_triton: moe_align_block_size_triton( topk_ids, num_experts, diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py new file mode 100644 index 0000000000..6530cd7c74 --- /dev/null +++ b/sgl-kernel/setup_rocm.py @@ -0,0 +1,92 @@ +# Copyright 2025 SGLang Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing +import os +import sys +from pathlib import Path + +import torch +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +root = Path(__file__).parent.resolve() + +if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv: + sys.argv.extend(["--plat-name", "manylinux2014_x86_64"]) + + +def _get_version(): + with open(root / "pyproject.toml") as f: + for line in f: + if line.startswith("version"): + return line.split("=")[1].strip().strip('"') + + +operator_namespace = "sgl_kernels" +include_dirs = [ + root / "src" / "sgl-kernel" / "include", + root / "src" / "sgl-kernel" / "csrc", +] + +sources = [ + "src/sgl-kernel/torch_extension_rocm.cc", + "src/sgl-kernel/csrc/moe_align_kernel.cu", +] + +cxx_flags = ["-O3"] +libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"] +extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] + +hipcc_flags = [ + "-DNDEBUG", + f"-DOPERATOR_NAMESPACE={operator_namespace}", + "-O3", + "-Xcompiler", + "-fPIC", + "-std=c++17", + "-D__HIP_PLATFORM_AMD__=1", + "--amdgpu-target=gfx942", + "-DENABLE_BF16", + "-DENABLE_FP8", +] + +setup( + name="sgl-kernel", + version=_get_version(), + packages=find_packages(), + package_dir={"": "src"}, + ext_modules=[ + CUDAExtension( + name="sgl_kernel.ops._kernels", + sources=sources, + include_dirs=include_dirs, + extra_compile_args={ + "nvcc": hipcc_flags, + "cxx": cxx_flags, + }, + libraries=libraries, + extra_link_args=extra_link_args, + py_limited_api=True, + ), + ], + cmdclass={ + "build_ext": BuildExtension.with_options( + use_ninja=True, max_jobs=multiprocessing.cpu_count() + ) + }, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, + install_requires=["torch"], +) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc b/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc new file mode 100644 index 0000000000..22f40da109 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc @@ -0,0 +1,29 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "sgl_kernels_ops.h" + +TORCH_LIBRARY_EXPAND(sgl_kernels, m) { + // moe_align_block_size + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); +} + +REGISTER_EXTENSION(_kernels) From a07364ccc5008d09c7b58c5ca057bf321358f2f3 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 4 Feb 2025 23:26:04 +0800 Subject: [PATCH 44/52] Update Triton decode backend interface (#3292) --- .../srt/layers/attention/triton_backend.py | 78 ++++++++++++-- .../attention/triton_ops/decode_attention.py | 101 ++++++++---------- test/srt/test_triton_attention_kernels.py | 27 ++--- 3 files changed, 129 insertions(+), 77 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index fade8ed292..c0f3bdb832 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -5,6 +5,9 @@ import torch from sglang.srt.layers.attention import AttentionBackend +from sglang.srt.layers.attention.flashinfer_backend import ( + create_flashinfer_kv_indices_triton, +) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -29,6 +32,12 @@ def __init__(self, model_runner: ModelRunner): self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_fwd + max_bs = model_runner.req_to_token_pool.size + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.num_head = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) @@ -58,11 +67,32 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) max_extend_len = None + + kv_indptr = self.kv_indptr + bs = len(forward_batch.req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda" + ) + create_flashinfer_kv_indices_triton[(bs,)]( + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + forward_batch.req_to_token_pool.req_to_token.stride(0), + ) + else: attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() - self.forward_metadata = attn_logits, max_extend_len + kv_indptr = None + kv_indices = None + + self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len @@ -73,7 +103,12 @@ def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_attn_logits = torch.empty( (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), dtype=torch.float32, - device="cuda", + device=self.device, + ) + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.cuda_graph_max_seq_len), + dtype=torch.int32, + device=self.device, ) def init_forward_metadata_capture_cuda_graph( @@ -90,9 +125,25 @@ def init_forward_metadata_capture_cuda_graph( assert forward_mode.is_decode(), "Not supported" assert spec_info is None, "Not supported" + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + self.forward_metadata = ( self.cuda_graph_attn_logits, None, + kv_indptr, + kv_indices, ) def init_forward_metadata_replay_cuda_graph( @@ -109,6 +160,20 @@ def init_forward_metadata_replay_cuda_graph( self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + def get_cuda_graph_seq_len_fill_value(self): return 1 @@ -132,7 +197,7 @@ def forward_extend( layer, forward_batch.out_cache_loc, k, v ) - _, max_extend_len = self.forward_metadata + _, max_extend_len, _, _ = self.forward_metadata self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -170,7 +235,7 @@ def forward_decode( else: o = torch.empty_like(q) - attn_logits, _ = self.forward_metadata + attn_logits, _, kv_indptr, kv_indices = self.forward_metadata if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( @@ -182,9 +247,8 @@ def forward_decode( forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - forward_batch.req_to_token_pool.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, + kv_indptr, + kv_indices, attn_logits, self.num_kv_splits, layer.scaling, diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 512900bd30..f2274322c5 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -49,11 +49,9 @@ def _fwd_kernel_stage1( K_Buffer, V_Buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, Att_Out, - stride_req_to_tokens_b, stride_qbs, stride_qh, stride_buf_kbs, @@ -82,8 +80,9 @@ def _fwd_kernel_stage1( offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lk mask_dv = offs_dv < Lv - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d q = tl.load(Q + off_q, mask=mask_d, other=0.0) @@ -100,7 +99,7 @@ def _fwd_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + kv_indices + cur_batch_kv_start_idx + offs_n, mask=offs_n < split_kv_end, other=0, ) @@ -173,9 +172,8 @@ def _decode_att_m_fwd( k_buffer, v_buffer, att_out, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, @@ -188,7 +186,7 @@ def _decode_att_m_fwd( Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] - batch, head_num = B_req_idx.shape[0], q.shape[1] + batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] grid = (batch, head_num, NUM_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[1] @@ -208,11 +206,9 @@ def _decode_att_m_fwd( k_buffer, v_buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, att_out, - Req_to_tokens.stride(0), q.stride(0), q.stride(1), k_buffer.stride(0), @@ -241,11 +237,9 @@ def _fwd_grouped_kernel_stage1( K_Buffer, V_Buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, Att_Out, - stride_req_to_tokens_b, stride_qbs, stride_qh, stride_buf_kbs, @@ -284,8 +278,9 @@ def _fwd_grouped_kernel_stage1( offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lk mask_dv = offs_dv < Lv - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) @@ -312,7 +307,7 @@ def _fwd_grouped_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + kv_indices + cur_batch_kv_start_idx + offs_n, mask=offs_n < split_kv_end, other=0, ) @@ -400,9 +395,8 @@ def _decode_grouped_att_m_fwd( k_buffer, v_buffer, att_out, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, @@ -426,7 +420,7 @@ def _decode_grouped_att_m_fwd( BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - batch, head_num = B_req_idx.shape[0], q.shape[1] + batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1] BLOCK_H = 16 @@ -450,11 +444,9 @@ def _decode_grouped_att_m_fwd( k_buffer, v_buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, att_out, - Req_to_tokens.stride(0), q.stride(0), q.stride(1), k_buffer.stride(0), @@ -485,7 +477,7 @@ def _decode_grouped_att_m_fwd( def _fwd_kernel_stage2( Mid_O, O, - B_Seqlen, + kv_indptr, stride_mid_ob, stride_mid_oh, stride_mid_os, @@ -498,7 +490,9 @@ def _fwd_kernel_stage2( cur_batch = tl.program_id(0) cur_head = tl.program_id(1) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load( + kv_indptr + cur_batch + ) offs_d = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lv @@ -542,7 +536,7 @@ def _decode_softmax_reducev_fwd( q, o, v_buffer, - b_seq_len, + kv_indptr, num_kv_splits, ): batch, head_num = q.shape[0], q.shape[1] @@ -561,7 +555,7 @@ def _decode_softmax_reducev_fwd( _fwd_kernel_stage2[grid]( logits, o, - b_seq_len, + kv_indptr, logits.stride(0), logits.stride(1), logits.stride(2), @@ -581,9 +575,8 @@ def decode_attention_fwd_normal( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -594,14 +587,13 @@ def decode_attention_fwd_normal( k_buffer, v_buffer, attn_logits, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) def decode_attention_fwd_grouped( @@ -609,9 +601,8 @@ def decode_attention_fwd_grouped( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -622,14 +613,13 @@ def decode_attention_fwd_grouped( k_buffer, v_buffer, attn_logits, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) def decode_attention_fwd( @@ -637,9 +627,8 @@ def decode_attention_fwd( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -655,9 +644,8 @@ def decode_attention_fwd( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -670,9 +658,8 @@ def decode_attention_fwd( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 2398af9b0a..52a20771b3 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -194,10 +194,12 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): # o will have the same shape as q o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") - req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) - b_req_idx = torch.arange(B, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0) + kv_indices = torch.arange(total_tokens, device="cuda") + attn_logits = torch.empty( (B, H_Q, num_kv_splits, D + 1), dtype=torch.float32, @@ -209,9 +211,8 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -250,10 +251,12 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") - req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) - b_req_idx = torch.arange(B, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0) + kv_indices = torch.arange(total_tokens, device="cuda") + attn_logits = torch.empty( (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, @@ -265,9 +268,8 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -284,9 +286,8 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): k_buffer, v_buffer, o_grouped, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits1, num_kv_splits, sm_scale, From 6186a8f8897eb1fd2a38d67405a9d698663e9307 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 5 Feb 2025 00:44:35 +0800 Subject: [PATCH 45/52] update flashinfer install index url (#3293) --- docker/Dockerfile | 16 ++++++++-------- docs/start/install.md | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index cec05825d0..264397f851 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -41,26 +41,26 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ && cd sglang \ && if [ "$BUILD_TYPE" = "srt" ]; then \ if [ "$CUDA_VERSION" = "12.1.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ - python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer/; \ python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ else \ echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ fi; \ else \ if [ "$CUDA_VERSION" = "12.1.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ - python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ - python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer/; \ python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ else \ echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ diff --git a/docs/start/install.md b/docs/start/install.md index fc1a936c68..38bce59f0e 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -6,7 +6,7 @@ You can install SGLang using any of the methods below. ``` pip install --upgrade pip pip install sgl-kernel --force-reinstall --no-deps -pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ +pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ ``` Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. @@ -19,7 +19,7 @@ cd sglang pip install --upgrade pip pip install sgl-kernel --force-reinstall --no-deps -pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ +pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ ``` Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. If you meet with issue like **ImportError: cannot import name `_grouped_size_compiled_for_decode_kernels`**, installing FlashInfer with some older version like 0.1.6 instead of the latest version could solve it. From c7256ca836cf56c55c845ad1ceb38426e64f88b2 Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Tue, 4 Feb 2025 12:34:57 -0600 Subject: [PATCH 46/52] [ROCm] Add tuning configs for AMD Radeon Graphics. (#3294) --- ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 ++++++++++++++ ...14336,device_name=AMD_Radeon_Graphics.json | 200 ++++++++++++++++++ ...=1792,device_name=AMD_Radeon_Graphics.json | 200 ++++++++++++++++++ ...=3584,device_name=AMD_Radeon_Graphics.json | 200 ++++++++++++++++++ ...me=AMD_Radeon_Graphics,dtype=fp8_w8a8.json | 178 ++++++++++++++++ ...=7168,device_name=AMD_Radeon_Graphics.json | 200 ++++++++++++++++++ ...me=AMD_Radeon_Graphics,dtype=fp8_w8a8.json | 175 +++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 ++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 ++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 ++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 ++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 ++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 ++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 ++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 ++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 ++++++++++++++ 16 files changed, 2793 insertions(+) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..a7be90051f --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 0000000000..6a976788f9 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 0000000000..0a46390b2e --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 0000000000..91011e64c7 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json new file mode 100644 index 0000000000..bb17743b60 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json @@ -0,0 +1,178 @@ +{ + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16384": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32768": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "65536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "131072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 0000000000..f807d4a5ab --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json new file mode 100644 index 0000000000..92c41a28be --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json @@ -0,0 +1,175 @@ +{ + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16384": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32768": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "65536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "131072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..c098ef2dbb --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..6f5adbb936 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..4225c78eb7 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..5e6789d00e --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..49ac14d2a5 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..dcbb0efc53 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..dfe5c1e43d --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..a87f5de1b1 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 0000000000..468f9e78da --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} From c2723a42a58f0272e41b5e81bb1fadd11ed42f99 Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Tue, 4 Feb 2025 17:15:40 -0600 Subject: [PATCH 47/52] [ROCm] Manually unroll _w8a8_block_fp8_matmul kernel on AMD GPU. (#3299) --- .../srt/layers/quantization/fp8_kernel.py | 133 +++++++++++++++++- 1 file changed, 132 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index fe57838e59..8443f8dd63 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -220,6 +220,132 @@ def _w8a8_block_fp8_matmul( tl.store(c_ptrs, c, mask=c_mask) +@triton.jit +def _w8a8_block_fp8_matmul_unrolledx4( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output + tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # manually unroll to 4 iterations + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K) // 4): + # 1st iteration + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # 2nd iteration + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k_start + BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # 3rd iteration + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k_start + BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # 4th iteration + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k_start + BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + @functools.lru_cache def get_w8a8_block_fp8_configs( N: int, K: int, block_n: int, block_k: int @@ -324,7 +450,12 @@ def grid(META): triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - _w8a8_block_fp8_matmul[grid]( + # Use manually unrolledx4 kernel on AMD GPU. + kernel = ( + _w8a8_block_fp8_matmul_unrolledx4 if is_hip_ == True else _w8a8_block_fp8_matmul + ) + + kernel[grid]( A, B, C, From 4885b908021768a5220034290edd412c349c925f Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Wed, 5 Feb 2025 10:58:17 +0800 Subject: [PATCH 48/52] Use forward_cuda to execute custom op for hip platform (#3305) Co-authored-by: wunhuang --- python/sglang/srt/custom_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index c35790691e..d770e9c085 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -20,7 +20,7 @@ def forward_cuda(self, *args, **kwargs): raise NotImplementedError def forward_hip(self, *args, **kwargs): - return self.forward_native(*args, **kwargs) + return self.forward_cuda(*args, **kwargs) def forward_xpu(self, *args, **kwargs): return self.forward_native(*args, **kwargs) From 7ab84948d87d2c264cccc4ae8c1db339b9efea6a Mon Sep 17 00:00:00 2001 From: "Wen-Heng (Jack) Chung" Date: Tue, 4 Feb 2025 21:12:20 -0600 Subject: [PATCH 49/52] [ROCm] Logic to decide whether to used manually unrolled kernel. (#3306) --- python/sglang/srt/layers/quantization/fp8_kernel.py | 13 ++++++++++--- python/sglang/srt/utils.py | 7 +++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 8443f8dd63..ddd614fdfd 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -22,7 +22,7 @@ import triton import triton.language as tl -from sglang.srt.utils import get_device_name, is_hip +from sglang.srt.utils import get_device_core_count, get_device_name, is_hip is_hip_ = is_hip() fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn @@ -450,9 +450,16 @@ def grid(META): triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) - # Use manually unrolledx4 kernel on AMD GPU. + # Use manually unrolledx4 kernel on AMD GPU when the grid size is small. + # Empirical testing shows the sweet spot lies when it's less than the # of + # compute units available on the device. + num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( + N, config["BLOCK_SIZE_N"] + ) kernel = ( - _w8a8_block_fp8_matmul_unrolledx4 if is_hip_ == True else _w8a8_block_fp8_matmul + _w8a8_block_fp8_matmul_unrolledx4 + if (is_hip_ == True and num_workgroups <= get_device_core_count()) + else _w8a8_block_fp8_matmul ) kernel[grid]( diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ebb346bbc6..b1c49f5273 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1046,6 +1046,13 @@ def get_device_name(device_id: int = 0) -> str: return torch.hpu.get_device_name(device_id) +def get_device_core_count(device_id: int = 0) -> int: + if hasattr(torch, "cuda") and torch.cuda.is_available(): + return torch.cuda.get_device_properties(device_id).multi_processor_count + + return 0 + + def get_device_capability(device_id: int = 0) -> Tuple[int, int]: major, minor = None, None if hasattr(torch, "cuda") and torch.cuda.is_available(): From 76fa2d152c112688b9c9450cb1fcb10a0f85207b Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 5 Feb 2025 00:36:49 -0800 Subject: [PATCH 50/52] Fix lora flashinfer import bug on ROCM (#3312) --- python/sglang/srt/lora/backend/flashinfer_backend.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/backend/flashinfer_backend.py b/python/sglang/srt/lora/backend/flashinfer_backend.py index 5374a3e0a6..91c15be3c0 100644 --- a/python/sglang/srt/lora/backend/flashinfer_backend.py +++ b/python/sglang/srt/lora/backend/flashinfer_backend.py @@ -1,10 +1,13 @@ from typing import Tuple import torch -from flashinfer import SegmentGEMMWrapper from sglang.srt.lora.backend import BaseLoraBackend from sglang.srt.lora.lora import LoraBatchInfo +from sglang.srt.utils import is_flashinfer_available + +if is_flashinfer_available(): + from flashinfer import SegmentGEMMWrapper class FlashInferLoraBackend(BaseLoraBackend): From 7aad8d1854010de2cdfd057524d44fe538b160e1 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 5 Feb 2025 17:35:02 +0800 Subject: [PATCH 51/52] chore: bump v0.4.2.post2 (#3313) --- docker/Dockerfile.rocm | 2 +- docs/developer/setup_github_runner.md | 4 ++-- docs/start/install.md | 12 ++++++------ python/pyproject.toml | 2 +- python/sglang/version.py | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 480f80854b..01bc0137c2 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,5 +1,5 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm630 -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.4.2.post2 -t v0.4.2.post2-rocm630 -f Dockerfile.rocm . # default base image ARG BASE_IMAGE="rocm/vllm-dev:20250114" diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md index cde8c0aa90..16b442554e 100644 --- a/docs/developer/setup_github_runner.md +++ b/docs/developer/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm630 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post2-rocm630 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm630 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post2-rocm630 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/start/install.md b/docs/start/install.md index 38bce59f0e..19b4ab56ac 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -6,7 +6,7 @@ You can install SGLang using any of the methods below. ``` pip install --upgrade pip pip install sgl-kernel --force-reinstall --no-deps -pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ +pip install "sglang[all]>=0.4.2.post2" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/ ``` Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. @@ -14,7 +14,7 @@ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/ ## Method 2: From source ``` # Use the last release branch -git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2.post2 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -28,7 +28,7 @@ Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: ``` # Use the last release branch -git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2.post2 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -56,7 +56,7 @@ docker run --gpus all \ Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm630 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.2.post2 -t v0.4.2.post2-rocm630 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -65,11 +65,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.4.2.post1-rocm630 \ + v0.4.2.post2-rocm630 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.4.2.post1-rocm630 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.2.post2-rocm630 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/python/pyproject.toml b/python/pyproject.toml index f87a2702b3..d71cf0153e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.4.2.post1" +version = "0.4.2.post2" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" diff --git a/python/sglang/version.py b/python/sglang/version.py index d1b3e6d0ae..615d4c40d1 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.4.2.post1" +__version__ = "0.4.2.post2" From de5533341ee3c1b7667b1eb1f209b6825335d136 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Wed, 5 Feb 2025 18:12:22 +0800 Subject: [PATCH 52/52] Update Triton extend backend interface (#3309) --- .../attention/double_sparsity_backend.py | 4 +- .../srt/layers/attention/triton_backend.py | 68 +++- .../triton_ops/double_sparsity_attention.py | 340 +++++++++++++++++- .../attention/triton_ops/extend_attention.py | 57 ++- test/srt/test_triton_attention_kernels.py | 27 +- 5 files changed, 427 insertions(+), 69 deletions(-) diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index a5e54f32d5..c807e8753f 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): # Lazy import to avoid the initialization of cuda context from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import ( + extend_attention_fwd, flash_decode_attention_fwd, flash_decode_sparse_attention_fwd, ) - from sglang.srt.layers.attention.triton_ops.extend_attention import ( - extend_attention_fwd, - ) super().__init__() diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index c0f3bdb832..3475df7219 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -37,6 +37,9 @@ def __init__(self, model_runner: ModelRunner): (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.qo_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) self.num_head = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() @@ -54,6 +57,9 @@ def __init__(self, model_runner: ModelRunner): def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + if forward_batch.forward_mode.is_decode(): attn_logits = torch.empty( ( @@ -68,31 +74,59 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_extend_len = None - kv_indptr = self.kv_indptr - bs = len(forward_batch.req_pool_indices) kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( - forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda" + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device ) create_flashinfer_kv_indices_triton[(bs,)]( - forward_batch.req_to_token_pool.req_to_token, + self.req_to_token, forward_batch.req_pool_indices, forward_batch.seq_lens, kv_indptr, None, kv_indices, - forward_batch.req_to_token_pool.req_to_token.stride(0), + self.req_to_token.stride(0), ) + qo_indptr = None + custom_mask = None else: + kv_indptr[1 : bs + 1] = torch.cumsum( + forward_batch.extend_prefix_lens, dim=0 + ) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.extend_prefix_lens.sum().item(), + dtype=torch.int32, + device=self.device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.extend_prefix_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.qo_indptr + qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + custom_mask = None + attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() - kv_indptr = None - kv_indices = None - - self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices + self.forward_metadata = ( + attn_logits, + max_extend_len, + kv_indptr, + kv_indices, + qo_indptr, + custom_mask, + ) def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len @@ -144,6 +178,8 @@ def init_forward_metadata_capture_cuda_graph( None, kv_indptr, kv_indices, + None, + None, ) def init_forward_metadata_replay_cuda_graph( @@ -197,7 +233,9 @@ def forward_extend( layer, forward_batch.out_cache_loc, k, v ) - _, max_extend_len, _, _ = self.forward_metadata + _, max_extend_len, kv_indptr, kv_indices, qo_indptr, custom_mask = ( + self.forward_metadata + ) self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -205,11 +243,9 @@ def forward_extend( o.view(-1, layer.tp_q_head_num, layer.v_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), - forward_batch.req_to_token_pool.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.extend_seq_lens, - forward_batch.extend_start_loc, + qo_indptr, + kv_indptr, + kv_indices, max_extend_len, layer.scaling, layer.logit_cap, @@ -235,7 +271,7 @@ def forward_decode( else: o = torch.empty_like(q) - attn_logits, _, kv_indptr, kv_indices = self.forward_metadata + attn_logits, _, kv_indptr, kv_indices, _, _ = self.forward_metadata if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( diff --git a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py index 79e148e9c9..db0fb6b4db 100644 --- a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py @@ -3,6 +3,13 @@ import triton.language as tl from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.utils import is_hip + +is_cuda_available = torch.cuda.is_available() +if is_cuda_available: + CUDA_CAPABILITY = torch.cuda.get_device_capability() + +is_hip_ = is_hip() if global_server_args_dict.get("attention_reduce_in_fp32", False): REDUCE_TRITON_TYPE = tl.float32 @@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq): return -import torch - - def flash_decode_attention_fwd( q, k_buffer, @@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd( ) sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ) + + +# Extend attention kernel for Double Sparsity +# Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +@triton.jit +def _fwd_kernel( + Q_Extend, + K_Extend, + V_Extend, + O_Extend, + K_Buffer, + V_Buffer, + Req_to_tokens, + B_req_idx, + B_Seq_Len, + B_Start_Loc_Extend, + B_Seq_Len_Extend, + sm_scale, + kv_group_num, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_req_to_tokens_b, + logit_cap: tl.constexpr, + Lq: tl.constexpr, + Lv: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_seq = tl.program_id(0) + cur_head = tl.program_id(1) + cur_block_m = tl.program_id(2) + cur_kv_head = cur_head // kv_group_num + + cur_seq_len = tl.load(B_Seq_Len + cur_seq) + cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq) + cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend + + cur_seq_prefix_start_in_loc = 0 + cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq) + cur_batch_req_idx = tl.load(B_req_idx + cur_seq) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + offs_m = tl.arange(0, BLOCK_M) + mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend + + mask_d = offs_d < Lq + mask_dv = offs_dv < Lv + + offs_q = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + q = tl.load( + Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0 + ) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + offs_qpe = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_dpe[None, :] + ) + qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0) + + # stage 1: compute scores with prefix + offs_n = tl.arange(0, BLOCK_N) + + acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) + deno = tl.zeros([BLOCK_M], dtype=tl.float32) + e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + + for start_n in range(0, cur_seq_len_prefix, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_seq_len_prefix + offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + ( + cur_seq_prefix_start_in_loc + start_n + offs_n + ) + offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0) + + # load k in transposed way + offs_buf_k = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) + + qk = tl.dot(q.to(k.dtype), k) + if BLOCK_DPE > 0: + offs_kpe = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe.to(kpe.dtype), kpe) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_buf_v = ( + offs_kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + # stage 2: compute the trianlge part + + cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) + for start_n in range(0, cur_block_m_end, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_block_m_end + + # load k in transposed way + offs_k = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] + ) + k = tl.load( + K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) + + qk = tl.dot(q, k, out_dtype=tl.float32) + if BLOCK_DPE > 0: + offs_kpe = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) + * stride_kbs + + cur_kv_head * stride_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Extend + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe, kpe) + + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( + start_n + offs_n[None, :] + ) + mask_causual &= mask_m[:, None] & mask_n[None, :] + qk = tl.where(mask_causual, qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_v = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs + + cur_kv_head * stride_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + offs_o = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_obs + + cur_head * stride_oh + + offs_dv[None, :] + ) + tl.store( + O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :] + ) + + +def extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_seq_len_extend, + b_start_loc_extend, + max_len_extend, + sm_scale=None, + logit_cap=0.0, +): + """ + q_extend, k_extend, v_extend, o_extend: contiguous tensors + + k_buffer, v_buffer: (prefix + extend) tensors in mem_manager + """ + Lq, Lk, Lv = ( + q_extend.shape[-1], + k_extend.shape[-1], + v_extend.shape[-1], + ) + + if Lq == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lq == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + elif Lq == 192: + BLOCK_DMODEL = 128 + BLOCK_DPE = 64 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lq) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + if is_hip_: + BLOCK_M, BLOCK_N = (64, 64) + num_warps = 4 + + else: + if is_cuda_available and CUDA_CAPABILITY[0] >= 9: + if Lq <= 256: + BLOCK_M, BLOCK_N = (128, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: + if Lq <= 128: + BLOCK_M, BLOCK_N = (128, 128) + elif Lq <= 256: + BLOCK_M, BLOCK_N = (64, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + else: + BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) + + num_warps = 4 if Lk <= 64 else 8 + + sm_scale = sm_scale or 1.0 / (Lq**0.5) + batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] + kv_group_num = q_extend.shape[1] // k_extend.shape[1] + + grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) + num_stages = 1 + + extra_kargs = {} + if is_hip_: + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + + _fwd_kernel[grid]( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_start_loc_extend, + b_seq_len_extend, + sm_scale, + kv_group_num, + q_extend.stride(0), + q_extend.stride(1), + k_extend.stride(0), + k_extend.stride(1), + v_extend.stride(0), + v_extend.stride(1), + o_extend.stride(0), + o_extend.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), + req_to_tokens.stride(0), + logit_cap=logit_cap, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + Lq=Lq, + Lv=Lv, + num_warps=num_warps, + num_stages=num_stages, + **extra_kargs, + ) diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index b2654f1f78..6c9976931d 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -46,11 +46,9 @@ def _fwd_kernel( O_Extend, K_Buffer, V_Buffer, - Req_to_tokens, - B_req_idx, - B_Seq_Len, - B_Start_Loc_Extend, - B_Seq_Len_Extend, + qo_indptr, + kv_indptr, + kv_indices, sm_scale, kv_group_num, stride_qbs, @@ -65,7 +63,6 @@ def _fwd_kernel( stride_buf_kh, stride_buf_vbs, stride_buf_vh, - stride_req_to_tokens_b, logit_cap: tl.constexpr, Lq: tl.constexpr, Lv: tl.constexpr, @@ -80,13 +77,10 @@ def _fwd_kernel( cur_block_m = tl.program_id(2) cur_kv_head = cur_head // kv_group_num - cur_seq_len = tl.load(B_Seq_Len + cur_seq) - cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq) - cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend - - cur_seq_prefix_start_in_loc = 0 - cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq) - cur_batch_req_idx = tl.load(B_req_idx + cur_seq) + cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq) + cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx + cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq) + cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) @@ -97,7 +91,7 @@ def _fwd_kernel( mask_dv = offs_dv < Lv offs_q = ( - (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] @@ -109,7 +103,7 @@ def _fwd_kernel( if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) offs_qpe = ( - (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_dpe[None, :] @@ -126,10 +120,9 @@ def _fwd_kernel( for start_n in range(0, cur_seq_len_prefix, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) mask_n = (start_n + offs_n) < cur_seq_len_prefix - offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + ( - cur_seq_prefix_start_in_loc + start_n + offs_n + offs_kv_loc = tl.load( + kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0 ) - offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0) # load k in transposed way offs_buf_k = ( @@ -188,7 +181,7 @@ def _fwd_kernel( # load k in transposed way offs_k = ( - (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs + (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] ) @@ -199,8 +192,7 @@ def _fwd_kernel( qk = tl.dot(q, k, out_dtype=tl.float32) if BLOCK_DPE > 0: offs_kpe = ( - (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) - * stride_kbs + (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs + cur_kv_head * stride_kh + offs_dpe[:, None] ) @@ -228,7 +220,7 @@ def _fwd_kernel( deno = deno * re_scale + tl.sum(p, 1) offs_v = ( - (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs + (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs + cur_kv_head * stride_vh + offs_dv[None, :] ) @@ -241,7 +233,7 @@ def _fwd_kernel( e_max = n_e_max offs_o = ( - (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_dv[None, :] @@ -258,11 +250,9 @@ def extend_attention_fwd( o_extend, k_buffer, v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_seq_len_extend, - b_start_loc_extend, + qo_indptr, + kv_indptr, + kv_indices, max_len_extend, sm_scale=None, logit_cap=0.0, @@ -315,7 +305,7 @@ def extend_attention_fwd( num_warps = 4 if Lk <= 64 else 8 sm_scale = sm_scale or 1.0 / (Lq**0.5) - batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] + batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) @@ -332,11 +322,9 @@ def extend_attention_fwd( o_extend, k_buffer, v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_start_loc_extend, - b_seq_len_extend, + qo_indptr, + kv_indptr, + kv_indices, sm_scale, kv_group_num, q_extend.stride(0), @@ -351,7 +339,6 @@ def extend_attention_fwd( k_buffer.stride(1), v_buffer.stride(0), v_buffer.stride(1), - req_to_tokens.stride(0), logit_cap=logit_cap, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 52a20771b3..3617e17be2 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -45,16 +45,20 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): max_len_in_batch = torch.max(b_seq_len, 0)[0].item() b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") - req_to_tokens = torch.empty( - (B, max_len_in_batch), dtype=torch.int32, device="cuda" - ) b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) + kv_indices = torch.zeros( + (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda" + ) + for i in range(B): - req_to_tokens[i, : b_seq_len[i]] = torch.arange( - b_start_loc[i], b_start_loc[i] + b_seq_len[i] + kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i] ) total_token_num = torch.sum(b_seq_len).item() @@ -90,9 +94,10 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): ) b_seq_len_extend = b_seq_len - b_seq_len_prefix - b_start_loc_extend = torch.zeros_like(b_seq_len) - b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) + extend_attention_fwd( q_extend, k_extend, @@ -100,11 +105,9 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): o_extend, k_buffer, v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_seq_len_extend, - b_start_loc_extend, + qo_indptr, + kv_indptr, + kv_indices, max_len_extend, )