File tree Expand file tree Collapse file tree 2 files changed +18
-2
lines changed Expand file tree Collapse file tree 2 files changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -1041,7 +1041,7 @@ def embed(
1041
1041
data : Union [List [List [float ]], List [List [List [float ]]]] = []
1042
1042
1043
1043
def decode_batch (seq_sizes : List [int ]):
1044
- llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
1044
+ llama_cpp .llama_kv_self_clear (self ._ctx .ctx )
1045
1045
self ._ctx .decode (self ._batch )
1046
1046
self ._batch .reset ()
1047
1047
@@ -1112,7 +1112,7 @@ def decode_batch(seq_sizes: List[int]):
1112
1112
1113
1113
output = data [0 ] if isinstance (input , str ) else data
1114
1114
1115
- llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
1115
+ llama_cpp .llama_kv_self_clear (self ._ctx .ctx )
1116
1116
self .reset ()
1117
1117
1118
1118
if return_count :
Original file line number Diff line number Diff line change @@ -216,3 +216,19 @@ def logit_processor_func(input_ids, logits):
216
216
217
217
assert number_1 != number_2
218
218
assert number_1 == number_3
219
+
220
+
221
+ def test_real_llama_embeddings (llama_cpp_model_path ):
222
+ model = llama_cpp .Llama (
223
+ llama_cpp_model_path ,
224
+ n_ctx = 32 ,
225
+ n_batch = 32 ,
226
+ n_ubatch = 32 ,
227
+ n_threads = multiprocessing .cpu_count (),
228
+ n_threads_batch = multiprocessing .cpu_count (),
229
+ logits_all = False ,
230
+ flash_attn = True ,
231
+ embedding = True
232
+ )
233
+ # Smoke test for now
234
+ model .embed ("Hello World" )
You can’t perform that action at this time.
0 commit comments