Skip to content

Commit c192488

Browse files
authored
Named entity extractor private models (#8658)
* add 'token' support to NamedEntityExtractor to enable using private models on HF backend * fix existing error message format * add release note * add HF_API_TOKEN to e2e workflow * add informative comment * Updated to_dict / from_dict to handle 'token' correctly ; Added tests * Fix lint * Revert unwanted change
1 parent 286061f commit c192488

File tree

5 files changed

+106
-6
lines changed

5 files changed

+106
-6
lines changed

Diff for: .github/workflows/e2e.yml

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ env:
1919
PYTHON_VERSION: "3.9"
2020
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
2121
HATCH_VERSION: "1.13.0"
22+
HF_API_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }}
2223

2324
jobs:
2425
run:

Diff for: e2e/pipelines/test_named_entity_extractor.py

+13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import os
56
import pytest
67

78
from haystack import Document, Pipeline
@@ -65,6 +66,18 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
6566
_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)
6667

6768

69+
@pytest.mark.parametrize("batch_size", [1, 3])
70+
@pytest.mark.skipif(
71+
not os.environ.get("HF_API_TOKEN", None),
72+
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
73+
)
74+
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
75+
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
76+
extractor.warm_up()
77+
78+
_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)
79+
80+
6881
@pytest.mark.parametrize("batch_size", [1, 3])
6982
def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size):
7083
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_trf")

Diff for: haystack/components/extractors/named_entity_extractor.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
1212
from haystack.lazy_imports import LazyImport
13+
from haystack.utils.auth import Secret, deserialize_secrets_inplace
1314
from haystack.utils.device import ComponentDevice
15+
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs
1416

1517
with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
1618
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
@@ -110,6 +112,7 @@ def __init__(
110112
model: str,
111113
pipeline_kwargs: Optional[Dict[str, Any]] = None,
112114
device: Optional[ComponentDevice] = None,
115+
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
113116
) -> None:
114117
"""
115118
Create a Named Entity extractor component.
@@ -128,16 +131,28 @@ def __init__(
128131
device/device map is specified in `pipeline_kwargs`,
129132
it overrides this parameter (only applicable to the
130133
HuggingFace backend).
134+
:param token:
135+
The API token to download private models from Hugging Face.
131136
"""
132137

133138
if isinstance(backend, str):
134139
backend = NamedEntityExtractorBackend.from_str(backend)
135140

136141
self._backend: _NerBackend
137142
self._warmed_up: bool = False
143+
self.token = token
138144
device = ComponentDevice.resolve_device(device)
139145

140146
if backend == NamedEntityExtractorBackend.HUGGING_FACE:
147+
pipeline_kwargs = resolve_hf_pipeline_kwargs(
148+
huggingface_pipeline_kwargs=pipeline_kwargs or {},
149+
model=model,
150+
task="ner",
151+
supported_tasks=["ner"],
152+
device=device,
153+
token=token,
154+
)
155+
141156
self._backend = _HfBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
142157
elif backend == NamedEntityExtractorBackend.SPACY:
143158
self._backend = _SpacyBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
@@ -159,7 +174,7 @@ def warm_up(self):
159174
self._warmed_up = True
160175
except Exception as e:
161176
raise ComponentError(
162-
f"Named entity extractor with backend '{self._backend.type} failed to initialize."
177+
f"Named entity extractor with backend '{self._backend.type}' failed to initialize."
163178
) from e
164179

