Skip to content

Commit 877f826

Browse files
authored
refactor: HF API Embedders - use InferenceClient.feature_extraction instead of InferenceClient.post (#8794)
* HF API Embedders: refactoring * rename variables * rm leftovers * rm pin * rm unused import * relnote * warning with truncate/normalize and serverless inference API * test that warnings are raised
1 parent f165212 commit 877f826

6 files changed

+139
-47
lines changed

Diff for: haystack/components/embedders/hugging_face_api_document_embedder.py

+26-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import json
5+
import warnings
66
from typing import Any, Dict, List, Optional, Union
77

88
from tqdm import tqdm
@@ -96,8 +96,8 @@ def __init__(
9696
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
9797
prefix: str = "",
9898
suffix: str = "",
99-
truncate: bool = True,
100-
normalize: bool = False,
99+
truncate: Optional[bool] = True,
100+
normalize: Optional[bool] = False,
101101
batch_size: int = 32,
102102
progress_bar: bool = True,
103103
meta_fields_to_embed: Optional[List[str]] = None,
@@ -124,13 +124,11 @@ def __init__(
124124
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
125125
if the backend uses Text Embeddings Inference.
126126
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
127-
It is always set to `True` and cannot be changed.
128127
:param normalize:
129128
Normalizes the embeddings to unit length.
130129
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
131130
if the backend uses Text Embeddings Inference.
132131
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
133-
It is always set to `False` and cannot be changed.
134132
:param batch_size:
135133
Number of documents to process at once.
136134
:param progress_bar:
@@ -239,18 +237,36 @@ def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[
239237
"""
240238
Embed a list of texts in batches.
241239
"""
240+
truncate = self.truncate
241+
normalize = self.normalize
242+
243+
if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
244+
if truncate is not None:
245+
msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
246+
warnings.warn(msg)
247+
truncate = None
248+
if normalize is not None:
249+
msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
250+
warnings.warn(msg)
251+
normalize = None
242252

243253
all_embeddings = []
244254
for i in tqdm(
245255
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
246256
):
247257
batch = texts_to_embed[i : i + batch_size]
248-
response = self._client.post(
249-
json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize},
250-
task="feature-extraction",
258+
259+
np_embeddings = self._client.feature_extraction(
260+
# this method does not officially support list of strings, but works as expected
261+
text=batch, # type: ignore[arg-type]
262+
truncate=truncate,
263+
normalize=normalize,
251264
)
252-
embeddings = json.loads(response.decode())
253-
all_embeddings.extend(embeddings)
265+
266+
if np_embeddings.ndim != 2 or np_embeddings.shape[0] != len(batch):
267+
raise ValueError(f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}")
268+
269+
all_embeddings.extend(np_embeddings.tolist())
254270

255271
return all_embeddings
256272

Diff for: haystack/components/embedders/hugging_face_api_text_embedder.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import json
5+
import warnings
66
from typing import Any, Dict, List, Optional, Union
77

88
from haystack import component, default_from_dict, default_to_dict, logging
@@ -80,8 +80,8 @@ def __init__(
8080
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
8181
prefix: str = "",
8282
suffix: str = "",
83-
truncate: bool = True,
84-
normalize: bool = False,
83+
truncate: Optional[bool] = True,
84+
normalize: Optional[bool] = False,
8585
): # pylint: disable=too-many-positional-arguments
8686
"""
8787
Creates a HuggingFaceAPITextEmbedder component.
@@ -104,13 +104,11 @@ def __init__(
104104
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
105105
if the backend uses Text Embeddings Inference.
106106
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
107-
It is always set to `True` and cannot be changed.
108107
:param normalize:
109108
Normalizes the embeddings to unit length.
110109
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
111110
if the backend uses Text Embeddings Inference.
112111
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
113-
It is always set to `False` and cannot be changed.
114112
"""
115113
huggingface_hub_import.check()
116114

@@ -198,12 +196,29 @@ def run(self, text: str):
198196
"In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder."
199197
)
200198

199+
truncate = self.truncate
200+
normalize = self.normalize
201+
202+
if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
203+
if truncate is not None:
204+
msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
205+
warnings.warn(msg)
206+
truncate = None
207+
if normalize is not None:
208+
msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
209+
warnings.warn(msg)
210+
normalize = None
211+
201212
text_to_embed = self.prefix + text + self.suffix
202213

203-
response = self._client.post(
204-
json={"inputs": [text_to_embed], "truncate": self.truncate, "normalize": self.normalize},
205-
task="feature-extraction",
206-
)
207-
embedding = json.loads(response.decode())[0]
214+
np_embedding = self._client.feature_extraction(text=text_to_embed, truncate=truncate, normalize=normalize)
215+
216+
error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}"
217+
if np_embedding.ndim > 2:
218+
raise ValueError(error_msg)
219+
if np_embedding.ndim == 2 and np_embedding.shape[0] != 1:
220+
raise ValueError(error_msg)
221+
222+
embedding = np_embedding.flatten().tolist()
208223

209224
return {"embedding": embedding}

Diff for: pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ extra-dependencies = [
8787
"numba>=0.54.0", # This pin helps uv resolve the dependency tree. See https://github.com/astral-sh/uv/issues/7881
8888

8989
"transformers[torch,sentencepiece]==4.47.1", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators...
90-
"huggingface_hub>=0.27.0, <0.28.0", # Hugging Face API Generators and Embedders
90+
"huggingface_hub>=0.27.0", # Hugging Face API Generators and Embedders
9191
"sentence-transformers>=3.0.0", # SentenceTransformersTextEmbedder and SentenceTransformersDocumentEmbedder
9292
"langdetect", # TextLanguageRouter and DocumentLanguageClassifier
9393
"openai-whisper>=20231106", # LocalWhisperTranscriber
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
fixes:
3+
- |
4+
In the Hugging Face API embedders, the `InferenceClient.feature_extraction` method is now used instead of
5+
`InferenceClient.post` to compute embeddings. This ensures a more robust and future-proof implementation.

Diff for: test/components/embedders/test_hugging_face_api_document_embedder.py

+49-15
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import pytest
99
from huggingface_hub.utils import RepositoryNotFoundError
1010

11+
from numpy import array
12+
1113
from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
1214
from haystack.dataclasses import Document
1315
from haystack.utils.auth import Secret
@@ -23,8 +25,8 @@ def mock_check_valid_model():
2325
yield mock
2426

2527

26-
def mock_embedding_generation(json, **kwargs):
27-
response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode()
28+
def mock_embedding_generation(text, **kwargs):
29+
response = array([[random.random() for _ in range(384)] for _ in range(len(text))])
2830
return response
2931

3032

@@ -201,10 +203,10 @@ def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model):
201203
"my_prefix document number 4 my_suffix",
202204
]
203205

204-
def test_embed_batch(self, mock_check_valid_model):
206+
def test_embed_batch(self, mock_check_valid_model, recwarn):
205207
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
206208

207-
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
209+
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
208210
mock_embedding_patch.side_effect = mock_embedding_generation
209211

210212
embedder = HuggingFaceAPIDocumentEmbedder(
@@ -223,6 +225,40 @@ def test_embed_batch(self, mock_check_valid_model):
223225
assert len(embedding) == 384
224226
assert all(isinstance(x, float) for x in embedding)
225227

228+
# Check that warnings about ignoring truncate and normalize are raised
229+
assert len(recwarn) == 2
230+
assert "truncate" in str(recwarn[0].message)
231+
assert "normalize" in str(recwarn[1].message)
232+
233+
def test_embed_batch_wrong_embedding_shape(self, mock_check_valid_model):
234+
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]
235+
236+
# embedding ndim != 2
237+
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
238+
mock_embedding_patch.return_value = array([0.1, 0.2, 0.3])
239+
240+
embedder = HuggingFaceAPIDocumentEmbedder(
241+
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
242+
api_params={"model": "BAAI/bge-small-en-v1.5"},
243+
token=Secret.from_token("fake-api-token"),
244+
)
245+
246+
with pytest.raises(ValueError):
247+
embedder._embed_batch(texts_to_embed=texts, batch_size=2)
248+
249+
# embedding ndim == 2 but shape[0] != len(batch)
250+
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
251+
mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]])
252+
253+
embedder = HuggingFaceAPIDocumentEmbedder(
254+
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
255+
api_params={"model": "BAAI/bge-small-en-v1.5"},
256+
token=Secret.from_token("fake-api-token"),
257+
)
258+
259+
with pytest.raises(ValueError):
260+
embedder._embed_batch(texts_to_embed=texts, batch_size=2)
261+
226262
def test_run_wrong_input_format(self, mock_check_valid_model):
227263
embedder = HuggingFaceAPIDocumentEmbedder(
228264
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
@@ -252,7 +288,7 @@ def test_run(self, mock_check_valid_model):
252288
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
253289
]
254290

