Skip to content

Commit e5c4193

Browse files
committed
Better sampler params
1 parent 27d6181 commit e5c4193

File tree

2 files changed

+46
-28
lines changed

2 files changed

+46
-28
lines changed

examples/grammar/grammar.cpp

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,7 @@ struct llama_grammar * llama_cached_parse_grammar(const char * grammar_str) {
2828
}
2929

3030
struct llama_sampler_params llama_sampler_default_params() {
31-
struct llama_sampler_params result = {
32-
0.80f, // temp;
33-
1.10f, // repeat_penalty
34-
64, // last_n_repeat
35-
0.00f, // frequency_penalty
36-
0.00f, // presence_penalty
37-
2, // mirostat
38-
5.00f, // mirostat_tau
39-
0.10f, // mirostat_eta
40-
};
41-
return result;
31+
return llama_sampler_params();
4232
}
4333

4434
llama_token llama_grammar_sample_token(struct llama_context * ctx,
@@ -66,8 +56,14 @@ llama_token llama_grammar_sample_token(struct llama_context * ctx,
6656
const int mirostat = params.mirostat;
6757
const float mirostat_tau = params.mirostat_tau;
6858
const float mirostat_eta = params.mirostat_eta;
59+
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k;
60+
const float top_p = params.top_p;
61+
const float tfs_z = params.tfs_z;
62+
const float typical_p = params.typical_p;
63+
const int32_t n_probs = params.n_probs;
64+
6965

70-
llama_token id = 0;
66+
llama_token result = -1;
7167

7268
// apply penalties
7369
if (!last_tokens.empty()) {
@@ -88,27 +84,37 @@ llama_token llama_grammar_sample_token(struct llama_context * ctx,
8884

8985
if (temp <= 0) {
9086
// Greedy sampling
91-
id = llama_sample_token_greedy(ctx, cur_p);
87+
result = llama_sample_token_greedy(ctx, cur_p);
9288
} else {
9389
if (mirostat == 1) {
9490
static float mirostat_mu = 2.0f * mirostat_tau;
9591
const int mirostat_m = 100;
96-
llama_sample_temperature(ctx, cur_p, temp);
97-
id = llama_sample_token_mirostat(ctx, cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
92+
llama_sample_temp(ctx, cur_p, temp);
93+
result = llama_sample_token_mirostat(ctx, cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
9894
} else if (mirostat == 2) {
9995
static float mirostat_mu = 2.0f * mirostat_tau;
100-
llama_sample_temperature(ctx, cur_p, temp);
101-
id = llama_sample_token_mirostat_v2(ctx, cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
96+
llama_sample_temp(ctx, cur_p, temp);
97+
result = llama_sample_token_mirostat_v2(ctx, cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
98+
} else {
99+
// Temperature sampling
100+
size_t min_keep = std::max(1, n_probs);
101+
llama_sample_top_k(ctx, cur_p, top_k, min_keep);
102+
llama_sample_tail_free(ctx, cur_p, tfs_z, min_keep);
103+
llama_sample_typical(ctx, cur_p, typical_p, min_keep);
104+
llama_sample_top_p(ctx, cur_p, top_p, min_keep);
105+
llama_sample_temp(ctx, cur_p, temp);
106+
result = llama_sample_token(ctx, cur_p);
102107
}
103108
}
109+
104110
// printf("`%d`", candidates_p.size);
105111

106112
if (grammar != NULL) {
107-
llama_grammar_accept_token(ctx, grammar, id);
113+
llama_grammar_accept_token(ctx, grammar, result);
108114
}
109115

110116
last_tokens.erase(last_tokens.begin());
111-
last_tokens.push_back(id);
117+
last_tokens.push_back(result);
112118

113-
return id;
119+
return result;
114120
}

examples/grammar/grammar.h

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,38 @@
33

44
#include <string>
55
#include <vector>
6+
#include <cstdint>
67
#include <unordered_map>
78
#include <stddef.h>
89
#include <stdint.h>
910
#include <stdbool.h>
1011

12+
1113
#include "llama.h"
1214
#include "grammar-parser.h"
1315

1416
#ifdef __cplusplus
1517
extern "C" {
1618
#endif
19+
// llama_sampler.h
20+
21+
#pragma once
22+
23+
1724
struct llama_sampler_params {
18-
float temp;
19-
float repeat_penalty;
20-
int32_t repeat_last_n;
21-
float frequency_penalty;
22-
float presence_penalty;
23-
int32_t mirostat;
24-
float mirostat_tau;
25-
float mirostat_eta;
25+
float temp = 0.80f; // Temperature
26+
float repeat_penalty = 1.10f; // Penalty for repeated tokens
27+
int32_t repeat_last_n = 64; // Number of tokens to consider for repeat penalty
28+
float frequency_penalty = 0.00f; // Penalty for frequent tokens
29+
float presence_penalty = 0.00f; // Penalty for present tokens
30+
int32_t mirostat = 2; // Mirostat version (0 = disabled, 1 = mirostat, 2 = mirostat 2.0)
31+
float mirostat_tau = 5.00f; // Mirostat target entropy
32+
float mirostat_eta = 0.10f; // Mirostat learning rate
33+
int32_t top_k = 40; // Top-K for sampling
34+
float top_p = 0.95f; // Top-P for sampling
35+
float tfs_z = 1.0f; // TFS-Z value
36+
float typical_p = 1.0f; // Typical-P value
37+
int32_t n_probs = 0; // Number of probabilities to output (0 for no output)
2638
};
2739

2840
llama_sampler_params llama_sampler_default_params();

0 commit comments

Comments
 (0)