Skip to content

Commit c78545d

Browse files
nillebsilvanocerza
andauthored
feat(openai): be tolerant to exceptions (#8526)
* feat: be tolerant to exceptions if ever an error is raised by the OpenAI API, don't fail the entire processing * fix: missing import, string separator * Enhance error handling * Use batched from more_itertools for compatibility with older Python versions * Fix batching and add test --------- Co-authored-by: Silvano Cerza <[email protected]>
1 parent f085959 commit c78545d

File tree

3 files changed

+63
-32
lines changed

3 files changed

+63
-32
lines changed

Diff for: haystack/components/embedders/openai_document_embedder.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
import os
66
from typing import Any, Dict, List, Optional, Tuple
77

8-
from openai import OpenAI
8+
from more_itertools import batched
9+
from openai import APIError, OpenAI
910
from tqdm import tqdm
1011

11-
from haystack import Document, component, default_from_dict, default_to_dict
12+
from haystack import Document, component, default_from_dict, default_to_dict, logging
1213
from haystack.utils import Secret, deserialize_secrets_inplace
1314

15+
logger = logging.getLogger(__name__)
16+
1417

1518
@component
1619
class OpenAIDocumentEmbedder:
@@ -34,7 +37,7 @@ class OpenAIDocumentEmbedder:
3437
```
3538
"""
3639

37-
def __init__(
40+
def __init__( # pylint: disable=too-many-positional-arguments
3841
self,
3942
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
4043
model: str = "text-embedding-ada-002",
@@ -158,11 +161,11 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAIDocumentEmbedder":
158161
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
159162
return default_from_dict(cls, data)
160163

161-
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
164+
def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]:
162165
"""
163166
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
164167
"""
165-
texts_to_embed = []
168+
texts_to_embed = {}
166169
for doc in documents:
167170
meta_values_to_embed = [
168171
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
@@ -174,25 +177,32 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
174177

175178
# copied from OpenAI embedding_utils (https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py)
176179
# replace newlines, which can negatively affect performance.
177-
text_to_embed = text_to_embed.replace("\n", " ")
178-
texts_to_embed.append(text_to_embed)
180+
texts_to_embed[doc.id] = text_to_embed.replace("\n", " ")
179181
return texts_to_embed
180182

181-
def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
183+
def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
182184
"""
183185
Embed a list of texts in batches.
184186
"""
185187

186188
all_embeddings = []
187189
meta: Dict[str, Any] = {}
188-
for i in tqdm(
189-
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
190+
for batch in tqdm(
191+
batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
190192
):
191-
batch = texts_to_embed[i : i + batch_size]
193+
args: Dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]}
194+
192195
if self.dimensions is not None:
193-
response = self.client.embeddings.create(model=self.model, dimensions=self.dimensions, input=batch)
194-
else:
195-
response = self.client.embeddings.create(model=self.model, input=batch)
196+
args["dimensions"] = self.dimensions
197+
198+
try:
199+
response = self.client.embeddings.create(**args)
200+
except APIError as exc:
201+
ids = ", ".join(b[0] for b in batch)
202+
msg = "Failed embedding of documents {ids} caused by {exc}"
203+
logger.exception(msg, ids=ids, exc=exc)
204+
continue
205+
196206
embeddings = [el.embedding for el in response.data]
197207
all_embeddings.extend(embeddings)
198208

Diff for: releasenotes/notes/patch-1-34479efe3bea0e4f.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
enhancements:
3+
- |
4+
Change `OpenAIDocumentEmbedder` to keep running if a batch fails embedding.
5+
Now OpenAI returns an error we log that error and keep processing following batches.

Diff for: test/components/embedders/test_openai_document_embedder.py

+34-18
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
import os
5+
import random
56
from typing import List
6-
from haystack.utils.auth import Secret
7+
from unittest.mock import Mock, patch
78

8-
import random
99
import pytest
10+
from openai import APIError
1011

1112
from haystack import Document
1213
from haystack.components.embedders.openai_document_embedder import OpenAIDocumentEmbedder
14+
from haystack.utils.auth import Secret
1315

1416

1517
def mock_openai_response(input: List[str], model: str = "text-embedding-ada-002", **kwargs) -> dict:
@@ -155,7 +157,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
155157

156158
def test_prepare_texts_to_embed_w_metadata(self):
157159
documents = [
158-
Document(content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"}) for i in range(5)
160+
Document(id=f"{i}", content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"})
161+
for i in range(5)
159162
]
160163

161164
embedder = OpenAIDocumentEmbedder(
@@ -165,30 +168,30 @@ def test_prepare_texts_to_embed_w_metadata(self):
165168
prepared_texts = embedder._prepare_texts_to_embed(documents)
166169

167170
# note that newline is replaced by space
168-
assert prepared_texts == [
169-
"meta_value 0 | document number 0: content",
170-
"meta_value 1 | document number 1: content",
171-
"meta_value 2 | document number 2: content",
172-
"meta_value 3 | document number 3: content",
173-
"meta_value 4 | document number 4: content",
174-
]
171+
assert prepared_texts == {
172+
"0": "meta_value 0 | document number 0: content",
173+
"1": "meta_value 1 | document number 1: content",
174+
"2": "meta_value 2 | document number 2: content",
175+
"3": "meta_value 3 | document number 3: content",
176+
"4": "meta_value 4 | document number 4: content",
177+
}
175178

176179
def test_prepare_texts_to_embed_w_suffix(self):
177-
documents = [Document(content=f"document number {i}") for i in range(5)]
180+
documents = [Document(id=f"{i}", content=f"document number {i}") for i in range(5)]
178181

179182
embedder = OpenAIDocumentEmbedder(
180183
api_key=Secret.from_token("fake-api-key"), prefix="my_prefix ", suffix=" my_suffix"
181184
)
182185

183186
prepared_texts = embedder._prepare_texts_to_embed(documents)
184187

185-
assert prepared_texts == [
186-
"my_prefix document number 0 my_suffix",
187-
"my_prefix document number 1 my_suffix",
188-
"my_prefix document number 2 my_suffix",
189-
"my_prefix document number 3 my_suffix",
190-
"my_prefix document number 4 my_suffix",
191-
]
188+
assert prepared_texts == {
189+
"0": "my_prefix document number 0 my_suffix",
190+
"1": "my_prefix document number 1 my_suffix",
191+
"2": "my_prefix document number 2 my_suffix",
192+
"3": "my_prefix document number 3 my_suffix",
193+
"4": "my_prefix document number 4 my_suffix",
194+
}
192195

193196
def test_run_wrong_input_format(self):
194197
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake-api-key"))
@@ -212,6 +215,19 @@ def test_run_on_empty_list(self):
212215
assert result["documents"] is not None
213216
assert not result["documents"] # empty list
214217

218+
def test_embed_batch_handles_exceptions_gracefully(self, caplog):
219+
embedder = OpenAIDocumentEmbedder(api_key=Secret.from_token("fake_api_key"))
220+
fake_texts_to_embed = {"1": "text1", "2": "text2"}
221+
with patch.object(
222+
embedder.client.embeddings,
223+
"create",
224+
side_effect=APIError(message="Mocked error", request=Mock(), body=None),
225+
):
226+
embedder._embed_batch(texts_to_embed=fake_texts_to_embed, batch_size=2)
227+
228+
assert len(caplog.records) == 1
229+
assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog.records[0].msg
230+
215231
@pytest.mark.skipif(os.environ.get("OPENAI_API_KEY", "") == "", reason="OPENAI_API_KEY is not set")
216232
@pytest.mark.integration
217233
def test_run(self):

0 commit comments

Comments
 (0)