2
2
#
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
import logging as python_logging
5
- import os
6
5
from datetime import datetime
7
- from typing import Any , Dict , List , Optional , Union
6
+ from typing import Any , Dict , List , Optional , Type , Union
8
7
9
8
from azure .core .credentials import AzureKeyCredential
10
9
from azure .core .exceptions import ClientAuthenticationError , HttpResponseError , ResourceNotFoundError
11
10
from azure .core .pipeline .policies import UserAgentPolicy
12
11
from azure .identity import DefaultAzureCredential
13
12
from azure .search .documents import SearchClient
14
13
from azure .search .documents .indexes import SearchIndexClient
14
+ from azure .search .documents .indexes ._generated ._serialization import Model
15
15
from azure .search .documents .indexes .models import (
16
16
CharFilter ,
17
17
CorsOptions ,
53
53
}
54
54
55
55
# Map of expected field names to their corresponding classes
56
- AZURE_CLASS_MAPPING = {
56
+ AZURE_CLASS_MAPPING : Dict [ str , Type [ Model ]] = {
57
57
"suggesters" : SearchSuggester ,
58
58
"analyzers" : LexicalAnalyzer ,
59
59
"tokenizers" : LexicalTokenizer ,
@@ -94,7 +94,7 @@ def __init__(
94
94
embedding_dimension : int = 768 ,
95
95
metadata_fields : Optional [Dict [str , Union [SearchField , type ]]] = None ,
96
96
vector_search_configuration : Optional [VectorSearch ] = None ,
97
- ** index_creation_kwargs ,
97
+ ** index_creation_kwargs : Any ,
98
98
):
99
99
"""
100
100
A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/)
@@ -133,16 +133,8 @@ def __init__(
133
133
134
134
For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/).
135
135
"""
136
-
137
- azure_endpoint = azure_endpoint or os .environ .get ("AZURE_AI_SEARCH_ENDPOINT" ) or None
138
- if not azure_endpoint :
139
- msg = "Please provide an Azure endpoint or set the environment variable AZURE_AI_SEARCH_ENDPOINT."
140
- raise ValueError (msg )
141
-
142
- api_key = api_key or os .environ .get ("AZURE_AI_SEARCH_API_KEY" ) or None
143
-
144
- self ._client = None
145
- self ._index_client = None
136
+ self ._client : Optional [SearchClient ] = None
137
+ self ._index_client : Optional [SearchIndexClient ] = None
146
138
self ._index_fields = [] # type: List[Any] # stores all fields in the final schema of index
147
139
self ._api_key = api_key
148
140
self ._azure_endpoint = azure_endpoint
@@ -155,11 +147,8 @@ def __init__(
155
147
156
148
@property
157
149
def client (self ) -> SearchClient :
158
- # resolve secrets for authentication
159
- resolved_endpoint = (
160
- self ._azure_endpoint .resolve_value () if isinstance (self ._azure_endpoint , Secret ) else self ._azure_endpoint
161
- )
162
- resolved_key = self ._api_key .resolve_value () if isinstance (self ._api_key , Secret ) else self ._api_key
150
+ resolved_endpoint = self ._azure_endpoint .resolve_value ()
151
+ resolved_key = self ._api_key .resolve_value ()
163
152
164
153
credential = AzureKeyCredential (resolved_key ) if resolved_key else DefaultAzureCredential ()
165
154
@@ -168,8 +157,9 @@ def client(self) -> SearchClient:
168
157
try :
169
158
if not self ._index_client :
170
159
self ._index_client = SearchIndexClient (
171
- resolved_endpoint ,
172
- credential ,
160
+ # resolve_value, with Secret.from_env_var (strict=True), returns a string or raises an error
161
+ endpoint = resolved_endpoint , # type: ignore[arg-type]
162
+ credential = credential ,
173
163
user_agent = ua_policy ,
174
164
)
175
165
if not self ._index_exists (self ._index_name ):
@@ -287,7 +277,7 @@ def _deserialize_index_creation_kwargs(cls, data: Dict[str, Any]) -> Any:
287
277
"""
288
278
Deserializes the index creation kwargs to the original classes.
289
279
"""
290
- result = {}
280
+ result : Dict [ str , Union [ List [ Model ], Model ]] = {}
291
281
for key , value in data .items ():
292
282
if key in AZURE_CLASS_MAPPING :
293
283
if isinstance (value , list ):
@@ -337,7 +327,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore":
337
327
else :
338
328
data ["init_parameters" ]["metadata_fields" ] = {}
339
329
340
- for key , _value in AZURE_CLASS_MAPPING . items () :
330
+ for key in AZURE_CLASS_MAPPING :
341
331
if key in data ["init_parameters" ]:
342
332
param_value = data ["init_parameters" ].get (key )
343
333
data ["init_parameters" ][key ] = cls ._deserialize_index_creation_kwargs ({key : param_value })
@@ -421,7 +411,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
421
411
if filters :
422
412
normalized_filters = _normalize_filters (filters )
423
413
result = self .client .search (filter = normalized_filters )
424
- return self ._convert_search_result_to_documents (result )
414
+ return self ._convert_search_result_to_documents (list ( result ) )
425
415
else :
426
416
return self .search_documents ()
427
417
@@ -465,7 +455,7 @@ def _index_exists(self, index_name: Optional[str]) -> bool:
465
455
msg = "Index name is required to check if the index exists."
466
456
raise ValueError (msg )
467
457
468
- def _get_raw_documents_by_id (self , document_ids : List [str ]):
458
+ def _get_raw_documents_by_id (self , document_ids : List [str ]) -> List [ Dict ] :
469
459
"""
470
460
Retrieves all Azure documents with a matching document_ids from the document store.
471
461
@@ -499,7 +489,7 @@ def _embedding_retrieval(
499
489
* ,
500
490
top_k : int = 10 ,
501
491
filters : Optional [str ] = None ,
502
- ** kwargs ,
492
+ ** kwargs : Any ,
503
493
) -> List [Document ]:
504
494
"""
505
495
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
@@ -533,7 +523,7 @@ def _bm25_retrieval(
533
523
query : str ,
534
524
top_k : int = 10 ,
535
525
filters : Optional [str ] = None ,
536
- ** kwargs ,
526
+ ** kwargs : Any ,
537
527
) -> List [Document ]:
538
528
"""
539
529
Retrieves documents that are most similar to `query`, using the BM25 algorithm.
@@ -566,7 +556,7 @@ def _hybrid_retrieval(
566
556
query_embedding : List [float ],
567
557
top_k : int = 10 ,
568
558
filters : Optional [str ] = None ,
569
- ** kwargs ,
559
+ ** kwargs : Any ,
570
560
) -> List [Document ]:
571
561
"""
572
562
Retrieves documents similar to query using the vector configuration in the document store and
0 commit comments