7
7
import json
8
8
import ctypes
9
9
import typing
10
+ import random
10
11
import fnmatch
11
12
import warnings
12
13
import contextlib
@@ -301,9 +302,11 @@ def __init__(
301
302
self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
302
303
self .n_threads_batch = n_threads_batch or multiprocessing .cpu_count ()
303
304
305
+ # Used by the sampler
306
+ self ._seed = seed or llama_cpp .LLAMA_DEFAULT_SEED
307
+
304
308
# Context Params
305
309
self .context_params = llama_cpp .llama_context_default_params ()
306
- self .context_params .seed = seed
307
310
self .context_params .n_ctx = n_ctx
308
311
self .context_params .n_batch = self .n_batch
309
312
self .context_params .n_threads = self .n_threads
@@ -613,8 +616,7 @@ def set_seed(self, seed: int):
613
616
Args:
614
617
seed: The random seed.
615
618
"""
616
- # TODO: Fix this
617
- # llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed)
619
+ self ._seed = seed
618
620
619
621
def reset (self ):
620
622
"""Reset the model state."""
@@ -672,7 +674,6 @@ def _init_sampler(
672
674
penalize_nl : bool = True ,
673
675
logits_processor : Optional [LogitsProcessorList ] = None ,
674
676
grammar : Optional [LlamaGrammar ] = None ,
675
- seed : Optional [int ] = None ,
676
677
):
677
678
sampler = internals .LlamaSampler ()
678
679
@@ -715,22 +716,22 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
715
716
716
717
if temp < 0.0 :
717
718
sampler .add_softmax ()
718
- sampler .add_dist (seed or llama_cpp . LLAMA_DEFAULT_SEED )
719
+ sampler .add_dist (self . _seed )
719
720
elif temp == 0.0 :
720
721
sampler .add_greedy ()
721
722
else :
722
723
if mirostat_mode == 1 :
723
724
mirostat_m = 100
724
725
sampler .add_mirostat (
725
726
self ._n_vocab ,
726
- seed or llama_cpp . LLAMA_DEFAULT_SEED ,
727
+ self . _seed ,
727
728
mirostat_tau ,
728
729
mirostat_eta ,
729
730
mirostat_m ,
730
731
)
731
732
elif mirostat_mode == 2 :
732
733
sampler .add_mirostat_v2 (
733
- seed or llama_cpp . LLAMA_DEFAULT_SEED ,
734
+ self . _seed ,
734
735
mirostat_tau ,
735
736
mirostat_eta ,
736
737
)
@@ -743,7 +744,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
743
744
sampler .add_top_p (top_p , min_keep )
744
745
sampler .add_min_p (min_p , min_keep )
745
746
sampler .add_temp (temp )
746
- sampler .add_dist (seed or llama_cpp . LLAMA_DEFAULT_SEED )
747
+ sampler .add_dist (self . _seed )
747
748
return sampler
748
749
749
750
def sample (
@@ -826,7 +827,6 @@ def generate(
826
827
logits_processor : Optional [LogitsProcessorList ] = None ,
827
828
stopping_criteria : Optional [StoppingCriteriaList ] = None ,
828
829
grammar : Optional [LlamaGrammar ] = None ,
829
- seed : Optional [int ] = None ,
830
830
) -> Generator [int , Optional [Sequence [int ]], None ]:
831
831
"""Create a generator of tokens from a prompt.
832
832
@@ -865,7 +865,6 @@ def generate(
865
865
penalize_nl = penalize_nl ,
866
866
logits_processor = logits_processor ,
867
867
grammar = grammar ,
868
- seed = seed ,
869
868
)
870
869
871
870
# Check for kv cache prefix match
@@ -1301,9 +1300,10 @@ def logit_bias_processor(
1301
1300
if self .verbose :
1302
1301
print ("Llama._create_completion: cache miss" , file = sys .stderr )
1303
1302
1304
- # TODO: Fix this
1305
- # if seed is not None:
1306
- # self._ctx.set_rng_seed(seed)
1303
+ if seed is not None :
1304
+ self .set_seed (seed )
1305
+ else :
1306
+ self .set_seed (random .Random (self ._seed ).randint (0 , 2 ** 32 ))
1307
1307
1308
1308
finish_reason = "length"
1309
1309
multibyte_fix = 0
@@ -1324,7 +1324,6 @@ def logit_bias_processor(
1324
1324
stopping_criteria = stopping_criteria ,
1325
1325
logits_processor = logits_processor ,
1326
1326
grammar = grammar ,
1327
- seed = seed ,
1328
1327
):
1329
1328
if llama_cpp .llama_token_is_eog (self ._model .model , token ):
1330
1329
text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
@@ -2136,14 +2135,17 @@ def save_state(self) -> LlamaState:
2136
2135
n_tokens = self .n_tokens ,
2137
2136
llama_state = bytes (llama_state_compact ),
2138
2137
llama_state_size = n_bytes ,
2138
+ seed = self ._seed ,
2139
2139
)
2140
2140
2141
2141
def load_state (self , state : LlamaState ) -> None :
2142
2142
# Only filling in up to `n_tokens` and then zero-ing out the rest
2143
2143
self .scores [: state .n_tokens , :] = state .scores .copy ()
2144
- self .scores [state .n_tokens :, :] = 0.0
2144
+ rest = self .scores [state .n_tokens :, :]
2145
+ rest [rest > 0 ] = 0.0
2145
2146
self .input_ids = state .input_ids .copy ()
2146
2147
self .n_tokens = state .n_tokens
2148
+ self ._seed = state .seed
2147
2149
state_size = state .llama_state_size
2148
2150
LLamaStateArrayType = ctypes .c_uint8 * state_size
2149
2151
llama_state = LLamaStateArrayType .from_buffer_copy (state .llama_state )
@@ -2321,12 +2323,14 @@ def __init__(
2321
2323
n_tokens : int ,
2322
2324
llama_state : bytes ,
2323
2325
llama_state_size : int ,
2326
+ seed : int ,
2324
2327
):
2325
2328
self .input_ids = input_ids
2326
2329
self .scores = scores
2327
2330
self .n_tokens = n_tokens
2328
2331
self .llama_state = llama_state
2329
2332
self .llama_state_size = llama_state_size
2333
+ self .seed = seed
2330
2334
2331
2335
2332
2336
LogitsProcessor = Callable [
0 commit comments