Skip to content

Commit 371ecd1

Browse files
authored
Fix: per-prediction seed (#198)
1 parent d8c8547 commit 371ecd1

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

binding.cpp

+10-4
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ int get_embeddings(void* params_ptr, void* state_pr, float * res_embeddings) {
4444
params.seed = time(NULL);
4545
}
4646

47-
std::mt19937 rng(params.seed);
47+
// no need for a rng
48+
// std::mt19937 rng(params.seed);
4849

4950
int n_past = 0;
5051

@@ -127,7 +128,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
127128
params.seed = time(NULL);
128129
}
129130

130-
std::mt19937 rng(params.seed);
131+
// no need for a rng
132+
// std::mt19937 rng(params.seed);
131133

132134
if (params.rope_freq_base != 10000.0) {
133135
fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
@@ -171,7 +173,8 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
171173
return 1;
172174
}
173175
session_tokens.resize(n_token_count_out);
174-
llama_set_rng_seed(ctx, params.seed);
176+
// no need to set the seed here --- we'll always set it later
177+
// llama_set_rng_seed(ctx, params.seed);
175178
if (debug) {
176179
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
177180
}
@@ -311,6 +314,9 @@ int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug) {
311314
llama_reset_timings(ctx);
312315
}
313316

317+
// set the seed before actually predicting
318+
llama_set_rng_seed(ctx, params.seed);
319+
314320
while (n_remain != 0) {
315321
// predict
316322
if (embd.size() > 0) {
@@ -878,4 +884,4 @@ void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f
878884
state->model= model;
879885
return state;
880886
}
881-
*/
887+
*/

examples/main.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ var (
1616
threads = 4
1717
tokens = 128
1818
gpulayers = 0
19+
seed = -1
1920
)
2021

2122
func main() {
@@ -26,6 +27,7 @@ func main() {
2627
flags.IntVar(&gpulayers, "ngl", 0, "Number of GPU layers to use")
2728
flags.IntVar(&threads, "t", runtime.NumCPU(), "number of threads to use during computation")
2829
flags.IntVar(&tokens, "n", 512, "number of tokens to predict")
30+
flags.IntVar(&seed, "s", -1, "predict RNG seed, -1 for random seed")
2931

3032
err := flags.Parse(os.Args[1:])
3133
if err != nil {
@@ -47,7 +49,7 @@ func main() {
4749
_, err := l.Predict(text, llama.Debug, llama.SetTokenCallback(func(token string) bool {
4850
fmt.Print(token)
4951
return true
50-
}), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
52+
}), llama.SetTokens(tokens), llama.SetThreads(threads), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"), llama.SetSeed(seed))
5153
if err != nil {
5254
panic(err)
5355
}

options.go

+2
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ var DefaultOptions PredictOptions = PredictOptions{
9191
MirostatTAU: 5.0,
9292
MirostatETA: 0.1,
9393
MMap: true,
94+
RopeFreqBase: 10000,
95+
RopeFreqScale: 1.0,
9496
}
9597

9698
func SetMulMatQ(b bool) ModelOption {

0 commit comments

Comments
 (0)