diff --git a/newsfragments/672.feature.rst b/newsfragments/672.feature.rst new file mode 100644 index 00000000..17950163 --- /dev/null +++ b/newsfragments/672.feature.rst @@ -0,0 +1,2 @@ +Have separate parameters for template database name and database name in DatabaseJanitor. +It'll make it much clearer to understand the code and Janitor's behaviour. \ No newline at end of file diff --git a/pytest_postgresql/executor.py b/pytest_postgresql/executor.py index 5c2800b9..cc8af302 100644 --- a/pytest_postgresql/executor.py +++ b/pytest_postgresql/executor.py @@ -131,6 +131,11 @@ def __init__( }, ) + @property + def template_dbname(self) -> str: + """Return the template database name.""" + return f"{self.dbname}_tmpl" + def start(self: T) -> T: """Add check for postgresql version before starting process.""" if self.version < self.MIN_SUPPORTED_VERSION: diff --git a/pytest_postgresql/executor_noop.py b/pytest_postgresql/executor_noop.py index 68d53fcf..4f910200 100644 --- a/pytest_postgresql/executor_noop.py +++ b/pytest_postgresql/executor_noop.py @@ -56,6 +56,11 @@ def __init__( self.dbname = dbname self._version: Any = None + @property + def template_dbname(self) -> str: + """Return the template database name.""" + return f"{self.dbname}_tmpl" + @property def version(self) -> Any: """Get postgresql's version.""" diff --git a/pytest_postgresql/factories/client.py b/pytest_postgresql/factories/client.py index 389529a2..462c392b 100644 --- a/pytest_postgresql/factories/client.py +++ b/pytest_postgresql/factories/client.py @@ -75,9 +75,15 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]: ), category=DeprecationWarning, ) - with DatabaseJanitor( - pg_user, pg_host, pg_port, pg_db, proc_fixture.version, pg_password, isolation_level + user=pg_user, + host=pg_host, + port=pg_port, + dbname=pg_db, + template_dbname=proc_fixture.template_dbname, + version=proc_fixture.version, + password=pg_password, + isolation_level=isolation_level, ) as janitor: db_connection: Connection = psycopg.connect( dbname=pg_db, diff --git a/pytest_postgresql/factories/noprocess.py b/pytest_postgresql/factories/noprocess.py index 638868a0..56200584 100644 --- a/pytest_postgresql/factories/noprocess.py +++ b/pytest_postgresql/factories/noprocess.py @@ -81,12 +81,11 @@ def postgresql_noproc_fixture(request: FixtureRequest) -> Iterator[NoopExecutor] dbname=pg_dbname, options=pg_options, ) - template_dbname = f"{noop_exec.dbname}_tmpl" with DatabaseJanitor( user=noop_exec.user, host=noop_exec.host, port=noop_exec.port, - dbname=template_dbname, + template_dbname=noop_exec.template_dbname, version=noop_exec.version, password=noop_exec.password, ) as janitor: diff --git a/pytest_postgresql/factories/process.py b/pytest_postgresql/factories/process.py index 28fd0b54..9068a609 100644 --- a/pytest_postgresql/factories/process.py +++ b/pytest_postgresql/factories/process.py @@ -136,12 +136,11 @@ def postgresql_proc_fixture( # start server with postgresql_executor: postgresql_executor.wait_for_postgres() - template_dbname = f"{postgresql_executor.dbname}_tmpl" with DatabaseJanitor( user=postgresql_executor.user, host=postgresql_executor.host, port=postgresql_executor.port, - dbname=template_dbname, + template_dbname=postgresql_executor.template_dbname, version=postgresql_executor.version, password=postgresql_executor.password, ) as janitor: diff --git a/pytest_postgresql/janitor.py b/pytest_postgresql/janitor.py index 8b69e1d1..7636f1cc 100644 --- a/pytest_postgresql/janitor.py +++ b/pytest_postgresql/janitor.py @@ -26,8 +26,9 @@ def __init__( user: str, host: str, port: Union[str, int], - dbname: str, version: Union[str, float, Version], # type: ignore[valid-type] + dbname: Optional[str] = None, + template_dbname: Optional[str] = None, password: Optional[str] = None, isolation_level: "Optional[psycopg.IsolationLevel]" = None, connection_timeout: int = 60, @@ -38,6 +39,7 @@ def __init__( :param host: postgresql host :param port: postgresql port :param dbname: database name + :param dbname: template database name :param version: postgresql version number :param password: optional postgresql password :param isolation_level: optional postgresql isolation level @@ -49,7 +51,10 @@ def __init__( self.password = password self.host = host self.port = port + # At least one of the dbname or template_dbname has to be filled. + assert any([dbname, template_dbname]) self.dbname = dbname + self.template_dbname = template_dbname self._connection_timeout = connection_timeout self.isolation_level = isolation_level if not isinstance(version, Version): @@ -59,36 +64,33 @@ def __init__( def init(self) -> None: """Create database in postgresql.""" - template_name = f"{self.dbname}_tmpl" with self.cursor() as cur: - if self.dbname.endswith("_tmpl"): - result = False - else: - cur.execute( - "SELECT EXISTS " - "(SELECT datname FROM pg_catalog.pg_database WHERE datname= %s);", - (template_name,), - ) - row = cur.fetchone() - result = (row is not None) and row[0] - if not result: + if self.is_template(): + cur.execute(f'CREATE DATABASE "{self.template_dbname}";') + elif self.template_dbname is None: cur.execute(f'CREATE DATABASE "{self.dbname}";') else: # All template database does not allow connection: - self._dont_datallowconn(cur, template_name) + self._dont_datallowconn(cur, self.template_dbname) # And make sure no-one is left connected to the template database. - # Otherwise Creating database from template will fail - self._terminate_connection(cur, template_name) - cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{template_name}";') + # Otherwise, Creating database from template will fail + self._terminate_connection(cur, self.template_dbname) + cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{self.template_dbname}";') + + def is_template(self) -> bool: + """Determine whether the DatabaseJanitor maintains template or database.""" + return self.dbname is None def drop(self) -> None: """Drop database in postgresql.""" # We cannot drop the database while there are connections to it, so we # terminate all connections first while not allowing new connections. + db_to_drop = self.template_dbname if self.is_template() else self.dbname + assert db_to_drop with self.cursor() as cur: - self._dont_datallowconn(cur, self.dbname) - self._terminate_connection(cur, self.dbname) - cur.execute(f'DROP DATABASE IF EXISTS "{self.dbname}";') + self._dont_datallowconn(cur, db_to_drop) + self._terminate_connection(cur, db_to_drop) + cur.execute(f'DROP DATABASE IF EXISTS "{db_to_drop}";') @staticmethod def _dont_datallowconn(cur: Cursor, dbname: str) -> None: @@ -113,12 +115,13 @@ def load(self, load: Union[Callable, str, Path]) -> None: * a callable that expects: host, port, user, dbname and password arguments. """ + db_to_load = self.template_dbname if self.is_template() else self.dbname _loader = build_loader(load) _loader( host=self.host, port=self.port, user=self.user, - dbname=self.dbname, + dbname=db_to_load, password=self.password, ) diff --git a/tests/test_janitor.py b/tests/test_janitor.py index de2d74f5..5eafe304 100644 --- a/tests/test_janitor.py +++ b/tests/test_janitor.py @@ -15,14 +15,18 @@ @pytest.mark.parametrize("version", (VERSION, 10, "10")) def test_version_cast(version: Any) -> None: """Test that version is cast to Version object.""" - janitor = DatabaseJanitor("user", "host", "1234", "database_name", version) + janitor = DatabaseJanitor( + user="user", host="host", port="1234", dbname="database_name", version=version + ) assert janitor.version == VERSION @patch("pytest_postgresql.janitor.psycopg.connect") def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None: """Test that the cursor requests the postgres database.""" - janitor = DatabaseJanitor("user", "host", "1234", "database_name", 10) + janitor = DatabaseJanitor( + user="user", host="host", port="1234", dbname="database_name", version=10 + ) with janitor.cursor(): connect_mock.assert_called_once_with( dbname="postgres", user="user", password=None, host="host", port="1234" @@ -32,7 +36,14 @@ def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None: @patch("pytest_postgresql.janitor.psycopg.connect") def test_cursor_connects_with_password(connect_mock: MagicMock) -> None: """Test that the cursor requests the postgres database.""" - janitor = DatabaseJanitor("user", "host", "1234", "database_name", 10, "some_password") + janitor = DatabaseJanitor( + user="user", + host="host", + port="1234", + dbname="database_name", + version=10, + password="some_password", + ) with janitor.cursor(): connect_mock.assert_called_once_with( dbname="postgres", user="user", password="some_password", host="host", port="1234"