Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Min-P sampling and late temperature adjustment as a fused sampling layer #2643

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cpp/include/tensorrt_llm/executor/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,11 @@ class DecodingMode
return DecodingMode{kTopKTopP | kUsePenalties | kUseBanTokens | kStandardStopCriteria};
}

static auto constexpr MinP()
{
return DecodingMode{kMinP | (kUsePenalties & ~kUseTemperature) | kUseBanTokens | kStandardStopCriteria};
}

static auto constexpr BeamSearch()
{
return DecodingMode{kBeamSearch | kUsePenalties | kUseBanTokens | kStandardStopCriteria};
Expand Down Expand Up @@ -612,6 +617,12 @@ class DecodingMode
return *this;
}

auto constexpr useMinP()
{
mState = kMinP | (mState & ~kTopKTopP & ~kUseTemperature);
return *this;
}

[[nodiscard]] bool constexpr isAuto() const
{
return anyBitSet(kAuto);
Expand All @@ -637,6 +648,16 @@ class DecodingMode
return allBitSet(kTopKTopP);
}

[[nodiscard]] bool constexpr isMinP() const
{
return anyBitSet(kMinP);
}

[[nodiscard]] bool constexpr isTopKorTopPorMinP() const
{
return anyBitSet(kTopKTopPMinP);
}

[[nodiscard]] bool constexpr isBeamSearch() const
{
return anyBitSet(kBeamSearch);
Expand Down Expand Up @@ -783,6 +804,8 @@ class DecodingMode
static UnderlyingType constexpr kExternalDraftTokens{1u << (kNumFlags + 7)};
static UnderlyingType constexpr kEagle{1u << (kNumFlags + 8)};
static UnderlyingType constexpr kTopKTopP{kTopK | kTopP};
static UnderlyingType constexpr kMinP{1u << (kNumFlags + 9)};
static UnderlyingType constexpr kTopKTopPMinP{kTopK | kTopP | kMinP};

[[nodiscard]] bool constexpr anyBitSet(UnderlyingType bits) const
{
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/tensorrt_llm/layers/defaultDecodingParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ class DefaultDecodingParams
return 1.0e-6f;
}

[[nodiscard]] __host__ __device__ static constexpr float getMinP()
{
return 0.0f;
}

[[nodiscard]] __host__ __device__ static constexpr runtime::TokenIdType getTopPResetId()
{
return -1;
Expand Down
Loading