@@ -43,7 +43,6 @@ def __init__(
43
43
self ._exit_stack = ExitStack ()
44
44
45
45
model = None
46
- vocab = None
47
46
48
47
if not os .path .exists (path_model ):
49
48
raise ValueError (f"Model path does not exist: { path_model } " )
@@ -58,24 +57,12 @@ def __init__(
58
57
59
58
self .model = model
60
59
61
- vocab = llama_cpp .llama_model_get_vocab (self .model )
62
-
63
- if vocab is None :
64
- raise ValueError (f"Failed to load vocab from file: { path_model } " )
65
-
66
- self .vocab = vocab
67
-
68
60
def free_model ():
69
61
if self .model is None :
70
62
return
71
63
llama_cpp .llama_model_free (self .model )
72
64
self .model = None
73
65
74
- if self .vocab is None :
75
- return
76
- llama_cpp .llama_model_free (self .vocab )
77
- self .vocab = None
78
-
79
66
self ._exit_stack .callback (free_model )
80
67
81
68
def close (self ):
@@ -84,11 +71,11 @@ def close(self):
84
71
def __del__ (self ):
85
72
self .close ()
86
73
87
- def vocab_type (self ) -> int :
88
- return llama_cpp .llama_vocab_type (self . vocab )
74
+ def vocab_type (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
75
+ return llama_cpp .llama_vocab_type (_vocab )
89
76
90
- def n_vocab (self ) -> int :
91
- return llama_cpp .llama_vocab_n_tokens (self . vocab )
77
+ def n_vocab (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
78
+ return llama_cpp .llama_vocab_n_tokens (_vocab )
92
79
93
80
def n_ctx_train (self ) -> int :
94
81
return llama_cpp .llama_model_n_ctx_train (self .model )
@@ -112,66 +99,66 @@ def n_params(self) -> int:
112
99
113
100
# Vocab
114
101
115
- def token_get_text (self , token : int ) -> str :
116
- return llama_cpp .llama_vocab_get_text (self . vocab , token ).decode ("utf-8" )
102
+ def token_get_text (self , _vocab : llama_cpp . llama_vocab_p , token : int ) -> str :
103
+ return llama_cpp .llama_vocab_get_text (_vocab , token ).decode ("utf-8" )
117
104
118
- def token_get_score (self , token : int ) -> float :
119
- return llama_cpp .llama_vocab_get_score (self . vocab , token )
105
+ def token_get_score (self , _vocab : llama_cpp . llama_vocab_p , token : int ) -> float :
106
+ return llama_cpp .llama_vocab_get_score (_vocab , token )
120
107
121
- def token_get_attr (self , token : int ) -> int :
122
- return llama_cpp .llama_vocab_get_attr (self . vocab , token )
108
+ def token_get_attr (self , _vocab : llama_cpp . llama_vocab_p , token : int ) -> int :
109
+ return llama_cpp .llama_vocab_get_attr (_vocab , token )
123
110
124
111
# Special tokens
125
112
126
- def token_bos (self ) -> int :
127
- return llama_cpp .llama_vocab_bos (self . vocab )
113
+ def token_bos (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
114
+ return llama_cpp .llama_vocab_bos (_vocab )
128
115
129
- def token_eos (self ) -> int :
130
- return llama_cpp .llama_vocab_eos (self . vocab )
116
+ def token_eos (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
117
+ return llama_cpp .llama_vocab_eos (_vocab )
131
118
132
- def token_eot (self ) -> int :
133
- return llama_cpp .llama_vocab_eot (self . vocab )
119
+ def token_eot (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
120
+ return llama_cpp .llama_vocab_eot (_vocab )
134
121
135
- def token_cls (self ) -> int :
136
- return llama_cpp .llama_vocab_cls (self . vocab )
122
+ def token_cls (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
123
+ return llama_cpp .llama_vocab_cls (_vocab )
137
124
138
- def token_sep (self ) -> int :
139
- return llama_cpp .llama_vocab_sep (self . vocab )
125
+ def token_sep (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
126
+ return llama_cpp .llama_vocab_sep (_vocab )
140
127
141
- def token_nl (self ) -> int :
142
- return llama_cpp .llama_vocab_nl (self . vocab )
128
+ def token_nl (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
129
+ return llama_cpp .llama_vocab_nl (_vocab )
143
130
144
- def token_pad (self ) -> int :
145
- return llama_cpp .llama_vocab_pad (self . vocab )
131
+ def token_pad (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
132
+ return llama_cpp .llama_vocab_pad (_vocab )
146
133
147
- def token_prefix (self ) -> int :
148
- return llama_cpp .llama_vocab_fim_pre (self . vocab )
134
+ def token_prefix (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
135
+ return llama_cpp .llama_vocab_fim_pre (_vocab )
149
136
150
- def token_middle (self ) -> int :
151
- return llama_cpp .llama_vocab_fim_mid (self . vocab )
137
+ def token_middle (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
138
+ return llama_cpp .llama_vocab_fim_mid (_vocab )
152
139
153
- def token_suffix (self ) -> int :
154
- return llama_cpp .llama_vocab_fim_suf (self . vocab )
140
+ def token_suffix (self , _vocab : llama_cpp . llama_vocab_p ) -> int :
141
+ return llama_cpp .llama_vocab_fim_suf (_vocab )
155
142
156
- def add_bos_token (self ) -> bool :
157
- return llama_cpp .llama_vocab_get_add_bos (self . vocab )
143
+ def add_bos_token (self , _vocab : llama_cpp . llama_vocab_p ) -> bool :
144
+ return llama_cpp .llama_vocab_get_add_bos (_vocab )
158
145
159
- def add_eos_token (self ) -> bool :
160
- return llama_cpp .llama_vocab_get_add_eos (self . vocab )
146
+ def add_eos_token (self , _vocab : llama_cpp . llama_vocab_p ) -> bool :
147
+ return llama_cpp .llama_vocab_get_add_eos (_vocab )
161
148
162
149
# Tokenization
163
150
164
- def tokenize (self , text : bytes , add_bos : bool , special : bool ):
151
+ def tokenize (self , _vocab : llama_cpp . llama_vocab_p , text : bytes , add_bos : bool , special : bool ):
165
152
n_ctx = self .n_ctx_train ()
166
153
tokens = (llama_cpp .llama_token * n_ctx )()
167
154
n_tokens = llama_cpp .llama_tokenize (
168
- self . vocab , text , len (text ), tokens , n_ctx , add_bos , special
155
+ _vocab , text , len (text ), tokens , n_ctx , add_bos , special
169
156
)
170
157
if n_tokens < 0 :
171
158
n_tokens = abs (n_tokens )
172
159
tokens = (llama_cpp .llama_token * n_tokens )()
173
160
n_tokens = llama_cpp .llama_tokenize (
174
- self . vocab , text , len (text ), tokens , n_tokens , add_bos , special
161
+ _vocab , text , len (text ), tokens , n_tokens , add_bos , special
175
162
)
176
163
if n_tokens < 0 :
177
164
raise RuntimeError (
@@ -618,10 +605,11 @@ def prev_str(self, ctx_main: LlamaContext, n: int) -> str:
618
605
def sample (
619
606
self ,
620
607
ctx_main : LlamaContext ,
608
+ _vocab :llama_cpp .llama_vocab_p ,
621
609
idx : int = 0 ,
622
610
logits_array : Optional [npt .NDArray [np .single ]] = None ,
623
611
):
624
- n_vocab = ctx_main .model .n_vocab ()
612
+ n_vocab = ctx_main .model .n_vocab (_vocab )
625
613
id : int = 0
626
614
627
615
if logits_array is None :
0 commit comments