From 93682abbe531ea9faf1795f98a8052205187c80e Mon Sep 17 00:00:00 2001 From: aikitoria Date: Tue, 31 Dec 2024 00:23:21 +0100 Subject: [PATCH 1/5] Min-P prototype --- cpp/include/tensorrt_llm/executor/types.h | 23 ++ .../layers/defaultDecodingParams.h | 5 + .../kernels/samplingMinPKernels.cu | 294 ++++++++++++++++++ .../kernels/samplingMinPKernels.h | 138 ++++++++ cpp/tensorrt_llm/layers/decodingLayer.cpp | 14 +- cpp/tensorrt_llm/layers/decodingParams.h | 4 + cpp/tensorrt_llm/layers/minPSamplingLayer.cpp | 191 ++++++++++++ cpp/tensorrt_llm/layers/minPSamplingLayer.h | 57 ++++ cpp/tensorrt_llm/layers/samplingLayer.cpp | 17 +- cpp/tensorrt_llm/runtime/gptDecoder.cpp | 12 +- .../runtime/gptDecoderBatched.cpp | 10 +- 11 files changed, 752 insertions(+), 13 deletions(-) create mode 100644 cpp/tensorrt_llm/kernels/samplingMinPKernels.cu create mode 100644 cpp/tensorrt_llm/kernels/samplingMinPKernels.h create mode 100644 cpp/tensorrt_llm/layers/minPSamplingLayer.cpp create mode 100644 cpp/tensorrt_llm/layers/minPSamplingLayer.h diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h index 205d6e567..a83c432e1 100644 --- a/cpp/include/tensorrt_llm/executor/types.h +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -510,6 +510,11 @@ class DecodingMode return DecodingMode{kTopKTopP | kUsePenalties | kUseBanTokens | kStandardStopCriteria}; } + static auto constexpr MinP() + { + return DecodingMode{kMinP | (kUsePenalties & ~kUseTemperature) | kUseBanTokens | kStandardStopCriteria}; + } + static auto constexpr BeamSearch() { return DecodingMode{kBeamSearch | kUsePenalties | kUseBanTokens | kStandardStopCriteria}; @@ -612,6 +617,12 @@ class DecodingMode return *this; } + auto constexpr useMinP() + { + mState = kMinP | (mState & ~kTopKTopP & ~kUseTemperature); + return *this; + } + [[nodiscard]] bool constexpr isAuto() const { return anyBitSet(kAuto); @@ -637,6 +648,16 @@ class DecodingMode return allBitSet(kTopKTopP); } + [[nodiscard]] bool constexpr isMinP() const + { + return anyBitSet(kMinP); + } + + [[nodiscard]] bool constexpr isTopKorTopPorMinP() const + { + return anyBitSet(kTopKTopPMinP); + } + [[nodiscard]] bool constexpr isBeamSearch() const { return anyBitSet(kBeamSearch); @@ -783,6 +804,8 @@ class DecodingMode static UnderlyingType constexpr kExternalDraftTokens{1u << (kNumFlags + 7)}; static UnderlyingType constexpr kEagle{1u << (kNumFlags + 8)}; static UnderlyingType constexpr kTopKTopP{kTopK | kTopP}; + static UnderlyingType constexpr kMinP{1u << (kNumFlags + 9)}; + static UnderlyingType constexpr kTopKTopPMinP{kTopK | kTopP | kMinP}; [[nodiscard]] bool constexpr anyBitSet(UnderlyingType bits) const { diff --git a/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h b/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h index 7b2e05629..114ea4453 100644 --- a/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h +++ b/cpp/include/tensorrt_llm/layers/defaultDecodingParams.h @@ -80,6 +80,11 @@ class DefaultDecodingParams return 1.0e-6f; } + [[nodiscard]] __host__ __device__ static constexpr float getMinP() + { + return 0.0f; + } + [[nodiscard]] __host__ __device__ static constexpr runtime::TokenIdType getTopPResetId() { return -1; diff --git a/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu b/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu new file mode 100644 index 000000000..08e6c4ac9 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu @@ -0,0 +1,294 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#elif (CUDART_VERSION >= 11050) +#include +#else +#include "3rdparty/cub/cub.cuh" +#endif + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/memoryUtils.h" +#include "tensorrt_llm/common/reduceKernelUtils.cuh" +#include "tensorrt_llm/kernels/samplingMinPKernels.h" + +using namespace tensorrt_llm::common; +using namespace tensorrt_llm::runtime; + +namespace tensorrt_llm::kernels +{ +template +__global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType* outputIds, + TokenIdType** outputIdsPtrs, SizeType32* sequenceLengths, FinishedState const* finishedInput, + FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, SizeType32 vocabSize, + curandState_t* curandState, float const* randomVals, float const* minPs, float const* temperatures, + TokenIdType const* endIds, SizeType32 maxBatchSize, SizeType32 const* batchSlots, bool returnAllSelectedTokens, + SizeType32 maxSeqLen, TokenIdType* outputIdCurrentStep, bool const* skipOutputIdCurrentStep) +{ + auto const tid = static_cast(threadIdx.x); + auto const batchId = static_cast(blockIdx.x); + auto const batchSlot = batchSlots[batchId]; + + // Skip kernel if this sampling method is not chosen + FinishedState const finishState = finishedInput != nullptr ? finishedInput[batchSlot] : FinishedState::empty(); + if (finishState.isSkipDecoding()) + { + return; + } + + // Exit early if sequence has finished + if (finishState.isFinished()) + { + if (tid == 0) + { + if (finishedOutput != nullptr) + { + finishedOutput[batchSlot] = finishState; + } + } + return; + } + + // Each thread computes local maximum across its assigned probabilities + float threadMax = -FLT_MAX; + const int probsBeginIdx = batchId * vocabSize; + const int probsEndIdx = (batchId + 1) * vocabSize; + + #pragma unroll + for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) + { + float prob = static_cast(probs[idx]); + threadMax = max(threadMax, prob); + } + + // Find global maximum probability across all threads in block + threadMax = blockReduceMax(threadMax); + __shared__ float sMaxP; + __shared__ float sCutoffP; + + if (tid == 0) + { + sMaxP = threadMax; + sCutoffP = sMaxP * (minPs != nullptr ? minPs[batchSlot] : 0.0f); + } + __syncthreads(); + + // Adjust the probabilities and cache them + float threadAdjustedProbsSum = 0.0f; + float invTemp = 1.0f / (temperatures != nullptr ? temperatures[batchSlot] : 1.0f); + + #pragma unroll + for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) + { + float prob = static_cast(probs[idx]); + prob = (prob < sCutoffP) ? 0.0f : powf(prob, invTemp); + adjustedProbs[idx] = static_cast(prob); + threadAdjustedProbsSum += prob; + } + + // Find global sum of adjusted probabilities and determine quantization scale factor + threadAdjustedProbsSum = blockReduceSum(threadAdjustedProbsSum); + __shared__ float sAdjustedProbsSum; + __shared__ float sQuantizeScaleFactor; + + if (tid == 0) + { + sAdjustedProbsSum = threadAdjustedProbsSum; + sQuantizeScaleFactor = UINT32_MAX / threadAdjustedProbsSum; + } + __syncthreads(); + + // We will now quantize the probabilities to integers to avoid numerical errors + // when trying to find the selected point in the prefix sum of the probabilities. + // We map the adjusted distribution between [0, UINT32_MAX] to avoid overflow. + + // Compute the sum of the quantized probabilities for each thread + uint32_t threadQuantProbsSum = 0; + + #pragma unroll + for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) + { + float prob = static_cast(adjustedProbs[idx]); + threadQuantProbsSum += static_cast(prob * sQuantizeScaleFactor); + } + + // Compute a global prefix sum of the quantized probabilities + uint32_t threadQuantProbsPrefix; + uint32_t totalQuantProbsSum; + + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage tempStorage; + + BlockScan(tempStorage).ExclusiveSum(threadQuantProbsSum, threadQuantProbsPrefix, totalQuantProbsSum); + + // Select a random point in the distribution + __shared__ uint32_t sRandomPoint; + + if (tid == 0) + { + // Rescale uniform random val to be within the sum of quantized probabilities + float randomVal = randomVals != nullptr ? randomVals[batchSlot] : curand_uniform(&curandState[batchSlot]); + sRandomPoint = static_cast(randomVal * totalQuantProbsSum); + } + __syncthreads(); + + // All but one warps will terminate on this condition + if (sRandomPoint < threadQuantProbsPrefix || sRandomPoint >= threadQuantProbsPrefix + threadQuantProbsSum) + { + return; + } + + // Find the selected token id and write it to the output buffer + threadQuantProbsSum = threadQuantProbsPrefix; + + for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) + { + float prob = static_cast(adjustedProbs[idx]); + uint32_t scaledProb = static_cast(prob * sQuantizeScaleFactor); + + if (sRandomPoint >= threadQuantProbsSum && sRandomPoint < threadQuantProbsSum + scaledProb) + { + auto const selectedTokenIdx = idx - probsBeginIdx; + auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot]; + auto* outPtr = outputIdsPtrs == nullptr ? outputIds + batchSlot * maxSeqLen : outputIdsPtrs[batchSlot]; + outPtr[curSeqLen] = selectedTokenIdx; + + if (!returnAllSelectedTokens && sequenceLengths != nullptr && finishedOutput != nullptr && endIds != nullptr) + { + if (selectedTokenIdx == endIds[batchSlot]) + { + // This request has finished + finishedOutput[batchSlot].setFinishedEOS(); + } + else + { + // This request must generate more tokens + sequenceLengths[batchSlot] += 1; + } + } + return; + } + + threadQuantProbsSum += scaledProb; + } +} + +template +std::vector getMinPWorkspaceSizes(SizeType32 batchSize, SizeType32 vocabSize) +{ + auto const adjustedProbBufSize = sizeof(T) * batchSize * vocabSize; + + return {adjustedProbBufSize}; +} + +template std::vector getMinPWorkspaceSizes(SizeType32 batchSize, SizeType32 vocabSize); +template std::vector getMinPWorkspaceSizes(SizeType32 batchSize, SizeType32 vocabSize); + +template +std::vector getMinPInitWorkspaceSizes(SizeType32 batchSize) +{ + auto const tempMinPsBufSize = batchSize * sizeof(float); + auto const tempTemperaturesBufSize = batchSize * sizeof(float); + + return {tempMinPsBufSize, tempTemperaturesBufSize}; +} + +template std::vector getMinPInitWorkspaceSizes(SizeType32 batchSize); +template std::vector getMinPInitWorkspaceSizes(SizeType32 batchSize); + +template +size_t getMinPWorkspaceSize(SizeType32 batchSize, SizeType32 vocabSizePadded) +{ + auto const workspaceSizes = getMinPWorkspaceSizes(batchSize, vocabSizePadded); + auto const initWorkspaceSizes = getMinPInitWorkspaceSizes(batchSize); + + return std::max(tensorrt_llm::common::calcAlignedSize(workspaceSizes, 256), + tensorrt_llm::common::calcAlignedSize(initWorkspaceSizes, 256)); +} + +template size_t getMinPWorkspaceSize(SizeType32 batchSize, SizeType32 vocabSizePadded); +template size_t getMinPWorkspaceSize(SizeType32 batchSize, SizeType32 vocabSizePadded); + +template +void invokeBatchMinPSampling(MinPSamplingKernelParams const& params, cudaStream_t stream) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + params.checkParams(); + auto const workspaceSizes = getMinPWorkspaceSizes(params.batchSize, params.vocabSizePadded); + + std::vector alignedPointers; + calcAlignedPointers(alignedPointers, params.workspace, workspaceSizes); + + auto adjustedProbs = static_cast(alignedPointers[0]); + + // Sample with Min P filter and late temperature in single pass + SizeType32 constexpr SAMPLING_BLOCK_SIZE = 1024; + dim3 grid(params.batchSize); + fusedMinPSsampling<<>>(params.probs, adjustedProbs, + params.outputIds, params.outputIdsPtrs, params.sequenceLength, params.finishedInput, params.finishedOutput, + params.cumLogProbs, params.outputLogProbs, params.vocabSizePadded, params.curandState, params.randomVals, + params.minPs, params.temperatures, params.endIds, params.maxBatchSize, params.batchSlots, params.returnAllSelectedTokens, + params.maxSeqLen, params.outputIdCurrentStep, params.skipOutputIdCurrentStep); + + sync_check_cuda_error(); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template void invokeBatchMinPSampling(MinPSamplingKernelParams const& params, cudaStream_t stream); + +template void invokeBatchMinPSampling(MinPSamplingKernelParams const& params, cudaStream_t stream); + +__device__ __host__ inline void setupMinPRuntimeArg(runtime::SizeType32 batchIndex, + ScatterDecodingParamEntry minP, ScatterDecodingParamEntry temperature, + runtime::SizeType32 const* batchSlots) +{ + auto const batchSlot = batchSlots[batchIndex]; + auto const p = minP.mVector == nullptr ? minP.mScalar : minP.mVector[batchIndex]; + auto const t = temperature.mVector == nullptr ? temperature.mScalar : temperature.mVector[batchIndex]; + + if (minP.mTarget != nullptr) + { + minP.mTarget[batchSlot] = p; + } + + if (temperature.mTarget != nullptr) + { + temperature.mTarget[batchSlot] = t; + } +} + +__global__ void setMinPRuntimeArgs(SizeType32 batchSize, ScatterDecodingParamEntry minP, + ScatterDecodingParamEntry temperature, SizeType32 const* batchSlotsPtr) +{ + auto index = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + for (SizeType32 bi = index; bi < batchSize; bi += static_cast(gridDim.x * blockDim.x)) + { + setupMinPRuntimeArg(bi, minP, temperature, batchSlotsPtr); + } +} + +void invokeSetMinPRuntimeArgs(SizeType32 batchSize, ScatterDecodingParamEntry minP, + ScatterDecodingParamEntry temperature, SizeType32 const* batchSlotsPtr, + cudaStream_t stream) +{ + dim3 block(std::min(static_cast(batchSize), 256u)); + dim3 grid(divUp(static_cast(batchSize), block.x)); + setMinPRuntimeArgs<<>>( + batchSize, minP, temperature, batchSlotsPtr); +} + +} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/kernels/samplingMinPKernels.h b/cpp/tensorrt_llm/kernels/samplingMinPKernels.h new file mode 100644 index 000000000..854669460 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/samplingMinPKernels.h @@ -0,0 +1,138 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/kernels/decodingCommon.h" +#include "tensorrt_llm/runtime/common.h" +#include + +namespace tensorrt_llm::kernels +{ +template +struct MinPSamplingKernelParams +{ + //! input buffer [batchSize, vocabSizePadded], required. Probabilities of each token in the vocab. + T const* probs{nullptr}; + + //! output buffer [maxBatchSize][maxSeqLen]. Contains pointers to rows with output tokens per request. + //! If nullptr, outputIds must be provided. + runtime::TokenIdType** outputIdsPtrs{nullptr}; + + //! output buffer [maxBatchSize, maxSeqLen], optional. Tensor to store output tokens. + //! Not used if outputIdsPtrs != nullptr + runtime::TokenIdType* outputIds{nullptr}; + + //! pointer to the workspace. Has to be pre-allocated by caller. + //! Function does not take ownership of the buffer. + void* workspace{nullptr}; + + //! input buffer [maxBatchSize]. P for MinP sampling per request. Supported P is in range [0.0; 1.0]. + //! 1.0 will always select the token with the highest probability. + //! 0.0 will disable the MinP filter and sample from all tokens. + //! If nullptr, MinP of 0.0 is used for all requests. + float const* minPs{nullptr}; + + //! input buffer [maxBatchSize]. Temperature per request for late temperature adjustment. + //! If nullptr, temperature of 1.0 is used for all requests. + float const* temperatures{nullptr}; + + //! input/output buffer [maxBatchSize], required. Current sequence length of the request up to, but excluding endId + //! token. + runtime::SizeType32* sequenceLength{nullptr}; + //! input buffer [maxBatchSize], optional. EOS token ids per request + runtime::TokenIdType const* endIds{nullptr}; + //! input buffer[batchSize], optional. Indices of rows of data in memory pool. + runtime::SizeType32 const* batchSlots{nullptr}; + + //! input buffer [maxBatchSize], optional. Exit early if true. + FinishedState const* finishedInput{nullptr}; + //! output buffer [maxBatchSize], optional. Set flag if sequence has finished (if finished || outputId == endId). + FinishedState* finishedOutput{nullptr}; + //! input buffer [maxBatchSize], optional. Flags whether to skip decoding per request + bool const* skipDecode{nullptr}; + + //! input/output buffer [maxBatchSize], optional. Cumulative log probability of selected tokens. Ignored if nullptr. + float* cumLogProbs{nullptr}; + //! output buffer [maxBatchSize], optional. Log probs is the probability induced by the MinP sampling. + //! I.e., log_prob = log P(i | i is in vocab). + float* outputLogProbs{nullptr}; + //! input buffer [maxBatchSize], optional. Curand states properly initialized using + //! invokeCurandInitialize per request. Either curandState or randomVals should be specified. + curandState_t* curandState{nullptr}; + //! input buffer [maxBatchSize], optional. Precomputed random values per request. + //! Either curandState or randomVals should be specified. + float const* randomVals{nullptr}; + + runtime::SizeType32 batchSize{-1}; + runtime::SizeType32 maxBatchSize{-1}; + runtime::SizeType32 vocabSizePadded{-1}; + runtime::SizeType32 maxSeqLen{-1}; + + bool returnAllSelectedTokens{false}; + + //! output buffer [maxBatchSize], optional. + //! Store the multinomial sampled target token id in TopK/MinP sampled tokens when returnAllSelectedTokens==True. + //! Only return when skipOutputIdCurrentStep != nullptr && skipOutputIdCurrentStep == False + runtime::TokenIdType* outputIdCurrentStep{nullptr}; + //! input buffer [maxBatchSize]. Determine if multinomial sampling is required when returnAllSelectedTokens==True. + bool const* skipOutputIdCurrentStep{nullptr}; + + void checkParams() const + { + TLLM_CHECK(batchSize > 0); + TLLM_CHECK(maxBatchSize > 0); + TLLM_CHECK(maxBatchSize >= batchSize); + TLLM_CHECK(vocabSizePadded > 0); + TLLM_CHECK(probs); + TLLM_CHECK(outputIds || outputIdsPtrs); + TLLM_CHECK(workspace); + TLLM_CHECK((curandState != nullptr) || (randomVals != nullptr)); + TLLM_CHECK(((curandState != nullptr) & (randomVals != nullptr)) == 0); + TLLM_CHECK(minPs); + TLLM_CHECK(temperatures); + + if (outputIds) + { + TLLM_CHECK(maxSeqLen > 0); + } + + TLLM_CHECK(((finishedOutput == nullptr) ^ (endIds == nullptr)) == 0); + TLLM_CHECK((skipOutputIdCurrentStep && outputIdCurrentStep && returnAllSelectedTokens) + || (skipOutputIdCurrentStep == nullptr && outputIdCurrentStep == nullptr)); + } +}; + +//! \brief Returns workspace size in bytes needed for sampling MinP computation +//! \param batchSize batch size +//! \param vocabSizePadded size of padded vocab +template +[[nodiscard]] size_t getMinPWorkspaceSize(runtime::SizeType32 batchSize, runtime::SizeType32 vocabSizePadded); + +//! \brief Returns workspace size in bytes needed for initialization of sampling MinP +//! \param batchSize batch size +template +[[nodiscard]] std::vector getMinPInitWorkspaceSizes(runtime::SizeType32 batchSize); + +//! \brief Given probs, performs Min P sampling. Fills sampled tokens to outputIds. +//! Updates sequenceLength, finished state, cumLogProbs inplace. +//! Sampling per request can be controlled using MinPs parameter. +template +void invokeBatchMinPSampling(MinPSamplingKernelParams const& params, cudaStream_t stream); + +void invokeSetMinPRuntimeArgs(runtime::SizeType32 batchSize, ScatterDecodingParamEntry minP, + ScatterDecodingParamEntry temperature, runtime::SizeType32 const* batchSlotsPtr, + cudaStream_t stream = nullptr); + +} // namespace tensorrt_llm::kernels diff --git a/cpp/tensorrt_llm/layers/decodingLayer.cpp b/cpp/tensorrt_llm/layers/decodingLayer.cpp index d35987b5c..38d7544f9 100644 --- a/cpp/tensorrt_llm/layers/decodingLayer.cpp +++ b/cpp/tensorrt_llm/layers/decodingLayer.cpp @@ -41,7 +41,7 @@ DecodingLayer::DecodingLayer(executor::DecodingMode const& mode, DecoderDomai { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - if (mDecodingMode.isTopKorTopP()) + if (mDecodingMode.isTopKorTopPorMinP()) { mDecodingLayer = std::make_unique>(mDecodingMode, decoderDomain, mBufferManager); } @@ -90,10 +90,10 @@ void DecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorC TLLM_CHECK_WITH_INFO(setupParams->decodingParams, "decodingParams for setup is not set"); - if (mDecodingMode.isTopKorTopP()) + if (mDecodingMode.isTopKorTopPorMinP()) { // sampling layers TLLM_CHECK_WITH_INFO( - beamWidth == 1, "Decoding mode is TopK and/or TopP, but beamWidth != 1 (%d != 1)", beamWidth); + beamWidth == 1, "Decoding mode is TopK and/or TopP or MinP, but beamWidth != 1 (%d != 1)", beamWidth); mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams, workspace); } else if (mDecodingMode.isBeamSearch()) @@ -131,7 +131,7 @@ void DecodingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorC else { TLLM_CHECK_WITH_INFO(false, - "Decoding mode is none of the supported {TopK, TopP, TopKTopP, BeamSearch, Medusa, Lookahead, " + "Decoding mode is none of the supported {TopK, TopP, TopKTopP, MinP, BeamSearch, Medusa, Lookahead, " "ExplicitDraftTokens, ExternalDraftTokens, Eagle}"); } @@ -186,14 +186,14 @@ std::tuple, std::shared_ptrite; auto const step = params->step; auto const localBatchSize = static_cast(params->localBatchSize); TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() == 1, - "Decoding mode is TopK and/or TopP, but beamWidth != 1 (%d != 1)", localDecoderDomain.getBeamWidth()); + "Decoding mode is TopK and/or TopP or MinP, but beamWidth != 1 (%d != 1)", localDecoderDomain.getBeamWidth()); // In sampling, we have supported batch sampling. So, we always compute all // sentences once. @@ -239,7 +239,7 @@ std::tuple, std::shared_ptr(externalDraftTokenParams->localBatchSize); TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() == 1, - "Decoding mode is TopK and/or TopP, but beamWidth != 1 (%d != 1)", localDecoderDomain.getBeamWidth()); + "Decoding mode is TopK and/or TopP or MinP, but beamWidth != 1 (%d != 1)", localDecoderDomain.getBeamWidth()); // In sampling, we have supported batch sampling. So, we always compute all // sentences once. diff --git a/cpp/tensorrt_llm/layers/decodingParams.h b/cpp/tensorrt_llm/layers/decodingParams.h index 54900676e..603784207 100644 --- a/cpp/tensorrt_llm/layers/decodingParams.h +++ b/cpp/tensorrt_llm/layers/decodingParams.h @@ -151,12 +151,16 @@ class SamplingSetupParams : public DecodingSetupParams // baseSamplingLayer std::optional> runtimeTopK; // [1] or [setupBatchSize] on cpu std::optional> runtimeTopP; // [1] or [setupBatchSize] on cpu + std::optional> runtimeMinP; // [1] or [setupBatchSize] on cpu // topPSamplingLayer std::optional> topPDecay; // [setupBatchSize], must between [0, 1] std::optional> topPMin; // [setupBatchSize], must between [0, 1] std::optional> topPResetIds; // [setupBatchSize] std::optional normalizeLogProbs; + + // minPSamplingLayer needs access to temperature + std::shared_ptr penaltyParams; }; class BeamSearchSetupParams : public DecodingSetupParams diff --git a/cpp/tensorrt_llm/layers/minPSamplingLayer.cpp b/cpp/tensorrt_llm/layers/minPSamplingLayer.cpp new file mode 100644 index 000000000..848322a6d --- /dev/null +++ b/cpp/tensorrt_llm/layers/minPSamplingLayer.cpp @@ -0,0 +1,191 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minPSamplingLayer.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/memoryUtils.h" +#include "tensorrt_llm/common/nvtxUtils.h" +#include "tensorrt_llm/kernels/decodingCommon.h" +#include "tensorrt_llm/kernels/samplingMinPKernels.h" +#include "tensorrt_llm/layers/defaultDecodingParams.h" +#include "tensorrt_llm/layers/layerUtils.h" + +#include +#include + +using namespace tensorrt_llm::common; +using namespace tensorrt_llm::kernels; +using namespace tensorrt_llm::runtime; + +namespace tensorrt_llm::layers +{ + +template +MinPSamplingLayer::MinPSamplingLayer(DecoderDomain const& decoderDomain, + std::shared_ptr bufferManager) + : BaseLayer(decoderDomain, bufferManager) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + allocateBuffer(mDecoderDomain.getBatchSize()); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void MinPSamplingLayer::allocateBuffer(SizeType32 batchSize) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + mWorkspaceSize = getMinPWorkspaceSize(batchSize, mDecoderDomain.getVocabSizePadded()); + + auto const batchSizeShape = ITensor::makeShape({batchSize}); + mRuntimeMinPDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); + mTemperatureDevice = mBufferManager->gpu(batchSizeShape, TRTDataType::value); + + mSetupWorkspaceSize = std::max({ + mRuntimeMinPDevice->getSizeInBytes(), + mTemperatureDevice->getSizeInBytes() + }); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void MinPSamplingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorConstPtr batchSlots, + std::shared_ptr const& baseSetupParams, + std::shared_ptr const& workspace) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto setupParams = std::dynamic_pointer_cast(baseSetupParams); + + auto defaultMinP = DefaultDecodingParams::getMinP(); + auto defaultTemperature = DefaultDecodingParams::getTemperature(); + + auto runtimeMinP = setupParams->runtimeMinP.value_or(std::vector{defaultMinP}); + auto temperature = setupParams->penaltyParams->temperature.value_or(std::vector{defaultTemperature}); + + auto const paramsSize = expandMatchElements(batchSize, runtimeMinP, temperature); + TLLM_CHECK_WITH_INFO(paramsSize != 0, + fmtstr("MinPSamplingLayer got parameter with unexpected size, want 1 or batchSize(%d), got" + "runtimeMinP.size() = %zu, " + "temperature.size() = %zu", + batchSize, runtimeMinP.size(), temperature.size())); + + for (size_t i = 0; i < paramsSize; ++i) + { + auto& currMinP = runtimeMinP[i]; + auto& currTemperature = temperature[i]; + + if (currMinP <= 0.f) + { + TLLM_LOG_WARNING( + "Min (%f) is out of range ((0.0, inf]). Change to default (%f).", currMinP, defaultMinP); + + currMinP = defaultMinP; + } + + if (currTemperature <= 0.f) + { + TLLM_LOG_WARNING( + "Temperature (%f) is out of range ((0.0, inf]). Change to default (%f).", currTemperature, defaultTemperature); + + currTemperature = defaultTemperature; + } + } + + float* MinPsPtr = nullptr; + float* TemperaturesPtr = nullptr; + + if (paramsSize > 1) + { + auto initWorkspaceSizes = getMinPInitWorkspaceSizes(batchSize); + std::vector alignedPointers; + calcAlignedPointers(workspace->getRawWorkspaceDevicePtr(), initWorkspaceSizes)(MinPsPtr); + + DecodingLayerWorkspace::copyToWorkspace( + *mBufferManager, runtimeMinP, IBuffer::wrap(MinPsPtr, initWorkspaceSizes[0] / sizeof(*MinPsPtr))); + + DecodingLayerWorkspace::copyToWorkspace( + *mBufferManager, temperature, IBuffer::wrap(TemperaturesPtr, initWorkspaceSizes[1] / sizeof(*TemperaturesPtr))); + } + + auto const* batchSlotsDevicePtr = workspace->getDeviceBatchSlotsPtr(); + + invokeSetMinPRuntimeArgs(batchSize, + {MinPsPtr, runtimeMinP.front(), bufferCast(*mRuntimeMinPDevice)}, + {TemperaturesPtr, temperature.front(), bufferCast(*mTemperatureDevice)}, + batchSlotsDevicePtr, getStream()); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +void MinPSamplingLayer::forwardAsync(std::shared_ptr const& outputs, + std::shared_ptr const& baseInputs, + std::shared_ptr const& workspace) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + NVTX3_SCOPED_RANGE(MinPSamplingLayer_forwardAsync); + + auto inputs = std::dynamic_pointer_cast(baseInputs); + auto const batchSize = inputs->logits.value()->getDimension<0>(); + + // Probabilities must be already computed instead of logits + auto probs = bufferCastOrNull(inputs->logits); + auto const* endIds = bufferCastOrNull(inputs->endIds); + + auto const* finishedInput = (inputs->finished) + ? reinterpret_cast( + bufferCastOrNull(inputs->finished.value())) + : nullptr; + + auto* finishedOutput = (outputs->finished) + ? reinterpret_cast( + bufferCastOrNull(outputs->finished.value())) + : nullptr; + + MinPSamplingKernelParams params{}; + params.probs = probs; + params.outputIdsPtrs = bufferCastOrNull(outputs->outputIdsPtr); + params.workspace = workspace->getRawWorkspaceDevicePtr(); + params.minPs = bufferCastOrNull(mRuntimeMinPDevice); + params.temperatures = bufferCastOrNull(mTemperatureDevice); + params.sequenceLength = bufferCastOrNull(outputs->sequenceLength); + params.endIds = endIds; + params.batchSlots = workspace->getDeviceBatchSlotsPtr(); + params.finishedInput = finishedInput; + params.finishedOutput = finishedOutput; + params.cumLogProbs = bufferCastOrNull(outputs->cumLogProbs); + params.outputLogProbs = bufferCastOrNull(outputs->outputLogProbsTiled); + params.curandState = inputs->curandStates; + params.batchSize = batchSize; + params.maxBatchSize = mDecoderDomain.getBatchSize(); + params.vocabSizePadded = mDecoderDomain.getVocabSizePadded(); + invokeBatchMinPSampling(params, getStream()); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +template +size_t MinPSamplingLayer::getWorkspaceSize() const noexcept +{ + return std::max(mSetupWorkspaceSize, mWorkspaceSize); +} + +template class MinPSamplingLayer; +template class MinPSamplingLayer; + +} // namespace tensorrt_llm::layers diff --git a/cpp/tensorrt_llm/layers/minPSamplingLayer.h b/cpp/tensorrt_llm/layers/minPSamplingLayer.h new file mode 100644 index 000000000..a0e17e550 --- /dev/null +++ b/cpp/tensorrt_llm/layers/minPSamplingLayer.h @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/layers/baseLayer.h" +#include "tensorrt_llm/runtime/common.h" + +namespace tensorrt_llm::layers +{ + +//! \brief Layer to randomly sample tokens from MinP logits. +//! Layer expects probs precomputed in "logits" tensor +template +class MinPSamplingLayer : public BaseLayer +{ + using Base = BaseLayer; + +public: + MinPSamplingLayer(DecoderDomain const& decoderDomain, std::shared_ptr bufferManager); + + void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, TensorConstPtr batchSlots, + std::shared_ptr const& setupParams, + std::shared_ptr const& workspace) override; + + void forwardAsync(std::shared_ptr const& outputs, + std::shared_ptr const& inputs, + std::shared_ptr const& workspace) override; + + //! @returns workspace needed for this layer in bytes + [[nodiscard]] size_t getWorkspaceSize() const noexcept override; + +protected: + TensorPtr mRuntimeMinPDevice; + TensorPtr mTemperatureDevice; + + size_t mWorkspaceSize{0}; + size_t mSetupWorkspaceSize{0}; + + using Base::mDecoderDomain; + +private: + void allocateBuffer(runtime::SizeType32 batchSize); +}; + +} // namespace tensorrt_llm::layers diff --git a/cpp/tensorrt_llm/layers/samplingLayer.cpp b/cpp/tensorrt_llm/layers/samplingLayer.cpp index 17c055875..c3ab13250 100644 --- a/cpp/tensorrt_llm/layers/samplingLayer.cpp +++ b/cpp/tensorrt_llm/layers/samplingLayer.cpp @@ -20,6 +20,7 @@ #include "tensorrt_llm/kernels/decodingCommon.h" #include "tensorrt_llm/layers/topKSamplingLayer.h" #include "tensorrt_llm/layers/topPSamplingLayer.h" +#include "tensorrt_llm/layers/minPSamplingLayer.h" #include "samplingLayer.h" #include @@ -40,18 +41,27 @@ SamplingLayer::SamplingLayer(executor::DecodingMode const& mode, DecoderDomai TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK_WITH_INFO(!mDecodingMode.isBeamSearch(), "SamplingLayer does not support Beam search mode"); - TLLM_CHECK_WITH_INFO(mDecodingMode.isTopKorTopP(), "SamplingLayer requires TopK or TopP mode"); + TLLM_CHECK_WITH_INFO(mDecodingMode.isTopKorTopPorMinP(), "SamplingLayer requires TopK or TopP or MinP mode"); if (mDecodingMode.isTopK()) { + TLLM_LOG_INFO("TopK sampling layer active"); mSamplingLayers.emplace_back(std::make_unique>(decoderDomain, mBufferManager)); } if (mDecodingMode.isTopP()) { + TLLM_LOG_INFO("TopP sampling layer active"); mSamplingLayers.emplace_back( std::make_unique>(decoderDomain, mBufferManager, /* deterministic */ true)); } + if (mDecodingMode.isMinP()) + { + TLLM_LOG_INFO("MinP sampling layer active"); + TLLM_CHECK_WITH_INFO(!mDecodingMode.isUseTemperature(), "MinP sampling is already fused with late temperature"); + mSamplingLayers.emplace_back(std::make_unique>(decoderDomain, mBufferManager)); + } + allocateBuffer(decoderDomain.getBatchSize()); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); @@ -136,9 +146,10 @@ void SamplingLayer::forwardAsync(std::shared_ptr const& : nullptr; auto const skipTopP = !mDecodingMode.isTopP(); + auto const skipMinP = !mDecodingMode.isMinP(); - // Compute probabilities either for TopP or if cumLogProbs or outputLogProbs are specified - bool const skipSoftMax = skipTopP && !mOutputLogProbs && !mCumLogProbs; + // Compute probabilities either for TopP or MinP or if cumLogProbs or outputLogProbs are specified + bool const skipSoftMax = skipTopP && skipMinP && !mOutputLogProbs && !mCumLogProbs; inputs->curandStates = reinterpret_cast(bufferCast(*mCurandStatesDevice)); inputs->probsComputed = !skipSoftMax; diff --git a/cpp/tensorrt_llm/runtime/gptDecoder.cpp b/cpp/tensorrt_llm/runtime/gptDecoder.cpp index c6572de40..0ce4a28b4 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoder.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoder.cpp @@ -145,7 +145,7 @@ void GptDecoder::setup(SamplingConfig const& samplingConfig, size_t batchSize setupParams->banWordsParams = std::move(banWordsParams); - if (mDecodingMode.isTopKorTopP()) + if (mDecodingMode.isTopKorTopPorMinP()) { auto samplingParams = std::make_shared(); samplingParams->normalizeLogProbs = mSamplingConfig.normalizeLogProbs; @@ -163,6 +163,14 @@ void GptDecoder::setup(SamplingConfig const& samplingConfig, size_t batchSize samplingParams->outputLogProbs = mSamplingConfig.outputLogProbs; samplingParams->cumLogProbs = mSamplingConfig.cumLogProbs; + if (mDecodingMode.isMinP()) + { + // TODO: This should have its own parameter! + // Also it needs shared access to the temperature settings. + samplingParams->runtimeMinP = mSamplingConfig.topP; + samplingParams->penaltyParams = setupParams->penaltyParams; + } + setupParams->decodingParams = std::move(samplingParams); } else if (mDecodingMode.isBeamSearch()) @@ -426,7 +434,7 @@ std::shared_ptr prepareInputs( TLLM_CHECK_WITH_INFO(input.batchSlots != nullptr, "Batch slots are mandatory to call the decoder."); std::shared_ptr forwardParams; - if (decodingMode.isTopKorTopP()) + if (decodingMode.isTopKorTopPorMinP()) { forwardParams = std::make_shared(input.endIds, input.batchSlots, input.step, ite, input.batchSize); diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp index 81a51f890..5541e7c38 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp @@ -419,7 +419,15 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max mDecoderStream = std::make_shared(); TLLM_CHECK(mDecoderStream->getDevice() == device); - mDecoder = IGptDecoder::create(mode, dtype, maxBatchSize, maxBeamWidth, mVocabSize, mVocabSizePadded, + // TODO: This should be passed from the executor, but I can't modify it! + // So we'll hack this to turn all TopP requests into MinP ones for now. + executor::DecodingMode modeCopy = mDecodingMode; + if (modeCopy.isTopKorTopP()) + { + modeCopy.useMinP(); + } + + mDecoder = IGptDecoder::create(modeCopy, dtype, maxBatchSize, maxBeamWidth, mVocabSize, mVocabSizePadded, mMaxSequenceLength, mDecoderStream, speculativeDecodingModulePtr); mNbSteps.clear(); From d068b0b5d84528eba00b98dc364b6d3be633b071 Mon Sep 17 00:00:00 2001 From: aikitoria Date: Tue, 31 Dec 2024 03:04:39 +0100 Subject: [PATCH 2/5] Fix numerical errors --- .../kernels/samplingMinPKernels.cu | 52 +++++++++++++++---- 1 file changed, 43 insertions(+), 9 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu b/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu index 08e6c4ac9..262b1b6cf 100644 --- a/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu +++ b/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu @@ -27,6 +27,8 @@ using namespace tensorrt_llm::common; using namespace tensorrt_llm::runtime; +#define DEBUG_MINP 0 + namespace tensorrt_llm::kernels { template @@ -41,6 +43,13 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType auto const batchId = static_cast(blockIdx.x); auto const batchSlot = batchSlots[batchId]; +#if DEBUG_MINP + if (tid == 0) + { + printf("Begin batch slot %d sequence length %d\n", batchSlot, sequenceLengths[batchSlot]); + } +#endif + // Skip kernel if this sampling method is not chosen FinishedState const finishState = finishedInput != nullptr ? finishedInput[batchSlot] : FinishedState::empty(); if (finishState.isSkipDecoding()) @@ -57,6 +66,10 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType { finishedOutput[batchSlot] = finishState; } + +#if DEBUG_MINP + printf("Batch slot %d already finished\n", batchSlot); +#endif } return; } @@ -75,13 +88,15 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType // Find global maximum probability across all threads in block threadMax = blockReduceMax(threadMax); - __shared__ float sMaxP; __shared__ float sCutoffP; if (tid == 0) { - sMaxP = threadMax; - sCutoffP = sMaxP * (minPs != nullptr ? minPs[batchSlot] : 0.0f); + sCutoffP = threadMax * (minPs != nullptr ? minPs[batchSlot] : 0.0f); + +#if DEBUG_MINP + printf("Batch slot %d maxP %f cutoffP %f\n", batchSlot, threadMax, sCutoffP); +#endif } __syncthreads(); @@ -106,7 +121,14 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType if (tid == 0) { sAdjustedProbsSum = threadAdjustedProbsSum; - sQuantizeScaleFactor = UINT32_MAX / threadAdjustedProbsSum; + + // Do division with doubles and round down to avoid special cases like + // 4294967295 / 32768 giving us 131072 rather than the desired 131071 + sQuantizeScaleFactor = __double2float_rd((double)(UINT32_MAX - vocabSize) / (double)threadAdjustedProbsSum); + +#if DEBUG_MINP + printf("Batch slot %d adjustedProbsSum %f quantizeScaleFactor %f\n", batchSlot, threadAdjustedProbsSum, sQuantizeScaleFactor); +#endif } __syncthreads(); @@ -121,7 +143,7 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) { float prob = static_cast(adjustedProbs[idx]); - threadQuantProbsSum += static_cast(prob * sQuantizeScaleFactor); + threadQuantProbsSum += __float2uint_rd(prob * sQuantizeScaleFactor); } // Compute a global prefix sum of the quantized probabilities @@ -140,7 +162,11 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType { // Rescale uniform random val to be within the sum of quantized probabilities float randomVal = randomVals != nullptr ? randomVals[batchSlot] : curand_uniform(&curandState[batchSlot]); - sRandomPoint = static_cast(randomVal * totalQuantProbsSum); + sRandomPoint = min(__float2uint_rd(randomVal * totalQuantProbsSum), totalQuantProbsSum - 1); + +#if DEBUG_MINP + printf("Batch slot %d totalQuantProbsSum %u randomPoint %u\n", batchSlot, totalQuantProbsSum, sRandomPoint); +#endif } __syncthreads(); @@ -156,15 +182,23 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) { float prob = static_cast(adjustedProbs[idx]); - uint32_t scaledProb = static_cast(prob * sQuantizeScaleFactor); + uint32_t quantProb = __float2uint_rd(prob * sQuantizeScaleFactor); - if (sRandomPoint >= threadQuantProbsSum && sRandomPoint < threadQuantProbsSum + scaledProb) + if (sRandomPoint < threadQuantProbsSum + quantProb) { auto const selectedTokenIdx = idx - probsBeginIdx; auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot]; auto* outPtr = outputIdsPtrs == nullptr ? outputIds + batchSlot * maxSeqLen : outputIdsPtrs[batchSlot]; outPtr[curSeqLen] = selectedTokenIdx; +#if DEBUG_MINP + printf("Batch slot %d selected token %d original prob %f adjusted prob %f normalized %f\n", + batchSlot, selectedTokenIdx, static_cast(probs[idx]), prob, prob / sAdjustedProbsSum); + + printf("Batch slot %d thread index %d prefix %d sum %u quant prob %u\n", + batchSlot, tid, threadQuantProbsPrefix, threadQuantProbsSum, quantProb); +#endif + if (!returnAllSelectedTokens && sequenceLengths != nullptr && finishedOutput != nullptr && endIds != nullptr) { if (selectedTokenIdx == endIds[batchSlot]) @@ -181,7 +215,7 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType return; } - threadQuantProbsSum += scaledProb; + threadQuantProbsSum += quantProb; } } From 2eeb6038a7602501882ea3dab9549993ad7cab4e Mon Sep 17 00:00:00 2001 From: aikitoria Date: Tue, 31 Dec 2024 03:25:16 +0100 Subject: [PATCH 3/5] Improve debug --- .../kernels/samplingMinPKernels.cu | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu b/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu index 262b1b6cf..d2abc3fb7 100644 --- a/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu +++ b/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu @@ -100,6 +100,28 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType } __syncthreads(); +#if DEBUG_MINP + // Print how many probabilities are above the cutoff + int threadNumProbsAboveCutoff = 0; + + #pragma unroll + for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) + { + float prob = static_cast(probs[idx]); + if (prob >= sCutoffP) + { + threadNumProbsAboveCutoff++; + } + } + + threadNumProbsAboveCutoff = blockReduceSum(threadNumProbsAboveCutoff); + + if (tid == 0) + { + printf("Batch slot %d numProbsAboveCutoff %d\n", batchSlot, threadNumProbsAboveCutoff); + } +#endif + // Adjust the probabilities and cache them float threadAdjustedProbsSum = 0.0f; float invTemp = 1.0f / (temperatures != nullptr ? temperatures[batchSlot] : 1.0f); @@ -195,7 +217,7 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType printf("Batch slot %d selected token %d original prob %f adjusted prob %f normalized %f\n", batchSlot, selectedTokenIdx, static_cast(probs[idx]), prob, prob / sAdjustedProbsSum); - printf("Batch slot %d thread index %d prefix %d sum %u quant prob %u\n", + printf("Batch slot %d thread index %d prefix %u sum %u quant prob %u\n", batchSlot, tid, threadQuantProbsPrefix, threadQuantProbsSum, quantProb); #endif From ce175b5c3184050d7054645b1101b79f2894f731 Mon Sep 17 00:00:00 2001 From: aikitoria Date: Fri, 3 Jan 2025 01:03:55 +0100 Subject: [PATCH 4/5] Remove quantized prefix sum in favor of BlockShuffle, remove cached scaled probs --- .../kernels/samplingMinPKernels.cu | 171 +++++++++--------- 1 file changed, 82 insertions(+), 89 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu b/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu index d2abc3fb7..f616886c8 100644 --- a/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu +++ b/cpp/tensorrt_llm/kernels/samplingMinPKernels.cu @@ -15,8 +15,10 @@ #error CUDART_VERSION Undefined! #elif (CUDART_VERSION >= 11050) #include +#include // Why is it not in the monolithic cub.cuh? #else #include "3rdparty/cub/cub.cuh" +#include "3rdparty/cub/block/block_shuffle.cuh" // Why is it not in the monolithic cub.cuh? #endif #include "tensorrt_llm/common/cudaUtils.h" @@ -31,12 +33,22 @@ using namespace tensorrt_llm::runtime; namespace tensorrt_llm::kernels { +// Shared state for MinP sampling kernels template -__global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType* outputIds, - TokenIdType** outputIdsPtrs, SizeType32* sequenceLengths, FinishedState const* finishedInput, - FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs, SizeType32 vocabSize, - curandState_t* curandState, float const* randomVals, float const* minPs, float const* temperatures, - TokenIdType const* endIds, SizeType32 maxBatchSize, SizeType32 const* batchSlots, bool returnAllSelectedTokens, +struct BlockScanShuffleStorage +{ + union { + typename cub::BlockScan::TempStorage scan; + typename cub::BlockShuffle::TempStorage shuffle; + }; +}; + +template +__global__ void fusedMinPSsampling(T const* probs, TokenIdType* outputIds, TokenIdType** outputIdsPtrs, + SizeType32* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, + float* cumLogProbs, float* outputLogProbs, SizeType32 vocabSize, curandState_t* curandState, + float const* randomVals, float const* minPs, float const* temperatures, TokenIdType const* endIds, + SizeType32 maxBatchSize, SizeType32 const* batchSlots, bool returnAllSelectedTokens, SizeType32 maxSeqLen, TokenIdType* outputIdCurrentStep, bool const* skipOutputIdCurrentStep) { auto const tid = static_cast(threadIdx.x); @@ -74,28 +86,35 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType return; } - // Each thread computes local maximum across its assigned probabilities - float threadMax = -FLT_MAX; + // Common stride for all arrays const int probsBeginIdx = batchId * vocabSize; const int probsEndIdx = (batchId + 1) * vocabSize; + // Each thread computes local maximum across its assigned probabilities + float threadMaxProb = -FLT_MAX; + #pragma unroll for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) { - float prob = static_cast(probs[idx]); - threadMax = max(threadMax, prob); + auto const prob = static_cast(probs[idx]); + threadMaxProb = max(threadMaxProb, prob); } // Find global maximum probability across all threads in block - threadMax = blockReduceMax(threadMax); + threadMaxProb = blockReduceMax(threadMaxProb); __shared__ float sCutoffP; + __shared__ float sInvTemp; if (tid == 0) { - sCutoffP = threadMax * (minPs != nullptr ? minPs[batchSlot] : 0.0f); + // Probs below this value will be ignored + sCutoffP = threadMaxProb * (minPs != nullptr ? minPs[batchSlot] : 0.0f); + + // Inverse temperature for scaling probabilities + sInvTemp = 1.0f / (temperatures != nullptr ? temperatures[batchSlot] : 1.0f); #if DEBUG_MINP - printf("Batch slot %d maxP %f cutoffP %f\n", batchSlot, threadMax, sCutoffP); + printf("Batch slot %d maxP %f cutoffP %f\n", batchSlot, threadMaxProb, sCutoffP); #endif } __syncthreads(); @@ -107,8 +126,7 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType #pragma unroll for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) { - float prob = static_cast(probs[idx]); - if (prob >= sCutoffP) + if (static_cast(probs[idx]) >= sCutoffP) { threadNumProbsAboveCutoff++; } @@ -122,103 +140,85 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType } #endif - // Adjust the probabilities and cache them - float threadAdjustedProbsSum = 0.0f; - float invTemp = 1.0f / (temperatures != nullptr ? temperatures[batchSlot] : 1.0f); - - #pragma unroll - for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) - { - float prob = static_cast(probs[idx]); - prob = (prob < sCutoffP) ? 0.0f : powf(prob, invTemp); - adjustedProbs[idx] = static_cast(prob); - threadAdjustedProbsSum += prob; - } - - // Find global sum of adjusted probabilities and determine quantization scale factor - threadAdjustedProbsSum = blockReduceSum(threadAdjustedProbsSum); - __shared__ float sAdjustedProbsSum; - __shared__ float sQuantizeScaleFactor; - - if (tid == 0) - { - sAdjustedProbsSum = threadAdjustedProbsSum; - - // Do division with doubles and round down to avoid special cases like - // 4294967295 / 32768 giving us 131072 rather than the desired 131071 - sQuantizeScaleFactor = __double2float_rd((double)(UINT32_MAX - vocabSize) / (double)threadAdjustedProbsSum); - -#if DEBUG_MINP - printf("Batch slot %d adjustedProbsSum %f quantizeScaleFactor %f\n", batchSlot, threadAdjustedProbsSum, sQuantizeScaleFactor); -#endif - } - __syncthreads(); - - // We will now quantize the probabilities to integers to avoid numerical errors - // when trying to find the selected point in the prefix sum of the probabilities. - // We map the adjusted distribution between [0, UINT32_MAX] to avoid overflow. - - // Compute the sum of the quantized probabilities for each thread - uint32_t threadQuantProbsSum = 0; + // Adjust the probabilities and sum the ones passing the cutoff + float threadScaledProbsSum = 0.f; #pragma unroll for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) { - float prob = static_cast(adjustedProbs[idx]); - threadQuantProbsSum += __float2uint_rd(prob * sQuantizeScaleFactor); + auto const prob = static_cast(probs[idx]); + auto const scaledProb = (prob < sCutoffP) ? 0.0f : powf(prob, sInvTemp); + threadScaledProbsSum += scaledProb; } - // Compute a global prefix sum of the quantized probabilities - uint32_t threadQuantProbsPrefix; - uint32_t totalQuantProbsSum; + // Find global sum and prefix sum of adjusted probabilities + using BlockScan = cub::BlockScan; + using BlockShuffle = cub::BlockShuffle; + __shared__ BlockScanShuffleStorage tempStorage; - using BlockScan = cub::BlockScan; - __shared__ typename BlockScan::TempStorage tempStorage; + float threadScaledProbsIncl = 0.f; + float threadScaledProbsExcl = 0.f; + float scaledProbsSum = 0.f; - BlockScan(tempStorage).ExclusiveSum(threadQuantProbsSum, threadQuantProbsPrefix, totalQuantProbsSum); + BlockScan(tempStorage.scan).InclusiveSum(threadScaledProbsSum, threadScaledProbsIncl, scaledProbsSum); + __syncthreads(); // We are aliasing the shared memory + BlockShuffle(tempStorage.shuffle).Offset(threadScaledProbsIncl, threadScaledProbsExcl, -1); // Select a random point in the distribution - __shared__ uint32_t sRandomPoint; + __shared__ float sRandomPoint; if (tid == 0) { - // Rescale uniform random val to be within the sum of quantized probabilities + // Rescale uniform random val to be within the sum of included adjusted probabilities float randomVal = randomVals != nullptr ? randomVals[batchSlot] : curand_uniform(&curandState[batchSlot]); - sRandomPoint = min(__float2uint_rd(randomVal * totalQuantProbsSum), totalQuantProbsSum - 1); + sRandomPoint = randomVal * scaledProbsSum; #if DEBUG_MINP - printf("Batch slot %d totalQuantProbsSum %u randomPoint %u\n", batchSlot, totalQuantProbsSum, sRandomPoint); + printf("Batch slot %d scaledProbsSum %f randomPoint %f\n", batchSlot, scaledProbsSum, sRandomPoint); #endif } __syncthreads(); - // All but one warps will terminate on this condition - if (sRandomPoint < threadQuantProbsPrefix || sRandomPoint >= threadQuantProbsPrefix + threadQuantProbsSum) + // All but one warps will reliably terminate on this condition + if (sRandomPoint < threadScaledProbsExcl || sRandomPoint >= threadScaledProbsIncl) { return; } + // Convert global random point to local range of current thread + float randomLocalOffset = sRandomPoint - threadScaledProbsExcl; + float randomLocalScalar = randomLocalOffset / (threadScaledProbsIncl - threadScaledProbsExcl); + float randomLocalPoint = randomLocalScalar * threadScaledProbsSum; + +#if DEBUG_MINP + printf("Batch slot %d threadScaledProbsExcl %f threadScaledProbsIncl %f threadScaledProbsSum %f\n", + batchSlot, threadScaledProbsExcl, threadScaledProbsIncl, threadScaledProbsSum); + + printf("Batch slot %d randomLocalOffset %f randomLocalScalar %f randomLocalPoint %f\n", + batchSlot, randomLocalOffset, randomLocalScalar, randomLocalPoint); +#endif + // Find the selected token id and write it to the output buffer - threadQuantProbsSum = threadQuantProbsPrefix; + threadScaledProbsSum = 0.f; + + auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot]; + auto* outPtr = outputIdsPtrs == nullptr ? outputIds + batchSlot * maxSeqLen : outputIdsPtrs[batchSlot]; for (int idx = probsBeginIdx + tid; idx < probsEndIdx; idx += THREADBLOCK_SIZE) { - float prob = static_cast(adjustedProbs[idx]); - uint32_t quantProb = __float2uint_rd(prob * sQuantizeScaleFactor); + auto const prob = static_cast(probs[idx]); + auto const scaledProb = (prob < sCutoffP) ? 0.0f : powf(prob, sInvTemp); + threadScaledProbsSum += scaledProb; - if (sRandomPoint < threadQuantProbsSum + quantProb) + // We are summing again in the same order, so this is guaranteed to be entered + if (randomLocalPoint < threadScaledProbsSum) { auto const selectedTokenIdx = idx - probsBeginIdx; - auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot]; - auto* outPtr = outputIdsPtrs == nullptr ? outputIds + batchSlot * maxSeqLen : outputIdsPtrs[batchSlot]; outPtr[curSeqLen] = selectedTokenIdx; #if DEBUG_MINP - printf("Batch slot %d selected token %d original prob %f adjusted prob %f normalized %f\n", - batchSlot, selectedTokenIdx, static_cast(probs[idx]), prob, prob / sAdjustedProbsSum); - - printf("Batch slot %d thread index %d prefix %u sum %u quant prob %u\n", - batchSlot, tid, threadQuantProbsPrefix, threadQuantProbsSum, quantProb); + printf("Batch slot %d selected token %d original prob %f scaled prob %f normalized %f\n", + batchSlot, selectedTokenIdx, prob, scaledProb, scaledProb / scaledProbsSum); #endif if (!returnAllSelectedTokens && sequenceLengths != nullptr && finishedOutput != nullptr && endIds != nullptr) @@ -236,17 +236,16 @@ __global__ void fusedMinPSsampling(T const* probs, T* adjustedProbs, TokenIdType } return; } - - threadQuantProbsSum += quantProb; } + + // This should never be reached + outPtr[curSeqLen] = vocabSize - 1; } template std::vector getMinPWorkspaceSizes(SizeType32 batchSize, SizeType32 vocabSize) { - auto const adjustedProbBufSize = sizeof(T) * batchSize * vocabSize; - - return {adjustedProbBufSize}; + return {}; } template std::vector getMinPWorkspaceSizes(SizeType32 batchSize, SizeType32 vocabSize); @@ -283,17 +282,11 @@ void invokeBatchMinPSampling(MinPSamplingKernelParams const& params, cudaStre TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); params.checkParams(); - auto const workspaceSizes = getMinPWorkspaceSizes(params.batchSize, params.vocabSizePadded); - - std::vector alignedPointers; - calcAlignedPointers(alignedPointers, params.workspace, workspaceSizes); - - auto adjustedProbs = static_cast(alignedPointers[0]); // Sample with Min P filter and late temperature in single pass SizeType32 constexpr SAMPLING_BLOCK_SIZE = 1024; dim3 grid(params.batchSize); - fusedMinPSsampling<<>>(params.probs, adjustedProbs, + fusedMinPSsampling<<>>(params.probs, params.outputIds, params.outputIdsPtrs, params.sequenceLength, params.finishedInput, params.finishedOutput, params.cumLogProbs, params.outputLogProbs, params.vocabSizePadded, params.curandState, params.randomVals, params.minPs, params.temperatures, params.endIds, params.maxBatchSize, params.batchSlots, params.returnAllSelectedTokens, From 19945946ce94b4ca9582fb29b676cc80f5e5a978 Mon Sep 17 00:00:00 2001 From: aikitoria Date: Mon, 13 Jan 2025 20:02:37 +0100 Subject: [PATCH 5/5] Fix batch input --- cpp/tensorrt_llm/layers/minPSamplingLayer.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/layers/minPSamplingLayer.cpp b/cpp/tensorrt_llm/layers/minPSamplingLayer.cpp index 848322a6d..977e366ee 100644 --- a/cpp/tensorrt_llm/layers/minPSamplingLayer.cpp +++ b/cpp/tensorrt_llm/layers/minPSamplingLayer.cpp @@ -113,7 +113,7 @@ void MinPSamplingLayer::setup(SizeType32 batchSize, SizeType32 beamWidth, Ten { auto initWorkspaceSizes = getMinPInitWorkspaceSizes(batchSize); std::vector alignedPointers; - calcAlignedPointers(workspace->getRawWorkspaceDevicePtr(), initWorkspaceSizes)(MinPsPtr); + calcAlignedPointers(workspace->getRawWorkspaceDevicePtr(), initWorkspaceSizes)(MinPsPtr, TemperaturesPtr); DecodingLayerWorkspace::copyToWorkspace( *mBufferManager, runtimeMinP, IBuffer::wrap(MinPsPtr, initWorkspaceSizes[0] / sizeof(*MinPsPtr)));