10
10
11
11
from haystack import ComponentError , DeserializationError , Document , component , default_from_dict , default_to_dict
12
12
from haystack .lazy_imports import LazyImport
13
+ from haystack .utils .auth import Secret , deserialize_secrets_inplace
13
14
from haystack .utils .device import ComponentDevice
15
+ from haystack .utils .hf import deserialize_hf_model_kwargs , resolve_hf_pipeline_kwargs , serialize_hf_model_kwargs
14
16
15
17
with LazyImport (message = "Run 'pip install \" transformers[torch]\" '" ) as transformers_import :
16
18
from transformers import AutoModelForTokenClassification , AutoTokenizer , pipeline
@@ -110,6 +112,7 @@ def __init__(
110
112
model : str ,
111
113
pipeline_kwargs : Optional [Dict [str , Any ]] = None ,
112
114
device : Optional [ComponentDevice ] = None ,
115
+ token : Optional [Secret ] = Secret .from_env_var (["HF_API_TOKEN" , "HF_TOKEN" ], strict = False ),
113
116
) -> None :
114
117
"""
115
118
Create a Named Entity extractor component.
@@ -128,16 +131,28 @@ def __init__(
128
131
device/device map is specified in `pipeline_kwargs`,
129
132
it overrides this parameter (only applicable to the
130
133
HuggingFace backend).
134
+ :param token:
135
+ The API token to download private models from Hugging Face.
131
136
"""
132
137
133
138
if isinstance (backend , str ):
134
139
backend = NamedEntityExtractorBackend .from_str (backend )
135
140
136
141
self ._backend : _NerBackend
137
142
self ._warmed_up : bool = False
143
+ self .token = token
138
144
device = ComponentDevice .resolve_device (device )
139
145
140
146
if backend == NamedEntityExtractorBackend .HUGGING_FACE :
147
+ pipeline_kwargs = resolve_hf_pipeline_kwargs (
148
+ huggingface_pipeline_kwargs = pipeline_kwargs or {},
149
+ model = model ,
150
+ task = "ner" ,
151
+ supported_tasks = ["ner" ],
152
+ device = device ,
153
+ token = token ,
154
+ )
155
+
141
156
self ._backend = _HfBackend (model_name_or_path = model , device = device , pipeline_kwargs = pipeline_kwargs )
142
157
elif backend == NamedEntityExtractorBackend .SPACY :
143
158
self ._backend = _SpacyBackend (model_name_or_path = model , device = device , pipeline_kwargs = pipeline_kwargs )
@@ -159,7 +174,7 @@ def warm_up(self):
159
174
self ._warmed_up = True
160
175
except Exception as e :
161
176
raise ComponentError (
162
- f"Named entity extractor with backend '{ self ._backend .type } failed to initialize."
177
+ f"Named entity extractor with backend '{ self ._backend .type } ' failed to initialize."
163
178
) from e
164
179
165
180
@component .output_types (documents = List [Document ])
@@ -201,14 +216,21 @@ def to_dict(self) -> Dict[str, Any]:
201
216
:returns:
202
217
Dictionary with serialized data.
203
218
"""
204
- return default_to_dict (
219
+ serialization_dict = default_to_dict (
205
220
self ,
206
221
backend = self ._backend .type .name ,
207
222
model = self ._backend .model_name ,
208
223
device = self ._backend .device .to_dict (),
209
224
pipeline_kwargs = self ._backend ._pipeline_kwargs ,
225
+ token = self .token .to_dict () if self .token else None ,
210
226
)
211
227
228
+ hf_pipeline_kwargs = serialization_dict ["init_parameters" ]["pipeline_kwargs" ]
229
+ hf_pipeline_kwargs .pop ("token" , None )
230
+
231
+ serialize_hf_model_kwargs (hf_pipeline_kwargs )
232
+ return serialization_dict
233
+
212
234
@classmethod
213
235
def from_dict (cls , data : Dict [str , Any ]) -> "NamedEntityExtractor" :
214
236
"""
@@ -220,10 +242,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
220
242
Deserialized component.
221
243
"""
222
244
try :
223
- init_params = data ["init_parameters" ]
245
+ deserialize_secrets_inplace (data ["init_parameters" ], keys = ["token" ])
246
+ init_params = data .get ("init_parameters" , {})
224
247
if init_params .get ("device" ) is not None :
225
248
init_params ["device" ] = ComponentDevice .from_dict (init_params ["device" ])
226
249
init_params ["backend" ] = NamedEntityExtractorBackend [init_params ["backend" ]]
250
+
251
+ hf_pipeline_kwargs = init_params .get ("pipeline_kwargs" , {})
252
+ deserialize_hf_model_kwargs (hf_pipeline_kwargs )
227
253
return default_from_dict (cls , data )
228
254
except Exception as e :
229
255
raise DeserializationError (f"Couldn't deserialize { cls .__name__ } instance" ) from e
@@ -352,8 +378,9 @@ def __init__(
352
378
self .pipeline : Optional [HfPipeline ] = None
353
379
354
380
def initialize (self ):
355
- self .tokenizer = AutoTokenizer .from_pretrained (self ._model_name_or_path )
356
- self .model = AutoModelForTokenClassification .from_pretrained (self ._model_name_or_path )
381
+ token = self ._pipeline_kwargs .get ("token" , None )
382
+ self .tokenizer = AutoTokenizer .from_pretrained (self ._model_name_or_path , token = token )
383
+ self .model = AutoModelForTokenClassification .from_pretrained (self ._model_name_or_path , token = token )
357
384
358
385
pipeline_params = {
359
386
"task" : "ner" ,
0 commit comments