Skip to content

Commit 9e5a4ea

Browse files
committed
fix: Update reference to in Llama.embed. Closes #2037
1 parent 9770b84 commit 9e5a4ea

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

llama_cpp/llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,7 @@ def embed(
10411041
data: Union[List[List[float]], List[List[List[float]]]] = []
10421042

10431043
def decode_batch(seq_sizes: List[int]):
1044-
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
1044+
llama_cpp.llama_kv_self_clear(self._ctx.ctx)
10451045
self._ctx.decode(self._batch)
10461046
self._batch.reset()
10471047

@@ -1112,7 +1112,7 @@ def decode_batch(seq_sizes: List[int]):
11121112

11131113
output = data[0] if isinstance(input, str) else data
11141114

1115-
llama_cpp.llama_kv_cache_clear(self._ctx.ctx)
1115+
llama_cpp.llama_kv_self_clear(self._ctx.ctx)
11161116
self.reset()
11171117

11181118
if return_count:

tests/test_llama.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,19 @@ def logit_processor_func(input_ids, logits):
216216

217217
assert number_1 != number_2
218218
assert number_1 == number_3
219+
220+
221+
def test_real_llama_embeddings(llama_cpp_model_path):
222+
model = llama_cpp.Llama(
223+
llama_cpp_model_path,
224+
n_ctx=32,
225+
n_batch=32,
226+
n_ubatch=32,
227+
n_threads=multiprocessing.cpu_count(),
228+
n_threads_batch=multiprocessing.cpu_count(),
229+
logits_all=False,
230+
flash_attn=True,
231+
embedding=True
232+
)
233+
# Smoke test for now
234+
model.embed("Hello World")

0 commit comments

Comments
 (0)