2
2
#
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
import os
5
+ import random
5
6
from typing import List
6
- from haystack . utils . auth import Secret
7
+ from unittest . mock import Mock , patch
7
8
8
- import random
9
9
import pytest
10
+ from openai import APIError
10
11
11
12
from haystack import Document
12
13
from haystack .components .embedders .openai_document_embedder import OpenAIDocumentEmbedder
14
+ from haystack .utils .auth import Secret
13
15
14
16
15
17
def mock_openai_response (input : List [str ], model : str = "text-embedding-ada-002" , ** kwargs ) -> dict :
@@ -155,7 +157,8 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch):
155
157
156
158
def test_prepare_texts_to_embed_w_metadata (self ):
157
159
documents = [
158
- Document (content = f"document number { i } :\n content" , meta = {"meta_field" : f"meta_value { i } " }) for i in range (5 )
160
+ Document (id = f"{ i } " , content = f"document number { i } :\n content" , meta = {"meta_field" : f"meta_value { i } " })
161
+ for i in range (5 )
159
162
]
160
163
161
164
embedder = OpenAIDocumentEmbedder (
@@ -165,30 +168,30 @@ def test_prepare_texts_to_embed_w_metadata(self):
165
168
prepared_texts = embedder ._prepare_texts_to_embed (documents )
166
169
167
170
# note that newline is replaced by space
168
- assert prepared_texts == [
169
- "meta_value 0 | document number 0: content" ,
170
- "meta_value 1 | document number 1: content" ,
171
- "meta_value 2 | document number 2: content" ,
172
- "meta_value 3 | document number 3: content" ,
173
- "meta_value 4 | document number 4: content" ,
174
- ]
171
+ assert prepared_texts == {
172
+ "0" : " meta_value 0 | document number 0: content" ,
173
+ "1" : " meta_value 1 | document number 1: content" ,
174
+ "2" : " meta_value 2 | document number 2: content" ,
175
+ "3" : " meta_value 3 | document number 3: content" ,
176
+ "4" : " meta_value 4 | document number 4: content" ,
177
+ }
175
178
176
179
def test_prepare_texts_to_embed_w_suffix (self ):
177
- documents = [Document (content = f"document number { i } " ) for i in range (5 )]
180
+ documents = [Document (id = f" { i } " , content = f"document number { i } " ) for i in range (5 )]
178
181
179
182
embedder = OpenAIDocumentEmbedder (
180
183
api_key = Secret .from_token ("fake-api-key" ), prefix = "my_prefix " , suffix = " my_suffix"
181
184
)
182
185
183
186
prepared_texts = embedder ._prepare_texts_to_embed (documents )
184
187
185
- assert prepared_texts == [
186
- "my_prefix document number 0 my_suffix" ,
187
- "my_prefix document number 1 my_suffix" ,
188
- "my_prefix document number 2 my_suffix" ,
189
- "my_prefix document number 3 my_suffix" ,
190
- "my_prefix document number 4 my_suffix" ,
191
- ]
188
+ assert prepared_texts == {
189
+ "0" : " my_prefix document number 0 my_suffix" ,
190
+ "1" : " my_prefix document number 1 my_suffix" ,
191
+ "2" : " my_prefix document number 2 my_suffix" ,
192
+ "3" : " my_prefix document number 3 my_suffix" ,
193
+ "4" : " my_prefix document number 4 my_suffix" ,
194
+ }
192
195
193
196
def test_run_wrong_input_format (self ):
194
197
embedder = OpenAIDocumentEmbedder (api_key = Secret .from_token ("fake-api-key" ))
@@ -212,6 +215,19 @@ def test_run_on_empty_list(self):
212
215
assert result ["documents" ] is not None
213
216
assert not result ["documents" ] # empty list
214
217
218
+ def test_embed_batch_handles_exceptions_gracefully (self , caplog ):
219
+ embedder = OpenAIDocumentEmbedder (api_key = Secret .from_token ("fake_api_key" ))
220
+ fake_texts_to_embed = {"1" : "text1" , "2" : "text2" }
221
+ with patch .object (
222
+ embedder .client .embeddings ,
223
+ "create" ,
224
+ side_effect = APIError (message = "Mocked error" , request = Mock (), body = None ),
225
+ ):
226
+ embedder ._embed_batch (texts_to_embed = fake_texts_to_embed , batch_size = 2 )
227
+
228
+ assert len (caplog .records ) == 1
229
+ assert "Failed embedding of documents 1, 2 caused by Mocked error" in caplog .records [0 ].msg
230
+
215
231
@pytest .mark .skipif (os .environ .get ("OPENAI_API_KEY" , "" ) == "" , reason = "OPENAI_API_KEY is not set" )
216
232
@pytest .mark .integration
217
233
def test_run (self ):
0 commit comments