Skip to content

Commit 3ffc680

Browse files
committed
Move self._vocab to llama.py
1 parent 4dc2609 commit 3ffc680

File tree

3 files changed

+46
-56
lines changed

3 files changed

+46
-56
lines changed

llama_cpp/_internals.py

Lines changed: 39 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def __init__(
4343
self._exit_stack = ExitStack()
4444

4545
model = None
46-
vocab = None
4746

4847
if not os.path.exists(path_model):
4948
raise ValueError(f"Model path does not exist: {path_model}")
@@ -58,24 +57,12 @@ def __init__(
5857

5958
self.model = model
6059

61-
vocab = llama_cpp.llama_model_get_vocab(self.model)
62-
63-
if vocab is None:
64-
raise ValueError(f"Failed to load vocab from file: {path_model}")
65-
66-
self.vocab = vocab
67-
6860
def free_model():
6961
if self.model is None:
7062
return
7163
llama_cpp.llama_model_free(self.model)
7264
self.model = None
7365

74-
if self.vocab is None:
75-
return
76-
llama_cpp.llama_model_free(self.vocab)
77-
self.vocab = None
78-
7966
self._exit_stack.callback(free_model)
8067

8168
def close(self):
@@ -84,11 +71,11 @@ def close(self):
8471
def __del__(self):
8572
self.close()
8673

87-
def vocab_type(self) -> int:
88-
return llama_cpp.llama_vocab_type(self.vocab)
74+
def vocab_type(self, _vocab:llama_cpp.llama_vocab_p) -> int:
75+
return llama_cpp.llama_vocab_type(_vocab)
8976

90-
def n_vocab(self) -> int:
91-
return llama_cpp.llama_vocab_n_tokens(self.vocab)
77+
def n_vocab(self, _vocab:llama_cpp.llama_vocab_p) -> int:
78+
return llama_cpp.llama_vocab_n_tokens(_vocab)
9279

9380
def n_ctx_train(self) -> int:
9481
return llama_cpp.llama_model_n_ctx_train(self.model)
@@ -112,66 +99,66 @@ def n_params(self) -> int:
11299

113100
# Vocab
114101

115-
def token_get_text(self, token: int) -> str:
116-
return llama_cpp.llama_vocab_get_text(self.vocab, token).decode("utf-8")
102+
def token_get_text(self, _vocab:llama_cpp.llama_vocab_p, token: int) -> str:
103+
return llama_cpp.llama_vocab_get_text(_vocab, token).decode("utf-8")
117104

118-
def token_get_score(self, token: int) -> float:
119-
return llama_cpp.llama_vocab_get_score(self.vocab, token)
105+
def token_get_score(self, _vocab:llama_cpp.llama_vocab_p, token: int) -> float:
106+
return llama_cpp.llama_vocab_get_score(_vocab, token)
120107

121-
def token_get_attr(self, token: int) -> int:
122-
return llama_cpp.llama_vocab_get_attr(self.vocab, token)
108+
def token_get_attr(self, _vocab:llama_cpp.llama_vocab_p, token: int) -> int:
109+
return llama_cpp.llama_vocab_get_attr(_vocab, token)
123110

124111
# Special tokens
125112

126-
def token_bos(self) -> int:
127-
return llama_cpp.llama_vocab_bos(self.vocab)
113+
def token_bos(self, _vocab:llama_cpp.llama_vocab_p) -> int:
114+
return llama_cpp.llama_vocab_bos(_vocab)
128115

129-
def token_eos(self) -> int:
130-
return llama_cpp.llama_vocab_eos(self.vocab)
116+
def token_eos(self, _vocab:llama_cpp.llama_vocab_p) -> int:
117+
return llama_cpp.llama_vocab_eos(_vocab)
131118

132-
def token_eot(self) -> int:
133-
return llama_cpp.llama_vocab_eot(self.vocab)
119+
def token_eot(self, _vocab:llama_cpp.llama_vocab_p) -> int:
120+
return llama_cpp.llama_vocab_eot(_vocab)
134121

135-
def token_cls(self) -> int:
136-
return llama_cpp.llama_vocab_cls(self.vocab)
122+
def token_cls(self, _vocab:llama_cpp.llama_vocab_p) -> int:
123+
return llama_cpp.llama_vocab_cls(_vocab)
137124

138-
def token_sep(self) -> int:
139-
return llama_cpp.llama_vocab_sep(self.vocab)
125+
def token_sep(self, _vocab:llama_cpp.llama_vocab_p) -> int:
126+
return llama_cpp.llama_vocab_sep(_vocab)
140127

141-
def token_nl(self) -> int:
142-
return llama_cpp.llama_vocab_nl(self.vocab)
128+
def token_nl(self, _vocab:llama_cpp.llama_vocab_p) -> int:
129+
return llama_cpp.llama_vocab_nl(_vocab)
143130

144-
def token_pad(self) -> int:
145-
return llama_cpp.llama_vocab_pad(self.vocab)
131+
def token_pad(self, _vocab:llama_cpp.llama_vocab_p) -> int:
132+
return llama_cpp.llama_vocab_pad(_vocab)
146133

147-
def token_prefix(self) -> int:
148-
return llama_cpp.llama_vocab_fim_pre(self.vocab)
134+
def token_prefix(self, _vocab:llama_cpp.llama_vocab_p) -> int:
135+
return llama_cpp.llama_vocab_fim_pre(_vocab)
149136

150-
def token_middle(self) -> int:
151-
return llama_cpp.llama_vocab_fim_mid(self.vocab)
137+
def token_middle(self, _vocab:llama_cpp.llama_vocab_p) -> int:
138+
return llama_cpp.llama_vocab_fim_mid(_vocab)
152139

153-
def token_suffix(self) -> int:
154-
return llama_cpp.llama_vocab_fim_suf(self.vocab)
140+
def token_suffix(self, _vocab:llama_cpp.llama_vocab_p) -> int:
141+
return llama_cpp.llama_vocab_fim_suf(_vocab)
155142

156-
def add_bos_token(self) -> bool:
157-
return llama_cpp.llama_vocab_get_add_bos(self.vocab)
143+
def add_bos_token(self, _vocab:llama_cpp.llama_vocab_p) -> bool:
144+
return llama_cpp.llama_vocab_get_add_bos(_vocab)
158145

159-
def add_eos_token(self) -> bool:
160-
return llama_cpp.llama_vocab_get_add_eos(self.vocab)
146+
def add_eos_token(self, _vocab:llama_cpp.llama_vocab_p) -> bool:
147+
return llama_cpp.llama_vocab_get_add_eos(_vocab)
161148

162149
# Tokenization
163150

164-
def tokenize(self, text: bytes, add_bos: bool, special: bool):
151+
def tokenize(self, _vocab:llama_cpp.llama_vocab_p, text: bytes, add_bos: bool, special: bool):
165152
n_ctx = self.n_ctx_train()
166153
tokens = (llama_cpp.llama_token * n_ctx)()
167154
n_tokens = llama_cpp.llama_tokenize(
168-
self.vocab, text, len(text), tokens, n_ctx, add_bos, special
155+
_vocab, text, len(text), tokens, n_ctx, add_bos, special
169156
)
170157
if n_tokens < 0:
171158
n_tokens = abs(n_tokens)
172159
tokens = (llama_cpp.llama_token * n_tokens)()
173160
n_tokens = llama_cpp.llama_tokenize(
174-
self.vocab, text, len(text), tokens, n_tokens, add_bos, special
161+
_vocab, text, len(text), tokens, n_tokens, add_bos, special
175162
)
176163
if n_tokens < 0:
177164
raise RuntimeError(
@@ -618,10 +605,11 @@ def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
618605
def sample(
619606
self,
620607
ctx_main: LlamaContext,
608+
_vocab:llama_cpp.llama_vocab_p,
621609
idx: int = 0,
622610
logits_array: Optional[npt.NDArray[np.single]] = None,
623611
):
624-
n_vocab = ctx_main.model.n_vocab()
612+
n_vocab = ctx_main.model.n_vocab(_vocab)
625613
id: int = 0
626614

627615
if logits_array is None:

llama_cpp/llama.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ def __init__(
374374
)
375375
)
376376

377+
self._vocab = llama_cpp.llama_model_get_vocab(self._model.model)
378+
377379
# Override tokenizer
378380
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
379381

@@ -2171,23 +2173,23 @@ def n_embd(self) -> int:
21712173

21722174
def n_vocab(self) -> int:
21732175
"""Return the vocabulary size."""
2174-
return self._model.n_vocab()
2176+
return self._model.n_vocab(self._vocab)
21752177

21762178
def tokenizer(self) -> LlamaTokenizer:
21772179
"""Return the llama tokenizer for this model."""
21782180
return LlamaTokenizer(self)
21792181

21802182
def token_eos(self) -> int:
21812183
"""Return the end-of-sequence token."""
2182-
return self._model.token_eos()
2184+
return self._model.token_eos(self._vocab)
21832185

21842186
def token_bos(self) -> int:
21852187
"""Return the beginning-of-sequence token."""
2186-
return self._model.token_bos()
2188+
return self._model.token_bos(self._vocab)
21872189

21882190
def token_nl(self) -> int:
21892191
"""Return the newline token."""
2190-
return self._model.token_nl()
2192+
return self._model.token_nl(self._vocab)
21912193

21922194
def pooling_type(self) -> str:
21932195
"""Return the pooling type."""

llama_cpp/llama_cpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,7 @@ def llama_pooling_type(ctx: llama_context_p, /) -> int:
13021302
...
13031303

13041304
# LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
1305-
@ctypes_function("llama_model_get_vocab", [llama_model_p_ctypes], llama_vocab_p)
1305+
@ctypes_function("llama_model_get_vocab", [llama_model_p_ctypes], llama_vocab_p_ctypes)
13061306
def llama_model_get_vocab(model: llama_model_p, /) -> Optional[llama_vocab_p]:
13071307
...
13081308

0 commit comments

Comments
 (0)