165180
@component.output_types(documents=List[Document])
@@ -201,14 +216,21 @@ def to_dict(self) -> Dict[str, Any]:
201216
:returns:
202217
Dictionary with serialized data.
203218
"""
204-
return default_to_dict(
219+
serialization_dict = default_to_dict(
205220
self,
206221
backend=self._backend.type.name,
207222
model=self._backend.model_name,
208223
device=self._backend.device.to_dict(),
209224
pipeline_kwargs=self._backend._pipeline_kwargs,
225+
token=self.token.to_dict() if self.token else None,
210226
)
211227

228+
hf_pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"]
229+
hf_pipeline_kwargs.pop("token", None)
230+
231+
serialize_hf_model_kwargs(hf_pipeline_kwargs)
232+
return serialization_dict
233+
212234
@classmethod
213235
def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
214236
"""
@@ -220,10 +242,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
220242
Deserialized component.
221243
"""
222244
try:
223-
init_params = data["init_parameters"]
245+
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
246+
init_params = data.get("init_parameters", {})
224247
if init_params.get("device") is not None:
225248
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
226249
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]
250+
251+
hf_pipeline_kwargs = init_params.get("pipeline_kwargs", {})
252+
deserialize_hf_model_kwargs(hf_pipeline_kwargs)
227253
return default_from_dict(cls, data)
228254
except Exception as e:
229255
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
@@ -352,8 +378,9 @@ def __init__(
352378
self.pipeline: Optional[HfPipeline] = None
353379

354380
def initialize(self):
355-
self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path)
356-
self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path)
381+
token = self._pipeline_kwargs.get("token", None)
382+
self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path, token=token)
383+
self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path, token=token)
357384

358385
pipeline_params = {
359386
"task": "ner",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
enhancements:
3+
- |
4+
Add `token` argument to `NamedEntityExtractor` to allow usage of private Hugging Face models.

Diff for: test/components/extractors/test_named_entity_extractor.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
from haystack.utils.auth import Secret
45
import pytest
56

67
from haystack import ComponentError, DeserializationError, Pipeline
@@ -11,6 +12,9 @@
1112
def test_named_entity_extractor_backend():
1213
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
1314

15+
# private model
16+
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
17+
1418
_ = NamedEntityExtractor(backend="hugging_face", model="dslim/bert-base-NER")
1519

1620
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_sm")
@@ -40,7 +44,58 @@ def test_named_entity_extractor_serde():
4044
_ = NamedEntityExtractor.from_dict(serde_data)
4145

4246

43-
def test_named_entity_extractor_from_dict_no_default_parameters_hf():
47+
def test_to_dict_default(monkeypatch):
48+
monkeypatch.delenv("HF_API_TOKEN", raising=False)
49+
50+
component = NamedEntityExtractor(
51+
backend=NamedEntityExtractorBackend.HUGGING_FACE,
52+
model="dslim/bert-base-NER",
53+
device=ComponentDevice.from_str("mps"),
54+
)
55+
data = component.to_dict()
56+
57+
assert data == {
58+
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
59+
"init_parameters": {
60+
"backend": "HUGGING_FACE",
61+
"model": "dslim/bert-base-NER",
62+
"device": {"type": "single", "device": "mps"},
63+
"pipeline_kwargs": {"model": "dslim/bert-base-NER", "device": "mps", "task": "ner"},
64+
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
65+
},
66+
}
67+
68+
69+
def test_to_dict_with_parameters():
70+
component = NamedEntityExtractor(
71+
backend=NamedEntityExtractorBackend.HUGGING_FACE,
72+
model="dslim/bert-base-NER",
73+
device=ComponentDevice.from_str("mps"),
74+
pipeline_kwargs={"model_kwargs": {"load_in_4bit": True}},
75+
token=Secret.from_env_var("ENV_VAR", strict=False),
76+
)
77+
data = component.to_dict()
78+
79+
assert data == {
80+
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
81+
"init_parameters": {
82+
"backend": "HUGGING_FACE",
83+
"model": "dslim/bert-base-NER",
84+
"device": {"type": "single", "device": "mps"},
85+
"pipeline_kwargs": {
86+
"model": "dslim/bert-base-NER",
87+
"device": "mps",
88+
"task": "ner",
89+
"model_kwargs": {"load_in_4bit": True},
90+
},
91+
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
92+
},
93+
}
94+
95+
96+
def test_named_entity_extractor_from_dict_no_default_parameters_hf(monkeypatch):
97+
monkeypatch.delenv("HF_API_TOKEN", raising=False)
98+
4499
data = {
45100
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
46101
"init_parameters": {"backend": "HUGGING_FACE", "model": "dslim/bert-base-NER"},

0 commit comments

Comments
 (0)