Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions packages/search-core/src/grogbot_search/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from sentence_transformers import SentenceTransformer

_EMBEDDING_BATCH_SIZE = 8


@lru_cache(maxsize=1)
def _load_model() -> SentenceTransformer:
Expand All @@ -19,8 +17,10 @@ def embed_texts(texts: Iterable[str], *, prompt: str) -> List[list[float]]:
return []

model = _load_model()
embeddings = []
for start in range(0, len(text_list), _EMBEDDING_BATCH_SIZE):
batch = text_list[start : start + _EMBEDDING_BATCH_SIZE]
embeddings.extend(model.encode(batch, normalize_embeddings=True, prompt=prompt))
embeddings = model.encode(
text_list,
batch_size=16,
normalize_embeddings=True,
prompt=prompt,
)
return [embedding.tolist() for embedding in embeddings]
28 changes: 3 additions & 25 deletions packages/search-core/tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class FakeModel:
def __init__(self):
self.calls = []

def encode(self, texts, *, normalize_embeddings: bool, prompt: str):
self.calls.append((texts, normalize_embeddings, prompt))
def encode(self, texts, *, batch_size: int, normalize_embeddings: bool, prompt: str):
self.calls.append((texts, batch_size, normalize_embeddings, prompt))
return [_FakeArray([1.0, 2.0]), _FakeArray([3.0, 4.0])]

fake_model = FakeModel()
Expand All @@ -47,26 +47,4 @@ def encode(self, texts, *, normalize_embeddings: bool, prompt: str):
result = embeddings.embed_texts(("first", "second"), prompt="search_query")

assert result == [[1.0, 2.0], [3.0, 4.0]]
assert fake_model.calls == [(["first", "second"], True, "search_query")]


def test_embed_texts_batches_requests_to_max_eight(monkeypatch):
class FakeModel:
def __init__(self):
self.calls = []

def encode(self, texts, *, normalize_embeddings: bool, prompt: str):
self.calls.append((list(texts), normalize_embeddings, prompt))
return [_FakeArray([float(int(text.removeprefix("chunk-")))]) for text in texts]

fake_model = FakeModel()
monkeypatch.setattr(embeddings, "_load_model", lambda: fake_model)

inputs = [f"chunk-{index}" for index in range(10)]
result = embeddings.embed_texts(inputs, prompt="search_document")

assert result == [[float(index)] for index in range(10)]
assert fake_model.calls == [
([f"chunk-{index}" for index in range(8)], True, "search_document"),
(["chunk-8", "chunk-9"], True, "search_document"),
]
assert fake_model.calls == [(["first", "second"], 16, True, "search_query")]