77import json
88import ctypes
99import typing
10+ import random
1011import fnmatch
1112import warnings
1213import contextlib
@@ -301,9 +302,11 @@ def __init__(
301302 self .n_threads = n_threads or max (multiprocessing .cpu_count () // 2 , 1 )
302303 self .n_threads_batch = n_threads_batch or multiprocessing .cpu_count ()
303304
305+ # Used by the sampler
306+ self ._seed = seed or llama_cpp .LLAMA_DEFAULT_SEED
307+
304308 # Context Params
305309 self .context_params = llama_cpp .llama_context_default_params ()
306- self .context_params .seed = seed
307310 self .context_params .n_ctx = n_ctx
308311 self .context_params .n_batch = self .n_batch
309312 self .context_params .n_threads = self .n_threads
@@ -613,8 +616,7 @@ def set_seed(self, seed: int):
613616 Args:
614617 seed: The random seed.
615618 """
616- # TODO: Fix this
617- # llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed)
619+ self ._seed = seed
618620
619621 def reset (self ):
620622 """Reset the model state."""
@@ -672,7 +674,6 @@ def _init_sampler(
672674 penalize_nl : bool = True ,
673675 logits_processor : Optional [LogitsProcessorList ] = None ,
674676 grammar : Optional [LlamaGrammar ] = None ,
675- seed : Optional [int ] = None ,
676677 ):
677678 sampler = internals .LlamaSampler ()
678679
@@ -715,22 +716,22 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
715716
716717 if temp < 0.0 :
717718 sampler .add_softmax ()
718- sampler .add_dist (seed or llama_cpp . LLAMA_DEFAULT_SEED )
719+ sampler .add_dist (self . _seed )
719720 elif temp == 0.0 :
720721 sampler .add_greedy ()
721722 else :
722723 if mirostat_mode == 1 :
723724 mirostat_m = 100
724725 sampler .add_mirostat (
725726 self ._n_vocab ,
726- seed or llama_cpp . LLAMA_DEFAULT_SEED ,
727+ self . _seed ,
727728 mirostat_tau ,
728729 mirostat_eta ,
729730 mirostat_m ,
730731 )
731732 elif mirostat_mode == 2 :
732733 sampler .add_mirostat_v2 (
733- seed or llama_cpp . LLAMA_DEFAULT_SEED ,
734+ self . _seed ,
734735 mirostat_tau ,
735736 mirostat_eta ,
736737 )
@@ -743,7 +744,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
743744 sampler .add_top_p (top_p , min_keep )
744745 sampler .add_min_p (min_p , min_keep )
745746 sampler .add_temp (temp )
746- sampler .add_dist (seed or llama_cpp . LLAMA_DEFAULT_SEED )
747+ sampler .add_dist (self . _seed )
747748 return sampler
748749
749750 def sample (
@@ -826,7 +827,6 @@ def generate(
826827 logits_processor : Optional [LogitsProcessorList ] = None ,
827828 stopping_criteria : Optional [StoppingCriteriaList ] = None ,
828829 grammar : Optional [LlamaGrammar ] = None ,
829- seed : Optional [int ] = None ,
830830 ) -> Generator [int , Optional [Sequence [int ]], None ]:
831831 """Create a generator of tokens from a prompt.
832832
@@ -865,7 +865,6 @@ def generate(
865865 penalize_nl = penalize_nl ,
866866 logits_processor = logits_processor ,
867867 grammar = grammar ,
868- seed = seed ,
869868 )
870869
871870 # Check for kv cache prefix match
@@ -1301,9 +1300,10 @@ def logit_bias_processor(
13011300 if self .verbose :
13021301 print ("Llama._create_completion: cache miss" , file = sys .stderr )
13031302
1304- # TODO: Fix this
1305- # if seed is not None:
1306- # self._ctx.set_rng_seed(seed)
1303+ if seed is not None :
1304+ self .set_seed (seed )
1305+ else :
1306+ self .set_seed (random .Random (self ._seed ).randint (0 , 2 ** 32 ))
13071307
13081308 finish_reason = "length"
13091309 multibyte_fix = 0
@@ -1324,7 +1324,6 @@ def logit_bias_processor(
13241324 stopping_criteria = stopping_criteria ,
13251325 logits_processor = logits_processor ,
13261326 grammar = grammar ,
1327- seed = seed ,
13281327 ):
13291328 if llama_cpp .llama_token_is_eog (self ._model .model , token ):
13301329 text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
@@ -2136,14 +2135,17 @@ def save_state(self) -> LlamaState:
21362135 n_tokens = self .n_tokens ,
21372136 llama_state = bytes (llama_state_compact ),
21382137 llama_state_size = n_bytes ,
2138+ seed = self ._seed ,
21392139 )
21402140
21412141 def load_state (self , state : LlamaState ) -> None :
21422142 # Only filling in up to `n_tokens` and then zero-ing out the rest
21432143 self .scores [: state .n_tokens , :] = state .scores .copy ()
2144- self .scores [state .n_tokens :, :] = 0.0
2144+ rest = self .scores [state .n_tokens :, :]
2145+ rest [rest > 0 ] = 0.0
21452146 self .input_ids = state .input_ids .copy ()
21462147 self .n_tokens = state .n_tokens
2148+ self ._seed = state .seed
21472149 state_size = state .llama_state_size
21482150 LLamaStateArrayType = ctypes .c_uint8 * state_size
21492151 llama_state = LLamaStateArrayType .from_buffer_copy (state .llama_state )
@@ -2321,12 +2323,14 @@ def __init__(
23212323 n_tokens : int ,
23222324 llama_state : bytes ,
23232325 llama_state_size : int ,
2326+ seed : int ,
23242327 ):
23252328 self .input_ids = input_ids
23262329 self .scores = scores
23272330 self .n_tokens = n_tokens
23282331 self .llama_state = llama_state
23292332 self .llama_state_size = llama_state_size
2333+ self .seed = seed
23302334
23312335
23322336LogitsProcessor = Callable [
0 commit comments