Skip to content

Commit 7415557

Browse files
committed
fix GridFSURIStore and increase testing
1 parent 2632269 commit 7415557

File tree

3 files changed

+57
-10
lines changed

3 files changed

+57
-10
lines changed

src/maggma/stores/gridfs.py

+39-6
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
self.searchable_fields = [] if searchable_fields is None else searchable_fields
8888
self.kwargs = kwargs
8989
self.ssh_tunnel = ssh_tunnel
90+
self._fs = None
9091

9192
if auth_source is None:
9293
auth_source = self.database
@@ -157,9 +158,9 @@ def connect(self, force_reset: bool = False):
157158
db = conn[self.database]
158159
self._coll = gridfs.GridFS(db, self.collection_name)
159160
self._files_collection = db[f"{self.collection_name}.files"]
160-
self._files_store = MongoStore.from_collection(self._files_collection)
161-
self._files_store.last_updated_field = f"metadata.{self.last_updated_field}"
162-
self._files_store.key = self.key
161+
self._fs = MongoStore.from_collection(self._files_collection)
162+
self._fs.last_updated_field = f"metadata.{self.last_updated_field}"
163+
self._fs.key = self.key
163164
self._chunks_collection = db[f"{self.collection_name}.chunks"]
164165

165166
@property
@@ -169,6 +170,13 @@ def _collection(self):
169170
raise StoreError("Must connect Mongo-like store before attempting to use it")
170171
return self._coll
171172

173+
@property
174+
def _files_store(self):
175+
"""Property referring to MongoStore associated to the files_collection."""
176+
if self._fs is None:
177+
raise StoreError("Must connect Mongo-like store before attempting to use it")
178+
return self._fs
179+
172180
@property
173181
def last_updated(self) -> datetime:
174182
"""
@@ -448,6 +456,7 @@ def __init__(
448456
ensure_metadata: bool = False,
449457
searchable_fields: Optional[list[str]] = None,
450458
mongoclient_kwargs: Optional[dict] = None,
459+
ssh_tunnel: Optional[SSHTunnel] = None,
451460
**kwargs,
452461
):
453462
"""
@@ -463,6 +472,10 @@ def __init__(
463472
"""
464473
self.uri = uri
465474

475+
if ssh_tunnel:
476+
raise ValueError(f"At the moment ssh_tunnel is not supported for {self.__class__.__name__}")
477+
self.ssh_tunnel = None
478+
466479
# parse the dbname from the uri
467480
if database is None:
468481
d_uri = uri_parser.parse_uri(uri)
@@ -479,6 +492,7 @@ def __init__(
479492
self.searchable_fields = [] if searchable_fields is None else searchable_fields
480493
self.kwargs = kwargs
481494
self.mongoclient_kwargs = mongoclient_kwargs or {}
495+
self._fs = None
482496

483497
if "key" not in kwargs:
484498
kwargs["key"] = "_id"
@@ -497,7 +511,26 @@ def connect(self, force_reset: bool = False):
497511
db = conn[self.database]
498512
self._coll = gridfs.GridFS(db, self.collection_name)
499513
self._files_collection = db[f"{self.collection_name}.files"]
500-
self._files_store = MongoStore.from_collection(self._files_collection)
501-
self._files_store.last_updated_field = f"metadata.{self.last_updated_field}"
502-
self._files_store.key = self.key
514+
self._fs = MongoStore.from_collection(self._files_collection)
515+
self._fs.last_updated_field = f"metadata.{self.last_updated_field}"
516+
self._fs.key = self.key
503517
self._chunks_collection = db[f"{self.collection_name}.chunks"]
518+
519+
@property
520+
def name(self) -> str:
521+
"""
522+
Return a string representing this data source.
523+
"""
524+
# TODO: This is not very safe since it exposes the username/password info
525+
return self.uri
526+
527+
def __eq__(self, other: object) -> bool:
528+
"""
529+
Check equality for GridFSURIStore
530+
other: other GridFSURIStore to compare with.
531+
"""
532+
if not isinstance(other, GridFSStore):
533+
return False
534+
535+
fields = ["uri", "database", "collection_name"]
536+
return all(getattr(self, f) == getattr(other, f) for f in fields)

src/maggma/stores/mongolike.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,9 @@ def __init__(
460460
ensure determinacy in query results.
461461
"""
462462
self.uri = uri
463-
self.ssh_tunnel = ssh_tunnel
463+
if ssh_tunnel:
464+
raise ValueError(f"At the moment ssh_tunnel is not supported for {self.__class__.__name__}")
465+
self.ssh_tunnel = None
464466
self.default_sort = default_sort
465467
self.safe_update = safe_update
466468
self.mongoclient_kwargs = mongoclient_kwargs or {}

tests/stores/test_gridfs.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,20 @@ def mongostore():
2020
store._collection.drop()
2121

2222

23-
@pytest.fixture()
24-
def gridfsstore():
25-
store = GridFSStore("maggma_test", "test", key="task_id")
23+
@pytest.fixture(params=["std", "uri"])
24+
def gridfsstore(request):
25+
"""
26+
Fixture providing both a standard GridFSStore and a GridFSURIStore.
27+
"""
28+
store_type = request.param
29+
if store_type == "std":
30+
store = GridFSStore("maggma_test", "test", key="task_id")
31+
elif store_type == "uri":
32+
store = GridFSURIStore(
33+
uri="mongodb://localhost:27017", database="maggma_test", collection_name="test", key="task_id"
34+
)
35+
else:
36+
raise ValueError(f"Unknown store_type {store_type}")
2637
store.connect()
2738
yield store
2839
store._files_collection.drop()
@@ -243,6 +254,7 @@ def test_gridfs_uri():
243254
is_name = store.name is uri
244255
# This is try and keep the secret safe
245256
assert is_name
257+
store.close()
246258

247259

248260
def test_gridfs_uri_dbname_parse():

0 commit comments

Comments
 (0)