@@ -28,17 +28,7 @@ struct llama_grammar * llama_cached_parse_grammar(const char * grammar_str) {
28
28
}
29
29
30
30
struct llama_sampler_params llama_sampler_default_params () {
31
- struct llama_sampler_params result = {
32
- 0 .80f , // temp;
33
- 1 .10f , // repeat_penalty
34
- 64 , // last_n_repeat
35
- 0 .00f , // frequency_penalty
36
- 0 .00f , // presence_penalty
37
- 2 , // mirostat
38
- 5 .00f , // mirostat_tau
39
- 0 .10f , // mirostat_eta
40
- };
41
- return result;
31
+ return llama_sampler_params ();
42
32
}
43
33
44
34
llama_token llama_grammar_sample_token (struct llama_context * ctx,
@@ -66,8 +56,14 @@ llama_token llama_grammar_sample_token(struct llama_context * ctx,
66
56
const int mirostat = params.mirostat ;
67
57
const float mirostat_tau = params.mirostat_tau ;
68
58
const float mirostat_eta = params.mirostat_eta ;
59
+ const int32_t top_k = params.top_k <= 0 ? llama_n_vocab (llama_get_model (ctx)) : params.top_k ;
60
+ const float top_p = params.top_p ;
61
+ const float tfs_z = params.tfs_z ;
62
+ const float typical_p = params.typical_p ;
63
+ const int32_t n_probs = params.n_probs ;
64
+
69
65
70
- llama_token id = 0 ;
66
+ llama_token result = - 1 ;
71
67
72
68
// apply penalties
73
69
if (!last_tokens.empty ()) {
@@ -88,27 +84,37 @@ llama_token llama_grammar_sample_token(struct llama_context * ctx,
88
84
89
85
if (temp <= 0 ) {
90
86
// Greedy sampling
91
- id = llama_sample_token_greedy (ctx, cur_p);
87
+ result = llama_sample_token_greedy (ctx, cur_p);
92
88
} else {
93
89
if (mirostat == 1 ) {
94
90
static float mirostat_mu = 2 .0f * mirostat_tau;
95
91
const int mirostat_m = 100 ;
96
- llama_sample_temperature (ctx, cur_p, temp);
97
- id = llama_sample_token_mirostat (ctx, cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
92
+ llama_sample_temp (ctx, cur_p, temp);
93
+ result = llama_sample_token_mirostat (ctx, cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
98
94
} else if (mirostat == 2 ) {
99
95
static float mirostat_mu = 2 .0f * mirostat_tau;
100
- llama_sample_temperature (ctx, cur_p, temp);
101
- id = llama_sample_token_mirostat_v2 (ctx, cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
96
+ llama_sample_temp (ctx, cur_p, temp);
97
+ result = llama_sample_token_mirostat_v2 (ctx, cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
98
+ } else {
99
+ // Temperature sampling
100
+ size_t min_keep = std::max (1 , n_probs);
101
+ llama_sample_top_k (ctx, cur_p, top_k, min_keep);
102
+ llama_sample_tail_free (ctx, cur_p, tfs_z, min_keep);
103
+ llama_sample_typical (ctx, cur_p, typical_p, min_keep);
104
+ llama_sample_top_p (ctx, cur_p, top_p, min_keep);
105
+ llama_sample_temp (ctx, cur_p, temp);
106
+ result = llama_sample_token (ctx, cur_p);
102
107
}
103
108
}
109
+
104
110
// printf("`%d`", candidates_p.size);
105
111
106
112
if (grammar != NULL ) {
107
- llama_grammar_accept_token (ctx, grammar, id );
113
+ llama_grammar_accept_token (ctx, grammar, result );
108
114
}
109
115
110
116
last_tokens.erase (last_tokens.begin ());
111
- last_tokens.push_back (id );
117
+ last_tokens.push_back (result );
112
118
113
- return id ;
119
+ return result ;
114
120
}
0 commit comments