Skip to content

Commit 2563024

Browse files
feat: add asynchronous embedding methods for GoogleGenAIDocumentEmbedder and GoogleGenAITextEmbedder (#1983)
* feat: Add GoogleAITextEmbedder and GoogleAIDocumentEmbedder components * fix: Improve error messages for input type validation in GoogleAITextEmbedder and GoogleAIDocumentEmbedder * feat: add Google GenAI embedder components for document and text embeddings * feat: add unit tests for GoogleAIDocumentEmbedder and GoogleAITextEmbedder * refactor: clean up imports and improve list handling in GoogleAIDocumentEmbedder and GoogleAITextEmbedder tests * refactor: Rename classes and update imports for Google GenAI components * feat: Add additional modules for Google GenAI embedders in config * chore: add 'more-itertools' to lint environment dependencies * refactor: update GoogleGenAIDocumentEmbedder and GoogleGenAITextEmbedder to use private attributes for initialization * refactor: update _prepare_texts_to_embed to return a list instead of a dictionary * refactor: format code for better readability and consistency in document embedder * refactor: improve code formatting for consistency and readability in document embedder and tests * refactor: update _prepare_texts_to_embed to return a list instead of a dictionary * feat: add new author to project metadata in pyproject.toml * feat: add asynchronous embedding methods for GoogleGenAIDocumentEmbedder and GoogleGenAITextEmbedder * fix: ensure consistent formatting for pylint * fix: update return type annotation for run_async method in GoogleGenAIDocumentEmbedder * fix: update return type annotation for run_async method in GoogleGenAITextEmbedder * fix: update return type annotation and handle None values in _embed_batch_async method * fix: remove unnecessary blank line in _embed_batch_async method --------- Co-authored-by: David S. Batista <[email protected]>
1 parent c45d294 commit 2563024

File tree

4 files changed

+135
-0
lines changed

4 files changed

+135
-0
lines changed

integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/document_embedder.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,37 @@ def _embed_batch(
171171

172172
return all_embeddings, meta
173173

174+
async def _embed_batch_async(
175+
self, texts_to_embed: List[str], batch_size: int
176+
) -> Tuple[List[Optional[List[float]]], Dict[str, Any]]:
177+
"""
178+
Embed a list of texts in batches asynchronously.
179+
"""
180+
181+
all_embeddings = []
182+
meta: Dict[str, Any] = {}
183+
for batch in tqdm(
184+
batched(texts_to_embed, batch_size), disable=not self._progress_bar, desc="Calculating embeddings"
185+
):
186+
args: Dict[str, Any] = {"model": self._model, "contents": [b[1] for b in batch]}
187+
if self._config:
188+
args["config"] = types.EmbedContentConfig(**self._config) if self._config else None
189+
190+
response = await self._client.aio.models.embed_content(**args)
191+
192+
embeddings = []
193+
if response.embeddings:
194+
for el in response.embeddings:
195+
embeddings.append(el.values if el.values else None)
196+
all_embeddings.extend(embeddings)
197+
else:
198+
all_embeddings.extend([None] * len(batch))
199+
200+
if "model" not in meta:
201+
meta["model"] = self._model
202+
203+
return all_embeddings, meta
204+
174205
@component.output_types(documents=List[Document], meta=Dict[str, Any])
175206
def run(self, documents: List[Document]) -> Union[Dict[str, List[Document]], Dict[str, Any]]:
176207
"""
@@ -200,3 +231,32 @@ def run(self, documents: List[Document]) -> Union[Dict[str, List[Document]], Dic
200231
doc.embedding = emb
201232

202233
return {"documents": documents, "meta": meta}
234+
235+
@component.output_types(documents=List[Document], meta=Dict[str, Any])
236+
async def run_async(self, documents: List[Document]) -> Union[Dict[str, List[Document]], Dict[str, Any]]:
237+
"""
238+
Embeds a list of documents asynchronously.
239+
240+
:param documents:
241+
A list of documents to embed.
242+
243+
:returns:
244+
A dictionary with the following keys:
245+
- `documents`: A list of documents with embeddings.
246+
- `meta`: Information about the usage of the model.
247+
"""
248+
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
249+
error_message_documents = (
250+
"GoogleGenAIDocumentEmbedder expects a list of Documents as input. "
251+
"In case you want to embed a string, please use the GoogleGenAITextEmbedder."
252+
)
253+
raise TypeError(error_message_documents)
254+
255+
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
256+
257+
embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self._batch_size)
258+
259+
for doc, emb in zip(documents, embeddings):
260+
doc.embedding = emb
261+
262+
return {"documents": documents, "meta": meta}

integrations/google_genai/src/haystack_integrations/components/embedders/google_genai/text_embedder.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,23 @@ def run(self, text: str) -> Union[Dict[str, List[float]], Dict[str, Any]]:
138138
create_kwargs = self._prepare_input(text=text)
139139
response = self._client.models.embed_content(**create_kwargs)
140140
return self._prepare_output(result=response)
141+
142+
@component.output_types(embedding=List[float], meta=Dict[str, Any])
143+
async def run_async(self, text: str) -> Union[Dict[str, List[float]], Dict[str, Any]]:
144+
"""
145+
Asynchronously embed a single string.
146+
147+
This is the asynchronous version of the `run` method. It has the same parameters and return values
148+
but can be used with `await` in async code.
149+
150+
:param text:
151+
Text to embed.
152+
153+
:returns:
154+
A dictionary with the following keys:
155+
- `embedding`: The embedding of the input text.
156+
- `meta`: Information about the usage of the model.
157+
"""
158+
create_kwargs = self._prepare_input(text=text)
159+
response = await self._client.aio.models.embed_content(**create_kwargs)
160+
return self._prepare_output(result=response)

integrations/google_genai/tests/test_document_embedder.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,39 @@ def test_run(self):
233233
assert len(doc.embedding) == 768
234234
assert all(isinstance(x, float) for x in doc.embedding)
235235

236+
assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], (
237+
"The model name does not contain 'text' and '004'"
238+
)
239+
240+
@pytest.mark.asyncio
241+
@pytest.mark.skipif(
242+
not os.environ.get("GOOGLE_API_KEY", None),
243+
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",
244+
)
245+
@pytest.mark.integration
246+
async def test_run_async(self):
247+
docs = [
248+
Document(content="I love cheese", meta={"topic": "Cuisine"}),
249+
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
250+
]
251+
252+
model = "text-embedding-004"
253+
254+
embedder = GoogleGenAIDocumentEmbedder(model=model, meta_fields_to_embed=["topic"], embedding_separator=" | ")
255+
256+
result = await embedder.run_async(documents=docs)
257+
documents_with_embeddings = result["documents"]
258+
assert isinstance(documents_with_embeddings, list)
259+
assert len(documents_with_embeddings) == len(docs)
260+
for doc in documents_with_embeddings:
261+
assert isinstance(doc, Document)
262+
assert isinstance(doc.embedding, list)
263+
assert len(doc.embedding) == 768
264+
assert all(isinstance(x, float) for x in doc.embedding)
265+
266+
assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], (
267+
"The model name does not contain 'text' and '004'"
268+
)
236269
assert result["documents"][0].meta == {"topic": "Cuisine"}
237270
assert result["documents"][1].meta == {"topic": "ML"}
238271
assert result["meta"] == {"model": model}

integrations/google_genai/tests/test_text_embedder.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,26 @@ def test_run(self):
160160
assert len(result["embedding"]) == 768
161161
assert all(isinstance(x, float) for x in result["embedding"])
162162

163+
assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], (
164+
"The model name does not contain 'text' and '004'"
165+
)
166+
167+
@pytest.mark.asyncio
168+
@pytest.mark.skipif(
169+
not os.environ.get("GOOGLE_API_KEY", None),
170+
reason="Export an env var called GOOGLE_API_KEY containing the Google API key to run this test.",
171+
)
172+
@pytest.mark.integration
173+
async def test_run_async(self):
174+
model = "text-embedding-004"
175+
176+
embedder = GoogleGenAITextEmbedder(model=model)
177+
result = await embedder.run_async(text="The food was delicious")
178+
179+
assert len(result["embedding"]) == 768
180+
assert all(isinstance(x, float) for x in result["embedding"])
181+
182+
assert "text" in result["meta"]["model"] and "004" in result["meta"]["model"], (
183+
"The model name does not contain 'text' and '004'"
184+
)
163185
assert result["meta"] == {"model": model}

0 commit comments

Comments
 (0)