Skip to content

Commit 7c17004

Browse files
fix: fix Azure AI types + add py.typed (#2003)
* fix: fix Azure AI types + add py.typed * fix --------- Co-authored-by: David S. Batista <[email protected]>
1 parent 1646ff6 commit 7c17004

File tree

9 files changed

+50
-67
lines changed

9 files changed

+50
-67
lines changed

.github/workflows/azure_ai_search.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,9 @@ jobs:
4747
- name: Install Hatch
4848
run: pip install --upgrade hatch
4949

50-
# TODO: Once this integration is properly typed, use hatch run test:types
51-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
5250
- name: Lint
5351
if: matrix.python-version == '3.9' && runner.os == 'Linux'
54-
run: hatch run fmt-check && hatch run lint:typing
52+
run: hatch run fmt-check && hatch run test:types
5553

5654
- name: Generate docs
5755
if: matrix.python-version == '3.9' && runner.os == 'Linux'

integrations/azure_ai_search/pyproject.toml

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,14 @@ integration = 'pytest -m "integration" {args:tests}'
6666
all = 'pytest {args:tests}'
6767
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
6868

69-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
69+
types = """mypy -p haystack_integrations.document_stores.azure_ai_search \
70+
-p haystack_integrations.components.retrievers.azure_ai_search {args}"""
7071

71-
# TODO: remove lint environment once this integration is properly typed
72-
# test environment should be used instead
73-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
74-
[tool.hatch.envs.lint]
75-
detached = true
76-
dependencies = ["mypy>=1.0.0", "ruff>=0.0.243"]
77-
78-
[tool.hatch.envs.lint.scripts]
79-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
72+
[tool.mypy]
73+
install_types = true
74+
non_interactive = true
75+
check_untyped_defs = true
76+
disallow_incomplete_defs = true
8077

8178
[tool.hatch.metadata]
8279
allow-direct-references = true
@@ -161,9 +158,5 @@ exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
161158

162159
[tool.pytest.ini_options]
163160
minversion = "6.0"
164-
markers = ["unit: unit tests", "integration: integration tests"]
165-
pythonpath = ["src"]
166-
167-
[[tool.mypy.overrides]]
168-
module = ["haystack.*", "haystack_integrations.*", "pytest.*", "azure.identity.*", "mypy.*", "azure.core.*", "azure.search.documents.*"]
169-
ignore_missing_imports = true
161+
markers = ["integration: integration tests"]
162+
pythonpath = ["src"]

integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(
2424
filters: Optional[Dict[str, Any]] = None,
2525
top_k: int = 10,
2626
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
27-
**kwargs,
27+
**kwargs: Any,
2828
):
2929
"""
3030
Create the AzureAISearchBM25Retriever component.
@@ -96,7 +96,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchBM25Retriever":
9696
return default_from_dict(cls, data)
9797

9898
@component.output_types(documents=List[Document])
99-
def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
99+
def run(
100+
self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None
101+
) -> Dict[str, List[Document]]:
100102
"""Retrieve documents from the AzureAISearchDocumentStore.
101103
102104
:param query: Text of the query.
@@ -111,11 +113,11 @@ def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optio
111113

112114
top_k = top_k or self._top_k
113115
filters = filters or self._filters
114-
if filters:
115-
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
116+
117+
normalized_filters = ""
118+
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
119+
if applied_filters:
116120
normalized_filters = _normalize_filters(applied_filters)
117-
else:
118-
normalized_filters = ""
119121

120122
try:
121123
docs = self._document_store._bm25_retrieval(

integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(
2424
filters: Optional[Dict[str, Any]] = None,
2525
top_k: int = 10,
2626
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
27-
**kwargs,
27+
**kwargs: Any,
2828
):
2929
"""
3030
Create the AzureAISearchEmbeddingRetriever component.
@@ -93,7 +93,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever":
9393
return default_from_dict(cls, data)
9494

9595
@component.output_types(documents=List[Document])
96-
def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
96+
def run(
97+
self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None
98+
) -> Dict[str, List[Document]]:
9799
"""Retrieve documents from the AzureAISearchDocumentStore.
98100
99101
:param query_embedding: A list of floats representing the query embedding.
@@ -107,11 +109,11 @@ def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] =
107109

108110
top_k = top_k or self._top_k
109111
filters = filters or self._filters
110-
if filters:
111-
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
112+
113+
normalized_filters = ""
114+
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
115+
if applied_filters:
112116
normalized_filters = _normalize_filters(applied_filters)
113-
else:
114-
normalized_filters = ""
115117

116118
try:
117119
docs = self._document_store._embedding_retrieval(

integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(
2424
filters: Optional[Dict[str, Any]] = None,
2525
top_k: int = 10,
2626
filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE,
27-
**kwargs,
27+
**kwargs: Any,
2828
):
2929
"""
3030
Create the AzureAISearchHybridRetriever component.
@@ -102,7 +102,7 @@ def run(
102102
query_embedding: List[float],
103103
filters: Optional[Dict[str, Any]] = None,
104104
top_k: Optional[int] = None,
105-
):
105+
) -> Dict[str, List[Document]]:
106106
"""Retrieve documents from the AzureAISearchDocumentStore.
107107
108108
:param query: Text of the query.
@@ -118,11 +118,11 @@ def run(
118118

119119
top_k = top_k or self._top_k
120120
filters = filters or self._filters
121-
if filters:
122-
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
121+
122+
normalized_filters = ""
123+
applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters)
124+
if applied_filters:
123125
normalized_filters = _normalize_filters(applied_filters)
124-
else:
125-
normalized_filters = ""
126126

127127
try:
128128
docs = self._document_store._hybrid_retrieval(

integrations/azure_ai_search/src/haystack_integrations/components/retrievers/py.typed

Whitespace-only changes.

integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
import logging as python_logging
5-
import os
65
from datetime import datetime
7-
from typing import Any, Dict, List, Optional, Union
6+
from typing import Any, Dict, List, Optional, Type, Union
87

98
from azure.core.credentials import AzureKeyCredential
109
from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceNotFoundError
1110
from azure.core.pipeline.policies import UserAgentPolicy
1211
from azure.identity import DefaultAzureCredential
1312
from azure.search.documents import SearchClient
1413
from azure.search.documents.indexes import SearchIndexClient
14+
from azure.search.documents.indexes._generated._serialization import Model
1515
from azure.search.documents.indexes.models import (
1616
CharFilter,
1717
CorsOptions,
@@ -53,7 +53,7 @@
5353
}
5454

5555
# Map of expected field names to their corresponding classes
56-
AZURE_CLASS_MAPPING = {
56+
AZURE_CLASS_MAPPING: Dict[str, Type[Model]] = {
5757
"suggesters": SearchSuggester,
5858
"analyzers": LexicalAnalyzer,
5959
"tokenizers": LexicalTokenizer,
@@ -94,7 +94,7 @@ def __init__(
9494
embedding_dimension: int = 768,
9595
metadata_fields: Optional[Dict[str, Union[SearchField, type]]] = None,
9696
vector_search_configuration: Optional[VectorSearch] = None,
97-
**index_creation_kwargs,
97+
**index_creation_kwargs: Any,
9898
):
9999
"""
100100
A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/)
@@ -133,16 +133,8 @@ def __init__(
133133
134134
For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/).
135135
"""
136-
137-
azure_endpoint = azure_endpoint or os.environ.get("AZURE_AI_SEARCH_ENDPOINT") or None
138-
if not azure_endpoint:
139-
msg = "Please provide an Azure endpoint or set the environment variable AZURE_AI_SEARCH_ENDPOINT."
140-
raise ValueError(msg)
141-
142-
api_key = api_key or os.environ.get("AZURE_AI_SEARCH_API_KEY") or None
143-
144-
self._client = None
145-
self._index_client = None
136+
self._client: Optional[SearchClient] = None
137+
self._index_client: Optional[SearchIndexClient] = None
146138
self._index_fields = [] # type: List[Any] # stores all fields in the final schema of index
147139
self._api_key = api_key
148140
self._azure_endpoint = azure_endpoint
@@ -155,11 +147,8 @@ def __init__(
155147

156148
@property
157149
def client(self) -> SearchClient:
158-
# resolve secrets for authentication
159-
resolved_endpoint = (
160-
self._azure_endpoint.resolve_value() if isinstance(self._azure_endpoint, Secret) else self._azure_endpoint
161-
)
162-
resolved_key = self._api_key.resolve_value() if isinstance(self._api_key, Secret) else self._api_key
150+
resolved_endpoint = self._azure_endpoint.resolve_value()
151+
resolved_key = self._api_key.resolve_value()
163152

164153
credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential()
165154

@@ -168,8 +157,9 @@ def client(self) -> SearchClient:
168157
try:
169158
if not self._index_client:
170159
self._index_client = SearchIndexClient(
171-
resolved_endpoint,
172-
credential,
160+
# resolve_value, with Secret.from_env_var (strict=True), returns a string or raises an error
161+
endpoint=resolved_endpoint, # type: ignore[arg-type]
162+
credential=credential,
173163
user_agent=ua_policy,
174164
)
175165
if not self._index_exists(self._index_name):
@@ -287,7 +277,7 @@ def _deserialize_index_creation_kwargs(cls, data: Dict[str, Any]) -> Any:
287277
"""
288278
Deserializes the index creation kwargs to the original classes.
289279
"""
290-
result = {}
280+
result: Dict[str, Union[List[Model], Model]] = {}
291281
for key, value in data.items():
292282
if key in AZURE_CLASS_MAPPING:
293283
if isinstance(value, list):
@@ -337,7 +327,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore":
337327
else:
338328
data["init_parameters"]["metadata_fields"] = {}
339329

340-
for key, _value in AZURE_CLASS_MAPPING.items():
330+
for key in AZURE_CLASS_MAPPING:
341331
if key in data["init_parameters"]:
342332
param_value = data["init_parameters"].get(key)
343333
data["init_parameters"][key] = cls._deserialize_index_creation_kwargs({key: param_value})
@@ -421,7 +411,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
421411
if filters:
422412
normalized_filters = _normalize_filters(filters)
423413
result = self.client.search(filter=normalized_filters)
424-
return self._convert_search_result_to_documents(result)
414+
return self._convert_search_result_to_documents(list(result))
425415
else:
426416
return self.search_documents()
427417

@@ -465,7 +455,7 @@ def _index_exists(self, index_name: Optional[str]) -> bool:
465455
msg = "Index name is required to check if the index exists."
466456
raise ValueError(msg)
467457

468-
def _get_raw_documents_by_id(self, document_ids: List[str]):
458+
def _get_raw_documents_by_id(self, document_ids: List[str]) -> List[Dict]:
469459
"""
470460
Retrieves all Azure documents with a matching document_ids from the document store.
471461
@@ -499,7 +489,7 @@ def _embedding_retrieval(
499489
*,
500490
top_k: int = 10,
501491
filters: Optional[str] = None,
502-
**kwargs,
492+
**kwargs: Any,
503493
) -> List[Document]:
504494
"""
505495
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
@@ -533,7 +523,7 @@ def _bm25_retrieval(
533523
query: str,
534524
top_k: int = 10,
535525
filters: Optional[str] = None,
536-
**kwargs,
526+
**kwargs: Any,
537527
) -> List[Document]:
538528
"""
539529
Retrieves documents that are most similar to `query`, using the BM25 algorithm.
@@ -566,7 +556,7 @@ def _hybrid_retrieval(
566556
query_embedding: List[float],
567557
top_k: int = 10,
568558
filters: Optional[str] = None,
569-
**kwargs,
559+
**kwargs: Any,
570560
) -> List[Document]:
571561
"""
572562
Retrieves documents similar to query using the vector configuration in the document store and

integrations/azure_ai_search/src/haystack_integrations/document_stores/py.typed

Whitespace-only changes.

integrations/azure_ai_search/tests/conftest.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def document_store(request):
3737
client.delete_index(index_name)
3838

3939
store = AzureAISearchDocumentStore(
40-
api_key=api_key,
41-
azure_endpoint=azure_endpoint,
4240
index_name=index_name,
4341
create_index=True,
4442
embedding_dimension=768,

0 commit comments

Comments
 (0)