Skip to content

Commit 22cedad

Browse files
xu-songabetlen
andauthored
fix: Fix memory allocation of ndarray (#1704)
* Fix memory allocation of ndarray * Add basic LlamaState tests * Improve LlamaState test and fix rng / seed --------- Co-authored-by: Andrei <[email protected]>
1 parent 9b64bb5 commit 22cedad

File tree

2 files changed

+64
-15
lines changed

2 files changed

+64
-15
lines changed

llama_cpp/llama.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import ctypes
99
import typing
10+
import random
1011
import fnmatch
1112
import warnings
1213
import contextlib
@@ -301,9 +302,11 @@ def __init__(
301302
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
302303
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
303304

305+
# Used by the sampler
306+
self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED
307+
304308
# Context Params
305309
self.context_params = llama_cpp.llama_context_default_params()
306-
self.context_params.seed = seed
307310
self.context_params.n_ctx = n_ctx
308311
self.context_params.n_batch = self.n_batch
309312
self.context_params.n_threads = self.n_threads
@@ -613,8 +616,7 @@ def set_seed(self, seed: int):
613616
Args:
614617
seed: The random seed.
615618
"""
616-
# TODO: Fix this
617-
# llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed)
619+
self._seed = seed
618620

619621
def reset(self):
620622
"""Reset the model state."""
@@ -672,7 +674,6 @@ def _init_sampler(
672674
penalize_nl: bool = True,
673675
logits_processor: Optional[LogitsProcessorList] = None,
674676
grammar: Optional[LlamaGrammar] = None,
675-
seed: Optional[int] = None,
676677
):
677678
sampler = internals.LlamaSampler()
678679

@@ -715,22 +716,22 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
715716

716717
if temp < 0.0:
717718
sampler.add_softmax()
718-
sampler.add_dist(seed or llama_cpp.LLAMA_DEFAULT_SEED)
719+
sampler.add_dist(self._seed)
719720
elif temp == 0.0:
720721
sampler.add_greedy()
721722
else:
722723
if mirostat_mode == 1:
723724
mirostat_m = 100
724725
sampler.add_mirostat(
725726
self._n_vocab,
726-
seed or llama_cpp.LLAMA_DEFAULT_SEED,
727+
self._seed,
727728
mirostat_tau,
728729
mirostat_eta,
729730
mirostat_m,
730731
)
731732
elif mirostat_mode == 2:
732733
sampler.add_mirostat_v2(
733-
seed or llama_cpp.LLAMA_DEFAULT_SEED,
734+
self._seed,
734735
mirostat_tau,
735736
mirostat_eta,
736737
)
@@ -743,7 +744,7 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
743744
sampler.add_top_p(top_p, min_keep)
744745
sampler.add_min_p(min_p, min_keep)
745746
sampler.add_temp(temp)
746-
sampler.add_dist(seed or llama_cpp.LLAMA_DEFAULT_SEED)
747+
sampler.add_dist(self._seed)
747748
return sampler
748749

749750
def sample(
@@ -826,7 +827,6 @@ def generate(
826827
logits_processor: Optional[LogitsProcessorList] = None,
827828
stopping_criteria: Optional[StoppingCriteriaList] = None,
828829
grammar: Optional[LlamaGrammar] = None,
829-
seed: Optional[int] = None,
830830
) -> Generator[int, Optional[Sequence[int]], None]:
831831
"""Create a generator of tokens from a prompt.
832832
@@ -865,7 +865,6 @@ def generate(
865865
penalize_nl=penalize_nl,
866866
logits_processor=logits_processor,
867867
grammar=grammar,
868-
seed=seed,
869868
)
870869

871870
# Check for kv cache prefix match
@@ -1301,9 +1300,10 @@ def logit_bias_processor(
13011300
if self.verbose:
13021301
print("Llama._create_completion: cache miss", file=sys.stderr)
13031302

1304-
# TODO: Fix this
1305-
# if seed is not None:
1306-
# self._ctx.set_rng_seed(seed)
1303+
if seed is not None:
1304+
self.set_seed(seed)
1305+
else:
1306+
self.set_seed(random.Random(self._seed).randint(0, 2 ** 32))
13071307

13081308
finish_reason = "length"
13091309
multibyte_fix = 0
@@ -1324,7 +1324,6 @@ def logit_bias_processor(
13241324
stopping_criteria=stopping_criteria,
13251325
logits_processor=logits_processor,
13261326
grammar=grammar,
1327-
seed=seed,
13281327
):
13291328
if llama_cpp.llama_token_is_eog(self._model.model, token):
13301329
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
@@ -2136,14 +2135,17 @@ def save_state(self) -> LlamaState:
21362135
n_tokens=self.n_tokens,
21372136
llama_state=bytes(llama_state_compact),
21382137
llama_state_size=n_bytes,
2138+
seed=self._seed,
21392139
)
21402140

21412141
def load_state(self, state: LlamaState) -> None:
21422142
# Only filling in up to `n_tokens` and then zero-ing out the rest
21432143
self.scores[: state.n_tokens, :] = state.scores.copy()
2144-
self.scores[state.n_tokens :, :] = 0.0
2144+
rest = self.scores[state.n_tokens :, :]
2145+
rest[rest > 0] = 0.0
21452146
self.input_ids = state.input_ids.copy()
21462147
self.n_tokens = state.n_tokens
2148+
self._seed = state.seed
21472149
state_size = state.llama_state_size
21482150
LLamaStateArrayType = ctypes.c_uint8 * state_size
21492151
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
@@ -2321,12 +2323,14 @@ def __init__(
23212323
n_tokens: int,
23222324
llama_state: bytes,
23232325
llama_state_size: int,
2326+
seed: int,
23242327
):
23252328
self.input_ids = input_ids
23262329
self.scores = scores
23272330
self.n_tokens = n_tokens
23282331
self.llama_state = llama_state
23292332
self.llama_state_size = llama_state_size
2333+
self.seed = seed
23302334

23312335

23322336
LogitsProcessor = Callable[

tests/test_llama.py

+45
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,48 @@ def logit_processor_func(input_ids, logits):
171171
logits_processor=logit_processors
172172
)
173173
assert output["choices"][0]["text"].lower().startswith("rot")
174+
175+
model.set_seed(1337)
176+
177+
state = model.save_state()
178+
179+
output = model.create_completion(
180+
"Pick a number from 1 to 10?:\n",
181+
max_tokens=4,
182+
top_k=50,
183+
top_p=0.9,
184+
temperature=0.8,
185+
grammar=llama_cpp.LlamaGrammar.from_string("""
186+
root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10"
187+
""")
188+
)
189+
number_1 = output["choices"][0]["text"]
190+
191+
output = model.create_completion(
192+
"Pick a number from 1 to 10?:\n",
193+
max_tokens=4,
194+
top_k=50,
195+
top_p=0.9,
196+
temperature=0.8,
197+
grammar=llama_cpp.LlamaGrammar.from_string("""
198+
root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10"
199+
""")
200+
)
201+
number_2 = output["choices"][0]["text"]
202+
203+
model.load_state(state)
204+
205+
output = model.create_completion(
206+
"Pick a number from 1 to 10?:\n",
207+
max_tokens=4,
208+
top_k=50,
209+
top_p=0.9,
210+
temperature=0.8,
211+
grammar=llama_cpp.LlamaGrammar.from_string("""
212+
root ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" | "10"
213+
""")
214+
)
215+
number_3 = output["choices"][0]["text"]
216+
217+
assert number_1 != number_2
218+
assert number_1 == number_3

0 commit comments

Comments
 (0)