From 3474a15f4e9ffd2606ea7fd8fde91a7be8fb671d Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Mon, 3 Feb 2025 21:07:35 +0100 Subject: [PATCH] wip fix tmppath issue fix tests make error message bad again --- disk_objectstore/backup_utils.py | 1 + disk_objectstore/container.py | 49 ++++++++++++++++++-------------- disk_objectstore/database.py | 6 ++-- tests/conftest.py | 13 ++------- tests/test_cli.py | 2 +- tests/test_container.py | 20 ++++++------- 6 files changed, 44 insertions(+), 47 deletions(-) diff --git a/disk_objectstore/backup_utils.py b/disk_objectstore/backup_utils.py index d536f04..4240679 100644 --- a/disk_objectstore/backup_utils.py +++ b/disk_objectstore/backup_utils.py @@ -221,6 +221,7 @@ def call_rsync( # pylint: disable=too-many-arguments,too-many-branches def get_existing_backup_folders(self): """Get all folders matching the backup folder name pattern.""" + success, stdout = self.run_cmd( [ "find", diff --git a/disk_objectstore/container.py b/disk_objectstore/container.py index dbea118..312da30 100644 --- a/disk_objectstore/container.py +++ b/disk_objectstore/container.py @@ -121,6 +121,7 @@ def __init__(self, folder: str | Path) -> None: self._folder = Path(folder).resolve() # Will be populated by the _get_session function self._session: Session | None = None + self._keep_open_session: Session | None = None # These act as caches and will be populated by the corresponding properties # IMPORANT! IF YOU ADD MORE, REMEMBER TO CLEAR THEM IN `init_container()`! @@ -134,9 +135,17 @@ def get_folder(self) -> Path: def close(self) -> None: """Close open files (in particular, the connection to the SQLite DB).""" if self._session is not None: + engine = self._session.bind self._session.close() + engine.dispose() self._session = None + if self._keep_open_session is not None: + engine = self._keep_open_session.bind + self._keep_open_session.close() + engine.dispose() + self._keep_open_session = None + def __enter__(self) -> Container: """Return a context manager that will close the session when exiting the context.""" return self @@ -180,32 +189,21 @@ def _get_config_file(self) -> Path: """Return the path to the container config file.""" return self._folder / "config.json" - @overload - def _get_session( - self, create: bool = False, raise_if_missing: Literal[True] = True + def _create_init_session( + self ) -> Session: - ... - - @overload - def _get_session( - self, create: bool = False, raise_if_missing: Literal[False] = False - ) -> Session | None: - ... - - def _get_session( - self, create: bool = False, raise_if_missing: bool = False - ) -> Session | None: """Return a new session to connect to the pack-index SQLite DB. :param create: if True, creates the sqlite file and schema. :param raise_if_missing: ignored if create==True. If create==False, and the index file is missing, either raise an exception (FileNotFoundError) if this flag is True, or return None """ - return get_session( - self._get_pack_index_path(), - create=create, - raise_if_missing=raise_if_missing, - ) + if self._keep_open_session is None: + self._keep_open_session = get_session( + self._get_pack_index_path(), + create=True, + ) + return self._keep_open_session def _get_cached_session(self) -> Session: """Return the SQLAlchemy session to access the SQLite file, @@ -214,7 +212,10 @@ def _get_cached_session(self) -> Session: # the latter means that in the previous run the pack file was missing # but maybe by now it has been created! if self._session is None: - self._session = self._get_session(create=False, raise_if_missing=True) + self._session = get_session( + self._get_pack_index_path(), + create=False, + ) return self._session def _get_loose_path_from_hashkey(self, hashkey: str) -> Path: @@ -332,6 +333,7 @@ def init_container( raise ValueError(f'Unknown hash type "{hash_type}"') if clear: + self.close() if self._folder.exists(): shutil.rmtree(self._folder) @@ -391,7 +393,7 @@ def init_container( ]: os.makedirs(folder) - self._get_session(create=True) + self._create_init_session() def _get_repository_config(self) -> dict[str, int | str]: """Return the repository config.""" @@ -1141,7 +1143,7 @@ def get_total_size(self) -> TotalSize: retval["total_size_packindexes_on_disk"] = ( self._get_pack_index_path().stat().st_size - ) + ) total_size_loose = 0 for loose_hashkey in self._list_loose(): @@ -1916,6 +1918,9 @@ def add_objects_to_pack( # pylint: disable=too-many-arguments :return: a list of object hash keys """ + # TODO should be custom error but not sure what + if not self.is_initialised: + raise ValueError("Invalid use of function, please first initialise the container.") stream_list: list[StreamSeekBytesType] = [ io.BytesIO(content) for content in content_list ] diff --git a/disk_objectstore/database.py b/disk_objectstore/database.py index 08c9436..6a3aafd 100644 --- a/disk_objectstore/database.py +++ b/disk_objectstore/database.py @@ -32,7 +32,7 @@ class Obj(Base): # pylint: disable=too-few-public-methods def get_session( - path: Path, create: bool = False, raise_if_missing: bool = False + path: Path, create: bool = False ) -> Optional[Session]: """Return a new session to connect to the pack-index SQLite DB. @@ -41,9 +41,7 @@ def get_session( is missing, either raise an exception (FileNotFoundError) if this flag is True, or return None """ if not create and not path.exists(): - if raise_if_missing: - raise FileNotFoundError("Pack index does not exist") - return None + raise FileNotFoundError("Pack index does not exist") engine = create_engine(f"sqlite:///{path}", future=True) diff --git a/tests/conftest.py b/tests/conftest.py index 84fbc04..8f1443f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,21 +81,14 @@ def temp_container(temp_dir): # pylint: disable=redefined-outer-name @pytest.fixture(scope="function") -def temp_dir(): +def temp_dir(tmp_path): """Get a temporary directory. :return: The path to the directory :rtype: str """ - import gc - gc.collect() - - try: - dirpath = tempfile.mkdtemp() - yield Path(dirpath) - finally: - # after the test function has completed, remove the directory again - shutil.rmtree(dirpath) + dirpath = tempfile.mkdtemp(dir=str(tmp_path)) + yield Path(dirpath) @pytest.fixture(scope="function") diff --git a/tests/test_cli.py b/tests/test_cli.py index 5b7e2b9..8cc47c0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -223,7 +223,7 @@ def test_backup(temp_container, temp_dir, remote, verbosity): if verbosity: args += [f"--verbosity={verbosity}"] - result = CliRunner().invoke(cli.backup, args, obj=obj) + result = CliRunner().invoke(cli.backup, args, obj=obj, catch_exceptions=False) assert result.exit_code == 0 assert path.exists() diff --git a/tests/test_container.py b/tests/test_container.py index 9a3cda5..a81342e 100644 --- a/tests/test_container.py +++ b/tests/test_container.py @@ -677,15 +677,16 @@ def test_initialisation(temp_dir): # Check that the session cannot be obtained before initialising with pytest.raises(FileNotFoundError): - container._get_session(create=False, raise_if_missing=True) - assert container._get_session(create=False, raise_if_missing=False) is None + container._get_cached_session() container.init_container() assert container.is_initialised + container.close() # This call should go through container.init_container(clear=True) assert container.is_initialised + container.close() with pytest.raises(FileExistsError) as excinfo: container.init_container() @@ -717,31 +718,32 @@ def test_initialisation(temp_dir): @pytest.mark.parametrize("hash_type", ["sha256", "sha1"]) @pytest.mark.parametrize("compress", [True, False]) -def test_check_hash_computation(temp_container, hash_type, compress): +def test_check_hash_computation(temp_dir, hash_type, compress): """Check that the hashes are correctly computed, when storing loose, directly to packs, and while repacking all loose. Check both compressed and uncompressed packed objects. """ # Re-init the container with the correct hash type - temp_container.init_container(hash_type=hash_type, clear=True) + container = Container(temp_dir) + container.init_container(hash_type=hash_type, clear=True) content1 = b"1" content2 = b"222" content3 = b"n2fwd" expected_hasher = getattr(hashlib, hash_type) - hashkey1 = temp_container.add_object(content1) + hashkey1 = container.add_object(content1) assert hashkey1 == expected_hasher(content1).hexdigest() - hashkey2, hashkey3 = temp_container.add_objects_to_pack( + hashkey2, hashkey3 = container.add_objects_to_pack( [content2, content3], compress=compress ) assert hashkey2 == expected_hasher(content2).hexdigest() assert hashkey3 == expected_hasher(content3).hexdigest() # No exceptions should be aised - temp_container.pack_all_loose(compress=compress, validate_objects=True) + container.pack_all_loose(compress=compress, validate_objects=True) @pytest.mark.parametrize("validate_objects", [True, False]) @@ -1064,9 +1066,7 @@ def test_sizes( temp_container, generate_random_data, compress_packs, compression_algorithm ): """Check that the information on size is reliable.""" - temp_container.init_container( - clear=True, compression_algorithm=compression_algorithm - ) + temp_container.init_container( clear=True, compression_algorithm=compression_algorithm) size_info = temp_container.get_total_size() assert size_info["total_size_packed"] == 0 assert size_info["total_size_packed_on_disk"] == 0