diff --git a/packages/search-core/src/grogbot_search/embeddings.py b/packages/search-core/src/grogbot_search/embeddings.py index 008751a..9a71c5a 100644 --- a/packages/search-core/src/grogbot_search/embeddings.py +++ b/packages/search-core/src/grogbot_search/embeddings.py @@ -5,8 +5,6 @@ from sentence_transformers import SentenceTransformer -_EMBEDDING_BATCH_SIZE = 8 - @lru_cache(maxsize=1) def _load_model() -> SentenceTransformer: @@ -19,8 +17,10 @@ def embed_texts(texts: Iterable[str], *, prompt: str) -> List[list[float]]: return [] model = _load_model() - embeddings = [] - for start in range(0, len(text_list), _EMBEDDING_BATCH_SIZE): - batch = text_list[start : start + _EMBEDDING_BATCH_SIZE] - embeddings.extend(model.encode(batch, normalize_embeddings=True, prompt=prompt)) + embeddings = model.encode( + text_list, + batch_size=16, + normalize_embeddings=True, + prompt=prompt, + ) return [embedding.tolist() for embedding in embeddings] diff --git a/packages/search-core/tests/test_embeddings.py b/packages/search-core/tests/test_embeddings.py index bb67b35..3851a3a 100644 --- a/packages/search-core/tests/test_embeddings.py +++ b/packages/search-core/tests/test_embeddings.py @@ -37,8 +37,8 @@ class FakeModel: def __init__(self): self.calls = [] - def encode(self, texts, *, normalize_embeddings: bool, prompt: str): - self.calls.append((texts, normalize_embeddings, prompt)) + def encode(self, texts, *, batch_size: int, normalize_embeddings: bool, prompt: str): + self.calls.append((texts, batch_size, normalize_embeddings, prompt)) return [_FakeArray([1.0, 2.0]), _FakeArray([3.0, 4.0])] fake_model = FakeModel() @@ -47,26 +47,4 @@ def encode(self, texts, *, normalize_embeddings: bool, prompt: str): result = embeddings.embed_texts(("first", "second"), prompt="search_query") assert result == [[1.0, 2.0], [3.0, 4.0]] - assert fake_model.calls == [(["first", "second"], True, "search_query")] - - -def test_embed_texts_batches_requests_to_max_eight(monkeypatch): - class FakeModel: - def __init__(self): - self.calls = [] - - def encode(self, texts, *, normalize_embeddings: bool, prompt: str): - self.calls.append((list(texts), normalize_embeddings, prompt)) - return [_FakeArray([float(int(text.removeprefix("chunk-")))]) for text in texts] - - fake_model = FakeModel() - monkeypatch.setattr(embeddings, "_load_model", lambda: fake_model) - - inputs = [f"chunk-{index}" for index in range(10)] - result = embeddings.embed_texts(inputs, prompt="search_document") - - assert result == [[float(index)] for index in range(10)] - assert fake_model.calls == [ - ([f"chunk-{index}" for index in range(8)], True, "search_document"), - (["chunk-8", "chunk-9"], True, "search_document"), - ] + assert fake_model.calls == [(["first", "second"], 16, True, "search_query")]