Skip to content

Commit d1162f2

Browse files
authored
fix: fix MongoDB types + add py.typed (#1967)
* fix: fix MongoDB types + add py.typed * try removing type:ignore * simplify methods
1 parent 0666586 commit d1162f2

File tree

5 files changed

+55
-61
lines changed

5 files changed

+55
-61
lines changed

.github/workflows/mongodb_atlas.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,9 @@ jobs:
4646
- name: Install Hatch
4747
run: pip install --upgrade hatch
4848

49-
# TODO: Once this integration is properly typed, use hatch run test:types
50-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
5149
- name: Lint
5250
if: matrix.python-version == '3.9' && runner.os == 'Linux'
53-
run: hatch run fmt-check && hatch run lint:typing
51+
run: hatch run fmt-check && hatch run test:types
5452

5553
- name: Generate docs
5654
if: matrix.python-version == '3.9' && runner.os == 'Linux'

integrations/mongodb_atlas/pyproject.toml

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,14 @@ integration = 'pytest -m "integration" {args:tests}'
6969
all = 'pytest {args:tests}'
7070
cov-retry = 'all --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x'
7171

72-
types = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
72+
types = """mypy -p haystack_integrations.components.retrievers.mongodb_atlas \
73+
-p haystack_integrations.document_stores.mongodb_atlas {args}"""
7374

74-
# TODO: remove lint environment once this integration is properly typed
75-
# test environment should be used instead
76-
# https://github.com/deepset-ai/haystack-core-integrations/issues/1771
77-
[tool.hatch.envs.lint]
78-
installer = "uv"
79-
detached = true
80-
dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]
81-
82-
[tool.hatch.envs.lint.scripts]
83-
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
75+
[tool.mypy]
76+
install_types = true
77+
non_interactive = true
78+
check_untyped_defs = true
79+
disallow_incomplete_defs = true
8480

