@@ -117,13 +117,6 @@ def __del__(self) -> None:
117
117
if self ._connection :
118
118
self ._connection .close ()
119
119
120
- async def __aexit__ (self , exc_type , exc_val , exc_tb ) -> None :
121
- """
122
- Asynchronous exit method to close MongoDB connections when the instance is destroyed.
123
- """
124
- if self ._connection_async :
125
- await self ._connection_async .close ()
126
-
127
120
@property
128
121
def connection (self ) -> Union [AsyncMongoClient , MongoClient ]:
129
122
if self ._connection :
@@ -142,53 +135,51 @@ def collection(self) -> Union[AsyncCollection, Collection]:
142
135
msg = "The collection is not established yet."
143
136
raise DocumentStoreError (msg )
144
137
145
- def _connection_is_valid (self ) -> bool :
138
+ def _connection_is_valid (self , connection : MongoClient ) -> bool :
146
139
"""
147
140
Checks if the connection to MongoDB Atlas is valid.
148
141
149
142
:returns: True if the connection is valid, False otherwise.
150
143
"""
151
144
try :
152
- self . _connection . admin .command ("ping" ) # type: ignore[union-attr]
145
+ connection . admin .command ("ping" )
153
146
return True
154
147
except Exception as e :
155
148
logger .error (f"Connection to MongoDB Atlas failed: { e } " )
156
149
return False
157
150
158
- async def _connection_is_valid_async (self ) -> bool :
151
+ async def _connection_is_valid_async (self , connection : AsyncMongoClient ) -> bool :
159
152
"""
160
153
Asynchronously checks if the connection to MongoDB Atlas is valid.
161
154
162
155
:returns: True if the connection is valid, False otherwise.
163
156
"""
164
157
try :
165
- await self . _connection_async . admin .command ("ping" ) # type: ignore[union-attr]
158
+ await connection . admin .command ("ping" )
166
159
return True
167
160
except Exception as e :
168
161
logger .error (f"Connection to MongoDB Atlas failed: { e } " )
169
162
return False
170
163
171
- def _collection_exists (self ) -> bool :
164
+ def _collection_exists (self , connection : MongoClient , database_name : str , collection_name : str ) -> bool :
172
165
"""
173
166
Checks if the collection exists in the MongoDB Atlas database.
174
167
175
168
:returns: True if the collection exists, False otherwise.
176
169
"""
177
- database = self ._connection [self .database_name ] # type: ignore[index]
178
- if self .collection_name in database .list_collection_names ():
179
- return True
180
- return False
170
+ database = connection [database_name ]
171
+ return collection_name in database .list_collection_names ()
181
172
182
- async def _collection_exists_async (self ) -> bool :
173
+ async def _collection_exists_async (
174
+ self , connection : AsyncMongoClient , database_name : str , collection_name : str
175
+ ) -> bool :
183
176
"""
184
177
Asynchronously checks if the collection exists in the MongoDB Atlas database.
185
178
186
179
:returns: True if the collection exists, False otherwise.
187
180
"""
188
- database = self ._connection_async [self .database_name ] # type: ignore[index]
189
- if self .collection_name in await database .list_collection_names ():
190
- return True
191
- return False
181
+ database = connection [database_name ]
182
+ return collection_name in await database .list_collection_names ()
192
183
193
184
def _ensure_connection_setup (self ) -> None :
194
185
"""
@@ -202,11 +193,11 @@ def _ensure_connection_setup(self) -> None:
202
193
self .mongo_connection_string .resolve_value (), driver = DriverInfo (name = "MongoDBAtlasHaystackIntegration" )
203
194
)
204
195
205
- if not self ._connection_is_valid ():
196
+ if not self ._connection_is_valid (self . _connection ):
206
197
msg = "Connection to MongoDB Atlas failed."
207
198
raise DocumentStoreError (msg )
208
199
209
- if not self ._collection_exists ():
200
+ if not self ._collection_exists (self . _connection , self . database_name , self . collection_name ):
210
201
msg = f"Collection '{ self .collection_name } ' does not exist in database '{ self .database_name } '."
211
202
raise DocumentStoreError (msg )
212
203
@@ -226,11 +217,11 @@ async def _ensure_connection_setup_async(self) -> None:
226
217
self .mongo_connection_string .resolve_value (), driver = DriverInfo (name = "MongoDBAtlasHaystackIntegration" )
227
218
)
228
219
229
- if not await self ._connection_is_valid_async ():
220
+ if not await self ._connection_is_valid_async (self . _connection_async ):
230
221
msg = "Connection to MongoDB Atlas failed."
231
222
raise DocumentStoreError (msg )
232
223
233
- if not await self ._collection_exists_async ():
224
+ if not await self ._collection_exists_async (self . _connection_async , self . database_name , self . collection_name ):
234
225
msg = f"Collection '{ self .collection_name } ' does not exist in database '{ self .database_name } '."
235
226
raise DocumentStoreError (msg )
236
227
@@ -274,7 +265,8 @@ def count_documents(self) -> int:
274
265
:returns: The number of documents in the document store.
275
266
"""
276
267
self ._ensure_connection_setup ()
277
- return self ._collection .count_documents ({}) # type: ignore[union-attr]
268
+ assert self ._collection is not None
269
+ return self ._collection .count_documents ({})
278
270
279
271
async def count_documents_async (self ) -> int :
280
272
"""
@@ -283,7 +275,8 @@ async def count_documents_async(self) -> int:
283
275
:returns: The number of documents in the document store.
284
276
"""
285
277
await self ._ensure_connection_setup_async ()
286
- return await self ._collection_async .count_documents ({}) # type: ignore[union-attr]
278
+ assert self ._collection_async is not None
279
+ return await self ._collection_async .count_documents ({})
287
280
288
281
def filter_documents (self , filters : Optional [Dict [str , Any ]] = None ) -> List [Document ]:
289
282
"""
@@ -296,8 +289,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
296
289
:returns: A list of Documents that match the given filters.
297
290
"""
298
291
self ._ensure_connection_setup ()
292
+ assert self ._collection is not None
299
293
filters = _normalize_filters (filters ) if filters else None
300
- documents = list (self ._collection .find (filters )) # type: ignore[union-attr]
294
+ documents = list (self ._collection .find (filters ))
301
295
return [self ._mongo_doc_to_haystack_doc (doc ) for doc in documents ]
302
296
303
297
async def filter_documents_async (self , filters : Optional [Dict [str , Any ]] = None ) -> List [Document ]:
@@ -311,8 +305,9 @@ async def filter_documents_async(self, filters: Optional[Dict[str, Any]] = None)
311
305
:returns: A list of Documents that match the given filters.
312
306
"""
313
307
await self ._ensure_connection_setup_async ()
308
+ assert self ._collection_async is not None
314
309
filters = _normalize_filters (filters ) if filters else None
315
- documents = await self ._collection_async .find (filters ).to_list () # type: ignore[union-attr]
310
+ documents = await self ._collection_async .find (filters ).to_list ()
316
311
return [self ._mongo_doc_to_haystack_doc (doc ) for doc in documents ]
317
312
318
313
def write_documents (self , documents : List [Document ], policy : DuplicatePolicy = DuplicatePolicy .NONE ) -> int :
@@ -327,7 +322,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
327
322
:returns: The number of documents written to the document store.
328
323
"""
329
324
self ._ensure_connection_setup ()
330
-
325
+ assert self . _collection is not None
331
326
if len (documents ) > 0 :
332
327
if not isinstance (documents [0 ], Document ):
333
328
msg = "param 'documents' must contain a list of objects of type Document"
@@ -342,15 +337,15 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
342
337
343
338
if policy == DuplicatePolicy .SKIP :
344
339
operations = [UpdateOne ({"id" : doc ["id" ]}, {"$setOnInsert" : doc }, upsert = True ) for doc in mongo_documents ]
345
- existing_documents = self ._collection .count_documents ({"id" : {"$in" : [doc .id for doc in documents ]}}) # type: ignore[union-attr]
340
+ existing_documents = self ._collection .count_documents ({"id" : {"$in" : [doc .id for doc in documents ]}})
346
341
written_docs -= existing_documents
347
342
elif policy == DuplicatePolicy .FAIL :
348
343
operations = [InsertOne (doc ) for doc in mongo_documents ]
349
344
else :
350
345
operations = [ReplaceOne ({"id" : doc ["id" ]}, upsert = True , replacement = doc ) for doc in mongo_documents ]
351
346
352
347
try :
353
- self ._collection .bulk_write (operations ) # type: ignore[union-attr]
348
+ self ._collection .bulk_write (operations )
354
349
except BulkWriteError as e :
355
350
msg = f"Duplicate documents found: { e .details ['writeErrors' ]} "
356
351
raise DuplicateDocumentError (msg ) from e
@@ -371,7 +366,7 @@ async def write_documents_async(
371
366
:returns: The number of documents written to the document store.
372
367
"""
373
368
await self ._ensure_connection_setup_async ()
374
-
369
+ assert self . _collection_async is not None
375
370
if len (documents ) > 0 :
376
371
if not isinstance (documents [0 ], Document ):
377
372
msg = "param 'documents' must contain a list of objects of type Document"
@@ -387,15 +382,17 @@ async def write_documents_async(
387
382
388
383
if policy == DuplicatePolicy .SKIP :
389
384
operations = [UpdateOne ({"id" : doc ["id" ]}, {"$setOnInsert" : doc }, upsert = True ) for doc in mongo_documents ]
390
- existing_documents = self ._collection .count_documents ({"id" : {"$in" : [doc .id for doc in documents ]}}) # type: ignore[union-attr]
385
+ existing_documents = await self ._collection_async .count_documents (
386
+ {"id" : {"$in" : [doc .id for doc in documents ]}}
387
+ )
391
388
written_docs -= existing_documents
392
389
elif policy == DuplicatePolicy .FAIL :
393
390
operations = [InsertOne (doc ) for doc in mongo_documents ]
394
391
else :
395
392
operations = [ReplaceOne ({"id" : doc ["id" ]}, upsert = True , replacement = doc ) for doc in mongo_documents ]
396
393
397
394
try :
398
- await self ._collection_async .bulk_write (operations ) # type: ignore[union-attr]
395
+ await self ._collection_async .bulk_write (operations )
399
396
except BulkWriteError as e :
400
397
msg = f"Duplicate documents found: { e .details ['writeErrors' ]} "
401
398
raise DuplicateDocumentError (msg ) from e
@@ -409,9 +406,10 @@ def delete_documents(self, document_ids: List[str]) -> None:
409
406
:param document_ids: the document ids to delete
410
407
"""
411
408
self ._ensure_connection_setup ()
409
+ assert self ._collection is not None
412
410
if not document_ids :
413
411
return
414
- self ._collection .delete_many (filter = {"id" : {"$in" : document_ids }}) # type: ignore[union-attr]
412
+ self ._collection .delete_many (filter = {"id" : {"$in" : document_ids }})
415
413
416
414
async def delete_documents_async (self , document_ids : List [str ]) -> None :
417
415
"""
@@ -420,9 +418,10 @@ async def delete_documents_async(self, document_ids: List[str]) -> None:
420
418
:param document_ids: the document ids to delete
421
419
"""
422
420
await self ._ensure_connection_setup_async ()
421
+ assert self ._collection_async is not None
423
422
if not document_ids :
424
423
return
425
- await self ._collection_async .delete_many (filter = {"id" : {"$in" : document_ids }}) # type: ignore[union-attr]
424
+ await self ._collection_async .delete_many (filter = {"id" : {"$in" : document_ids }})
426
425
427
426
def _embedding_retrieval (
428
427
self ,
@@ -441,6 +440,7 @@ def _embedding_retrieval(
441
440
:raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails.
442
441
"""
443
442
self ._ensure_connection_setup ()
443
+ assert self ._collection is not None
444
444
if not query_embedding :
445
445
msg = "Query embedding must not be empty"
446
446
raise ValueError (msg )
@@ -462,7 +462,7 @@ def _embedding_retrieval(
462
462
{"$project" : {"_id" : 0 }},
463
463
]
464
464
try :
465
- documents = list (self ._collection .aggregate (pipeline )) # type: ignore[union-attr]
465
+ documents = list (self ._collection .aggregate (pipeline ))
466
466
except Exception as e :
467
467
msg = f"Retrieval of documents from MongoDB Atlas failed: { e } "
468
468
if filters :
@@ -490,6 +490,7 @@ async def _embedding_retrieval_async(
490
490
:raises DocumentStoreError: If the retrieval of documents from MongoDB Atlas fails.
491
491
"""
492
492
await self ._ensure_connection_setup_async ()
493
+ assert self ._collection_async is not None
493
494
if not query_embedding :
494
495
msg = "Query embedding must not be empty"
495
496
raise ValueError (msg )
@@ -511,7 +512,8 @@ async def _embedding_retrieval_async(
511
512
{"$project" : {"_id" : 0 }},
512
513
]
513
514
try :
514
- documents = await self ._collection_async .aggregate (pipeline ).to_list () # type: ignore[union-attr]
515
+ cursor = await self ._collection_async .aggregate (pipeline )
516
+ documents = await cursor .to_list (length = None )
515
517
except Exception as e :
516
518
msg = f"Retrieval of documents from MongoDB Atlas failed: { e } "
517
519
if filters :
@@ -606,8 +608,9 @@ def _fulltext_retrieval(
606
608
]
607
609
608
610
self ._ensure_connection_setup ()
611
+ assert self ._collection is not None
609
612
try :
610
- documents = list (self ._collection .aggregate (pipeline )) # type: ignore[union-attr]
613
+ documents = list (self ._collection .aggregate (pipeline ))
611
614
except Exception as e :
612
615
error_msg = f"Failed to retrieve documents from MongoDB Atlas: { e } "
613
616
if filters :
@@ -698,9 +701,9 @@ async def _fulltext_retrieval_async(
698
701
]
699
702
700
703
await self ._ensure_connection_setup_async ()
701
-
704
+ assert self . _collection_async is not None
702
705
try :
703
- cursor = await self ._collection_async .aggregate (pipeline ) # type: ignore[union-attr]
706
+ cursor = await self ._collection_async .aggregate (pipeline )
704
707
documents = await cursor .to_list (length = None )
705
708
except Exception as e :
706
709
error_msg = f"Failed to retrieve documents from MongoDB Atlas: { e } "
0 commit comments