255-
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
291+
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
256292
mock_embedding_patch.side_effect = mock_embedding_generation
257293

258294
embedder = HuggingFaceAPIDocumentEmbedder(
@@ -268,16 +304,14 @@ def test_run(self, mock_check_valid_model):
268304
result = embedder.run(documents=docs)
269305

270306
mock_embedding_patch.assert_called_once_with(
271-
json={
272-
"inputs": [
273-
"prefix Cuisine | I love cheese suffix",
274-
"prefix ML | A transformer is a deep learning architecture suffix",
275-
],
276-
"truncate": True,
277-
"normalize": False,
278-
},
279-
task="feature-extraction",
307+
text=[
308+
"prefix Cuisine | I love cheese suffix",
309+
"prefix ML | A transformer is a deep learning architecture suffix",
310+
],
311+
truncate=None,
312+
normalize=None,
280313
)
314+
281315
documents_with_embeddings = result["documents"]
282316

283317
assert isinstance(documents_with_embeddings, list)
@@ -294,7 +328,7 @@ def test_run_custom_batch_size(self, mock_check_valid_model):
294328
Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}),
295329
]
296330

297-
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
331+
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
298332
mock_embedding_patch.side_effect = mock_embedding_generation
299333

300334
embedder = HuggingFaceAPIDocumentEmbedder(

Diff for: test/components/embedders/test_hugging_face_api_text_embedder.py

+33-11
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import random
88
import pytest
99
from huggingface_hub.utils import RepositoryNotFoundError
10-
10+
from numpy import array
1111
from haystack.components.embedders import HuggingFaceAPITextEmbedder
1212
from haystack.utils.auth import Secret
1313
from haystack.utils.hf import HFEmbeddingAPIType
@@ -21,11 +21,6 @@ def mock_check_valid_model():
2121
yield mock
2222

2323

24-
def mock_embedding_generation(json, **kwargs):
25-
response = str([[random.random() for _ in range(384)] for _ in range(len(json["inputs"]))]).encode()
26-
return response
27-
28-
2924
class TestHuggingFaceAPITextEmbedder:
3025
def test_init_invalid_api_type(self):
3126
with pytest.raises(ValueError):
@@ -141,9 +136,9 @@ def test_run_wrong_input_format(self, mock_check_valid_model):
141136
with pytest.raises(TypeError):
142137
embedder.run(text=list_integers_input)
143138

144-
def test_run(self, mock_check_valid_model):
145-
with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch:
146-
mock_embedding_patch.side_effect = mock_embedding_generation
139+
def test_run(self, mock_check_valid_model, recwarn):
140+
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
141+
mock_embedding_patch.return_value = array([[random.random() for _ in range(384)]])
147142

148143
embedder = HuggingFaceAPITextEmbedder(
149144
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API,
@@ -156,13 +151,40 @@ def test_run(self, mock_check_valid_model):
156151
result = embedder.run(text="The food was delicious")
157152

158153
mock_embedding_patch.assert_called_once_with(
159-
json={"inputs": ["prefix The food was delicious suffix"], "truncate": True, "normalize": False},
160-
task="feature-extraction",
154+
text="prefix The food was delicious suffix", truncate=None, normalize=None
161155
)
162156

163157
assert len(result["embedding"]) == 384
164158
assert all(isinstance(x, float) for x in result["embedding"])
165159

160+
# Check that warnings about ignoring truncate and normalize are raised
161+
assert len(recwarn) == 2
162+
assert "truncate" in str(recwarn[0].message)
163+
assert "normalize" in str(recwarn[1].message)
164+
165+
def test_run_wrong_embedding_shape(self, mock_check_valid_model):
166+
# embedding ndim > 2
167+
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
168+
mock_embedding_patch.return_value = array([[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]])
169+
170+
embedder = HuggingFaceAPITextEmbedder(
171+
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
172+
)
173+
174+
with pytest.raises(ValueError):
175+
embedder.run(text="The food was delicious")
176+
177+
# embedding ndim == 2 but shape[0] != 1
178+
with patch("huggingface_hub.InferenceClient.feature_extraction") as mock_embedding_patch:
179+
mock_embedding_patch.return_value = array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
180+
181+
embedder = HuggingFaceAPITextEmbedder(
182+
api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"}
183+
)
184+
185+
with pytest.raises(ValueError):
186+
embedder.run(text="The food was delicious")
187+
166188
@pytest.mark.flaky(reruns=5, reruns_delay=5)
167189
@pytest.mark.integration
168190
@pytest.mark.skipif(

0 commit comments

Comments
 (0)