8581
[tool.black]
8682
target-version = ["py38"]
@@ -134,6 +130,8 @@ ignore = [
134130
"PLR0912",
135131
"PLR0913",
136132
"PLR0915",
133+
# Allow assert statements
134+
"S101",
137135
]
138136
unfixable = [
139137
# Don't touch unused imports
@@ -163,11 +161,6 @@ omit = ["*/tests/*", "*/__init__.py"]
163161
show_missing = true
164162
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
165163

166-
167-
[[tool.mypy.overrides]]
168-
module = ["haystack.*", "haystack_integrations.*", "pymongo.*", "pytest.*"]
169-
ignore_missing_imports = true
170-
171164
[tool.pytest.ini_options]
172165
addopts = "--strict-markers"
173166
markers = [

integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/py.typed

Whitespace-only changes.

integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,6 @@ def __del__(self) -> None:
117117
if self._connection:
118118
self._connection.close()
119119

120-
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
121-
"""
122-
Asynchronous exit method to close MongoDB connections when the instance is destroyed.
123-
"""
124-
if self._connection_async:
125-
await self._connection_async.close()
126-
127120
@property
128121
def connection(self) -> Union[AsyncMongoClient, MongoClient]:
129122
if self._connection:
@@ -142,53 +135,51 @@ def collection(self) -> Union[AsyncCollection, Collection]:
142135
msg = "The collection is not established yet."
143136
raise DocumentStoreError(msg)
144137

145-
def _connection_is_valid(self) -> bool:
138+
def _connection_is_valid(self, connection: MongoClient) -> bool:
146139
"""
147140
Checks if the connection to MongoDB Atlas is valid.
148141
149142
:returns: True if the connection is valid, False otherwise.
150143
"""
151144
try:
152-
self._connection.admin.command("ping") # type: ignore[union-attr]
145+
connection.admin.command("ping")
153146
return True
154147
except Exception as e:
155148
logger.error(f"Connection to MongoDB Atlas failed: {e}")
156149
return False
157150

158-
async def _connection_is_valid_async(self) -> bool:
151+
async def _connection_is_valid_async(self, connection: AsyncMongoClient) -> bool:
159152
"""
160153
Asynchronously checks if the connection to MongoDB Atlas is valid.
161154
162155
:returns: True if the connection is valid, False otherwise.
163156
"""
164157
try:
165-
await self._connection_async.admin.command("ping") # type: ignore[union-attr]
158+
await connection.admin.command("ping")
166159
return True
167160
except Exception as e:
168161
logger.error(f"Connection to MongoDB Atlas failed: {e}")
169162
return False
170163

171-
def _collection_exists(self) -> bool:
164+
def _collection_exists(self, connection: MongoClient, database_name: str, collection_name: str) -> bool:
172165
"""
173166
Checks if the collection exists in the MongoDB Atlas database.
174167
175168
:returns: True if the collection exists, False otherwise.
176169
"""
177-
database = self._connection[self.database_name] # type: ignore[index]
178-
if self.collection_name in database.list_collection_names():
179-
return True
180-
return False
170+
database = connection[database_name]
171+
return collection_name in database.list_collection_names()
181172

182-
async def _collection_exists_async(self) -> bool:
173+
async def _collection_exists_async(
174+
self, connection: AsyncMongoClient, database_name: str, collection_name: str
175+
) -> bool:
183176
"""
184177
Asynchronously checks if the collection exists in the MongoDB Atlas database.
185178
186179
:returns: True if the collection exists, False otherwise.
187180
"""
188-
database = self._connection_async[self.database_name] # type: ignore[index]
189-
if self.collection_name in await database.list_collection_names():
190-
return True
191-
return False
181+
database = connection[database_name]
182+
return collection_name in await database.list_collection_names()
192183

193184
def _ensure_connection_setup(self) -> None:
194185
"""
@@ -202,11 +193,11 @@ def _ensure_connection_setup(self) -> None:
202193
self.mongo_connection_string.resolve_value(), driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
203194
)
204195

205-
if not self._connection_is_valid():
196+
if not self._connection_is_valid(self._connection):
206197
msg = "Connection to MongoDB Atlas failed."
207198
raise DocumentStoreError(msg)
208199

209-
if not self._collection_exists():
200+
if not self._collection_exists(self._connection, self.database_name, self.collection_name):
210201
msg = f"Collection '{self.collection_name}' does not exist in database '{self.database_name}'."
211202
raise DocumentStoreError(msg)
212203

@@ -226,11 +217,11 @@ async def _ensure_connection_setup_async(self) -> None:
226217
self.mongo_connection_string.resolve_value(), driver=DriverInfo(name="MongoDBAtlasHaystackIntegration")
227218
)
228219

229-
if not await self._connection_is_valid_async():
220+
if not await self._connection_is_valid_async(self._connection_async):
230221
msg = "Connection to MongoDB Atlas failed."
231222
raise DocumentStoreError(msg)
232223

233-
if not await self._collection_exists_async():
224+
if not await self._collection_exists_async(self._connection_async, self.database_name, self.collection_name):
234225
msg = f"Collection '{self.collection_name}' does not exist in database '{self.database_name}'."
235226
raise DocumentStoreError(msg)
236227

@@ -274,7 +265,8 @@ def count_documents(self) -> int:
274265
:returns: The number of documents in the document store.
275266
"""
276267
self._ensure_connection_setup()
277-
return self._collection.count_documents({}) # type: ignore[union-attr]
268+
assert self._collection is not None
269+
return self._collection.count_documents({})
278270

279271
async def count_documents_async(self) -> int:
280272
"""
@@ -283,7 +275,8 @@ async def count_documents_async(self) -> int:
283275
:returns: The number of documents in the document store.
284276
"""
285277
await self._ensure_connection_setup_async()
286-
return await self._collection_async.count_documents({}) # type: ignore[union-attr]
278+
assert self._collection_async is not None
279+
return await self._collection_async.count_documents({})
287280

288281
def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
289282
"""
@@ -296,8 +289,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
296289
:returns: A list of Documents that match the given filters.
297290
"""
298291
self._ensure_connection_setup()
292+
assert self._collection is not None
299293
filters = _normalize_filters(filters) if filters else None
300-
documents = list(self._collection.find(filters)) # type: ignore[union-attr]
294+
documents = list(self._collection.find(filters))
301295
return [self._mongo_doc_to_haystack_doc(doc) for doc in documents]
302296

303297
async def filter_documents_async(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
@@ -311,8 +305,9 @@ async def filter_documents_async(self, filters: Optional[Dict[str, Any]] = None)
311305
:returns: A list of Documents that match the given filters.
312306
"""
313307
await self._ensure_connection_setup_async()
308+
assert self._collection_async is not None
314309
filters = _normalize_filters(filters) if filters else None
315-
documents = await self._collection_async.find(filters).to_list() # type: ignore[union-attr]
310+
documents = await self._collection_async.find(filters).to_list()
316311
return [self._mongo_doc_to_haystack_doc(doc) for doc in documents]
317312

318313
def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
@@ -327,7 +322,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
327322
:returns: The number of documents written to the document store.
328323
"""
329324
self._ensure_connection_setup()
330-
325+
assert self._collection is not None
331326
if len(documents) > 0:
332327
if not isinstance(documents[0], Document):
333328
msg = "param 'documents' must contain a list of objects of type Document"
@@ -342,15 +337,15 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
342337

343338
if policy == DuplicatePolicy.SKIP:
344339
operations = [UpdateOne({"id": doc["id"]}, {"$setOnInsert": doc}, upsert=True) for doc in mongo_documents]
345-
existing_documents = self._collection.count_documents({"id": {"$in": [doc.id for doc in documents]}}) # type: ignore[union-attr]
340+
existing_documents = self._collection.count_documents({"id": {"$in": [doc.id for doc in documents]}})
346341
written_docs -= existing_documents
347342
elif policy == DuplicatePolicy.FAIL:
348343
operations = [InsertOne(doc) for doc in mongo_documents]
349344
else:
350345
operations = [ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) for doc in mongo_documents]
351346

352347
try:
353-
self._collection.bulk_write(operations) # type: ignore[union-attr]
348+
self._collection.bulk_write(operations)
354349
except BulkWriteError as e:
355350
msg = f"Duplicate documents found: {e.details['writeErrors']}"
356351
raise DuplicateDocumentError(msg) from e
@@ -371,7 +366,7 @@ async def write_documents_async(
371366
:returns: The number of documents written to the document store.
372367
"""
373368
await self._ensure_connection_setup_async()
374-
369+
assert self._collection_async is not None
375370
if len(documents) > 0:
376371
if not isinstance(documents[0], Document):
377372
msg = "param 'documents' must contain a list of objects of type Document"
@@ -387,15 +382,17 @@ async def write_documents_async(
387382

388383
if policy == DuplicatePolicy.SKIP:
389384
operations = [UpdateOne({"id": doc["id"]}, {"$setOnInsert": doc}, upsert=True) for doc in mongo_documents]
390-
existing_documents = self._collection.count_documents({"id": {"$in": [doc.id for doc in documents]}}) # type: ignore[union-attr]
385+
existing_documents = await self._collection_async.count_documents(
386+
{"id": {"$in": [doc.id for doc in documents]}}
387+
)
391388
written_docs -= existing_documents
392389
elif policy == DuplicatePolicy.FAIL:
393390
operations = [InsertOne(doc) for doc in mongo_documents]
394391
else:
395392
operations = [ReplaceOne({"id": doc["id"]}, upsert=True, replacement=doc) for doc in mongo_documents]
396393

397394
try:
398-
await self._collection_async.bulk_write(operations) # type: ignore[union-attr]
395+
await self._collection_async.bulk_write(operations)
399396
except BulkWriteError as e:
400397
msg = f"Duplicate documents found: {e.details['writeErrors']}"
401398
raise DuplicateDocumentError(msg) from e
@@ -409,9 +406,10 @@ def delete_documents(self, document_ids: List[str]) -> None:
409406
:param document_ids: the document ids to delete
410407
"""
411408
self._ensure_connection_setup()
409+
assert self._collection is not None
412410
if not document_ids:
413411
return
414-
self._collection.delete_many(filter={"id": {"$in": document_ids}}) # type: ignore[union-attr]
412+
self._collection.delete_many(filter={"id": {"$in": document_ids}})
415413

416414
async def delete_documents_async(self, document_ids: List[str]) -> None:
417415
"""
@@ -420,9 +418,10 @@ async def delete_documents_async(self, document_ids: List[str]) -> None:
420418
:param document_ids: the document ids to delete
421419
"""
422420
await self._ensure_connection_setup_async()
421+
assert self._collection_async is not None
423422
if not document_ids:
424423
return
425-
await self._collection_async.delete_many(filter={"id": {"$in": document_ids}}) # type: ignore[union-attr]
424+
await self._collection_async.delete_many(filter={"id": {"$in": document_ids}})
426425

427426
def _embedding_retrieval(
428427
self,
@@ -441,6 +440,7 @@ def _embedding_retrieval(
441440
:raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails.
442441
"""
443442
self._ensure_connection_setup()
443+
assert self._collection is not None
444444
if not query_embedding:
445445
msg = "Query embedding must not be empty"
446446
raise ValueError(msg)
@@ -462,7 +462,7 @@ def _embedding_retrieval(
462462
{"$project": {"_id": 0}},
463463
]
464464
try:
465-
documents = list(self._collection.aggregate(pipeline)) # type: ignore[union-attr]
465+
documents = list(self._collection.aggregate(pipeline))
466466
except Exception as e:
467467
msg = f"Retrieval of documents from MongoDB Atlas failed: {e}"
468468
if filters:
@@ -490,6 +490,7 @@ async def _embedding_retrieval_async(
490490
:raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails.
491491
"""
492492
await self._ensure_connection_setup_async()
493+
assert self._collection_async is not None
493494
if not query_embedding:
494495
msg = "Query embedding must not be empty"
495496
raise ValueError(msg)
@@ -511,7 +512,8 @@ async def _embedding_retrieval_async(
511512
{"$project": {"_id": 0}},
512513
]
513514
try:
514-
documents = await self._collection_async.aggregate(pipeline).to_list() # type: ignore[union-attr]
515+
cursor = await self._collection_async.aggregate(pipeline)
516+
documents = await cursor.to_list(length=None)
515517
except Exception as e:
516518
msg = f"Retrieval of documents from MongoDB Atlas failed: {e}"
517519
if filters:
@@ -606,8 +608,9 @@ def _fulltext_retrieval(
606608
]
607609

608610
self._ensure_connection_setup()
611+
assert self._collection is not None
609612
try:
610-
documents = list(self._collection.aggregate(pipeline)) # type: ignore[union-attr]
613+
documents = list(self._collection.aggregate(pipeline))
611614
except Exception as e:
612615
error_msg = f"Failed to retrieve documents from MongoDB Atlas: {e}"
613616
if filters:
@@ -698,9 +701,9 @@ async def _fulltext_retrieval_async(
698701
]
699702

700703
await self._ensure_connection_setup_async()
701-
704+
assert self._collection_async is not None
702705
try:
703-
cursor = await self._collection_async.aggregate(pipeline) # type: ignore[union-attr]
706+
cursor = await self._collection_async.aggregate(pipeline)
704707
documents = await cursor.to_list(length=None)
705708
except Exception as e:
706709
error_msg = f"Failed to retrieve documents from MongoDB Atlas: {e}"

integrations/mongodb_atlas/src/haystack_integrations/document_stores/py.typed

Whitespace-only changes.

0 commit comments

Comments
 (0)