Skip to content

Commit 3a91d2b

Browse files
authored
fix: Astra - fix types + add py.typed (#2011)
1 parent 415a2b5 commit 3a91d2b

File tree

7 files changed

+31
-40
lines changed

7 files changed

+31
-40
lines changed

.github/workflows/astra.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,9 @@ jobs:
5050
- name: Install Hatch
5151
run: pip install --upgrade hatch
5252

53-
# TODO: Once this integration is properly typed, use hatch run test:types
54-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
5553
- name: Lint
5654
if: matrix.python-version == '3.9' && runner.os == 'Linux'
57-
run: hatch run fmt-check && hatch run lint:typing
55+
run: hatch run fmt-check && hatch run test:types
5856

5957
- name: Generate docs
6058
if: matrix.python-version == '3.9' && runner.os == 'Linux'

integrations/astra/pyproject.toml

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,14 @@ unit = 'pytest -m "not integration" {args:tests}'
6464
integration = 'pytest -m "integration" {args:tests}'
6565
all = 'pytest {args:tests}'
6666
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
67-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
67+
types = """mypy -p haystack_integrations.document_stores.astra \
68+
-p haystack_integrations.components.retrievers.astra {args}"""
6869

69-
# TODO: remove lint environment once this integration is properly typed
70-
# test environment should be used instead
71-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
72-
[tool.hatch.envs.lint]
73-
installer = "uv"
74-
detached = true
75-
dependencies = ["pip", "mypy>=1.0.0", "ruff>=0.0.243"]
76-
[tool.hatch.envs.lint.scripts]
77-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
70+
[tool.mypy]
71+
install_types = true
72+
non_interactive = true
73+
check_untyped_defs = true
74+
disallow_incomplete_defs = true
7875

7976
[tool.hatch.metadata]
8077
allow-direct-references = true
@@ -156,16 +153,4 @@ exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
156153

157154
[tool.pytest.ini_options]
158155
minversion = "6.0"
159-
markers = ["unit: unit tests", "integration: integration tests"]
160-
161-
[[tool.mypy.overrides]]
162-
module = [
163-
"astra_client.*",
164-
"astrapy.*",
165-
"pydantic.*",
166-
"haystack.*",
167-
"haystack_integrations.*",
168-
"pytest.*",
169-
"openpyxl.*",
170-
]
171-
ignore_missing_imports = true
156+
markers = ["integration: integration tests"]

integrations/astra/src/haystack_integrations/components/retrievers/astra/retriever.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,12 @@ def __init__(
5858
raise Exception(message)
5959

6060
@component.output_types(documents=List[Document])
61-
def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
61+
def run(
62+
self,
63+
query_embedding: List[float],
64+
filters: Optional[Dict[str, Any]] = None,
65+
top_k: Optional[int] = None,
66+
) -> Dict[str, List[Document]]:
6267
"""Retrieve documents from the AstraDocumentStore.
6368
6469
:param query_embedding: floats representing the query embedding

integrations/astra/src/haystack_integrations/components/retrievers/py.typed

Whitespace-only changes.

integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def batch_generator(chunks, batch_size):
278278

279279
return formatted_docs
280280

281-
def insert(self, documents: List[Dict]):
281+
def insert(self, documents: List[Dict]) -> List[str]:
282282
"""
283283
Insert documents into the Astra index.
284284
@@ -290,7 +290,7 @@ def insert(self, documents: List[Dict]):
290290

291291
return inserted_ids
292292

293-
def update_document(self, document: Dict, id_key: str):
293+
def update_document(self, document: Dict, id_key: str) -> bool:
294294
"""
295295
Update a document in the Astra index.
296296

integrations/astra/src/haystack_integrations/document_stores/astra/document_store.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from haystack.document_stores.types import DuplicatePolicy
1010
from haystack.utils import Secret, deserialize_secrets_inplace
1111

12-
from .astra_client import AstraClient
12+
from .astra_client import AstraClient, QueryResponse
1313
from .errors import AstraDocumentStoreFilterError
1414
from .filters import _convert_filters
1515

@@ -150,7 +150,7 @@ def write_documents(
150150
self,
151151
documents: List[Document],
152152
policy: DuplicatePolicy = DuplicatePolicy.NONE,
153-
):
153+
) -> int:
154154
"""
155155
Indexes documents for later queries.
156156
@@ -176,7 +176,7 @@ def write_documents(
176176

177177
batch_size = MAX_BATCH_SIZE
178178

179-
def _convert_input_document(document: Union[dict, Document]):
179+
def _convert_input_document(document: Union[dict, Document]) -> Dict[str, Any]:
180180
if isinstance(document, Document):
181181
document_dict = document.to_dict(flatten=False)
182182
elif isinstance(document, dict):
@@ -217,7 +217,7 @@ def _convert_input_document(document: Union[dict, Document]):
217217
documents_to_write = [_convert_input_document(doc) for doc in documents]
218218

219219
duplicate_documents = []
220-
new_documents: List[Document] = []
220+
new_documents: List[Dict] = []
221221
i = 0
222222
while i < len(documents_to_write):
223223
doc = documents_to_write[i]
@@ -238,7 +238,7 @@ def _convert_input_document(document: Union[dict, Document]):
238238
if policy == DuplicatePolicy.SKIP:
239239
if len(new_documents) > 0:
240240
for batch in _batches(new_documents, batch_size):
241-
inserted_ids = self.index.insert(batch) # type: ignore
241+
inserted_ids = self.index.insert(batch)
242242
insertion_counter += len(inserted_ids)
243243
logger.info(f"write_documents inserted documents with id {inserted_ids}")
244244
else:
@@ -247,7 +247,7 @@ def _convert_input_document(document: Union[dict, Document]):
247247
elif policy == DuplicatePolicy.OVERWRITE:
248248
if len(new_documents) > 0:
249249
for batch in _batches(new_documents, batch_size):
250-
inserted_ids = self.index.insert(batch) # type: ignore
250+
inserted_ids = self.index.insert(batch)
251251
insertion_counter += len(inserted_ids)
252252
logger.info(f"write_documents inserted documents with id {inserted_ids}")
253253
else:
@@ -256,7 +256,7 @@ def _convert_input_document(document: Union[dict, Document]):
256256
if len(duplicate_documents) > 0:
257257
updated_ids = []
258258
for duplicate_doc in duplicate_documents:
259-
updated = self.index.update_document(duplicate_doc, "_id") # type: ignore
259+
updated = self.index.update_document(duplicate_doc, "_id")
260260
if updated:
261261
updated_ids.append(duplicate_doc["_id"])
262262
insertion_counter = insertion_counter + len(updated_ids)
@@ -267,7 +267,7 @@ def _convert_input_document(document: Union[dict, Document]):
267267
elif policy == DuplicatePolicy.FAIL:
268268
if len(new_documents) > 0:
269269
for batch in _batches(new_documents, batch_size):
270-
inserted_ids = self.index.insert(batch) # type: ignore
270+
inserted_ids = self.index.insert(batch)
271271
insertion_counter = insertion_counter + len(inserted_ids)
272272
logger.info(f"write_documents inserted documents with id {inserted_ids}")
273273
else:
@@ -326,15 +326,18 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
326326
return documents
327327

328328
@staticmethod
329-
def _get_result_to_documents(results) -> List[Document]:
329+
def _get_result_to_documents(results: QueryResponse) -> List[Document]:
330330
documents = []
331331
for match in results.matches:
332+
metadata = match.metadata
333+
blob = metadata.pop("blob", None) if metadata else None
334+
meta = metadata.pop("meta", {}) if metadata else {}
332335
document = Document(
333336
content=match.text,
334337
id=match.document_id,
335338
embedding=match.values,
336-
blob=match.metadata.pop("blob", None),
337-
meta=match.metadata.pop("meta", None),
339+
blob=blob,
340+
meta=meta,
338341
score=match.score,
339342
)
340343
documents.append(document)

integrations/astra/src/haystack_integrations/document_stores/py.typed

Whitespace-only changes.

0 commit comments

Comments
 (0)