Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Have separate paratemeter for template and regular database name - cl… #891

Merged
merged 1 commit into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions newsfragments/672.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Have separate parameters for template database name and database name in DatabaseJanitor.
It'll make it much clearer to understand the code and Janitor's behaviour.
5 changes: 5 additions & 0 deletions pytest_postgresql/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def __init__(
},
)

@property
def template_dbname(self) -> str:
"""Return the template database name."""
return f"{self.dbname}_tmpl"

def start(self: T) -> T:
"""Add check for postgresql version before starting process."""
if self.version < self.MIN_SUPPORTED_VERSION:
Expand Down
5 changes: 5 additions & 0 deletions pytest_postgresql/executor_noop.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def __init__(
self.dbname = dbname
self._version: Any = None

@property
def template_dbname(self) -> str:
"""Return the template database name."""
return f"{self.dbname}_tmpl"

@property
def version(self) -> Any:
"""Get postgresql's version."""
Expand Down
10 changes: 8 additions & 2 deletions pytest_postgresql/factories/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,15 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
),
category=DeprecationWarning,
)

with DatabaseJanitor(
pg_user, pg_host, pg_port, pg_db, proc_fixture.version, pg_password, isolation_level
user=pg_user,
host=pg_host,
port=pg_port,
dbname=pg_db,
template_dbname=proc_fixture.template_dbname,
version=proc_fixture.version,
password=pg_password,
isolation_level=isolation_level,
) as janitor:
db_connection: Connection = psycopg.connect(
dbname=pg_db,
Expand Down
3 changes: 1 addition & 2 deletions pytest_postgresql/factories/noprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@ def postgresql_noproc_fixture(request: FixtureRequest) -> Iterator[NoopExecutor]
dbname=pg_dbname,
options=pg_options,
)
template_dbname = f"{noop_exec.dbname}_tmpl"
with DatabaseJanitor(
user=noop_exec.user,
host=noop_exec.host,
port=noop_exec.port,
dbname=template_dbname,
template_dbname=noop_exec.template_dbname,
version=noop_exec.version,
password=noop_exec.password,
) as janitor:
Expand Down
3 changes: 1 addition & 2 deletions pytest_postgresql/factories/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,11 @@ def postgresql_proc_fixture(
# start server
with postgresql_executor:
postgresql_executor.wait_for_postgres()
template_dbname = f"{postgresql_executor.dbname}_tmpl"
with DatabaseJanitor(
user=postgresql_executor.user,
host=postgresql_executor.host,
port=postgresql_executor.port,
dbname=template_dbname,
template_dbname=postgresql_executor.template_dbname,
version=postgresql_executor.version,
password=postgresql_executor.password,
) as janitor:
Expand Down
45 changes: 24 additions & 21 deletions pytest_postgresql/janitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ def __init__(
user: str,
host: str,
port: Union[str, int],
dbname: str,
version: Union[str, float, Version], # type: ignore[valid-type]
dbname: Optional[str] = None,
template_dbname: Optional[str] = None,
password: Optional[str] = None,
isolation_level: "Optional[psycopg.IsolationLevel]" = None,
connection_timeout: int = 60,
Expand All @@ -38,6 +39,7 @@ def __init__(
:param host: postgresql host
:param port: postgresql port
:param dbname: database name
:param dbname: template database name
:param version: postgresql version number
:param password: optional postgresql password
:param isolation_level: optional postgresql isolation level
Expand All @@ -49,7 +51,10 @@ def __init__(
self.password = password
self.host = host
self.port = port
# At least one of the dbname or template_dbname has to be filled.
assert any([dbname, template_dbname])
self.dbname = dbname
self.template_dbname = template_dbname
self._connection_timeout = connection_timeout
self.isolation_level = isolation_level
if not isinstance(version, Version):
Expand All @@ -59,36 +64,33 @@ def __init__(

def init(self) -> None:
"""Create database in postgresql."""
template_name = f"{self.dbname}_tmpl"
with self.cursor() as cur:
if self.dbname.endswith("_tmpl"):
result = False
else:
cur.execute(
"SELECT EXISTS "
"(SELECT datname FROM pg_catalog.pg_database WHERE datname= %s);",
(template_name,),
)
row = cur.fetchone()
result = (row is not None) and row[0]
if not result:
if self.is_template():
cur.execute(f'CREATE DATABASE "{self.template_dbname}";')
elif self.template_dbname is None:
cur.execute(f'CREATE DATABASE "{self.dbname}";')
else:
# All template database does not allow connection:
self._dont_datallowconn(cur, template_name)
self._dont_datallowconn(cur, self.template_dbname)
# And make sure no-one is left connected to the template database.
# Otherwise Creating database from template will fail
self._terminate_connection(cur, template_name)
cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{template_name}";')
# Otherwise, Creating database from template will fail
self._terminate_connection(cur, self.template_dbname)
cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{self.template_dbname}";')

def is_template(self) -> bool:
"""Determine whether the DatabaseJanitor maintains template or database."""
return self.dbname is None

def drop(self) -> None:
"""Drop database in postgresql."""
# We cannot drop the database while there are connections to it, so we
# terminate all connections first while not allowing new connections.
db_to_drop = self.template_dbname if self.is_template() else self.dbname
assert db_to_drop
with self.cursor() as cur:
self._dont_datallowconn(cur, self.dbname)
self._terminate_connection(cur, self.dbname)
cur.execute(f'DROP DATABASE IF EXISTS "{self.dbname}";')
self._dont_datallowconn(cur, db_to_drop)
self._terminate_connection(cur, db_to_drop)
cur.execute(f'DROP DATABASE IF EXISTS "{db_to_drop}";')

@staticmethod
def _dont_datallowconn(cur: Cursor, dbname: str) -> None:
Expand All @@ -113,12 +115,13 @@ def load(self, load: Union[Callable, str, Path]) -> None:
* a callable that expects: host, port, user, dbname and password arguments.

"""
db_to_load = self.template_dbname if self.is_template() else self.dbname
_loader = build_loader(load)
_loader(
host=self.host,
port=self.port,
user=self.user,
dbname=self.dbname,
dbname=db_to_load,
password=self.password,
)

Expand Down
17 changes: 14 additions & 3 deletions tests/test_janitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
@pytest.mark.parametrize("version", (VERSION, 10, "10"))
def test_version_cast(version: Any) -> None:
"""Test that version is cast to Version object."""
janitor = DatabaseJanitor("user", "host", "1234", "database_name", version)
janitor = DatabaseJanitor(
user="user", host="host", port="1234", dbname="database_name", version=version
)
assert janitor.version == VERSION


@patch("pytest_postgresql.janitor.psycopg.connect")
def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None:
"""Test that the cursor requests the postgres database."""
janitor = DatabaseJanitor("user", "host", "1234", "database_name", 10)
janitor = DatabaseJanitor(
user="user", host="host", port="1234", dbname="database_name", version=10
)
with janitor.cursor():
connect_mock.assert_called_once_with(
dbname="postgres", user="user", password=None, host="host", port="1234"
Expand All @@ -32,7 +36,14 @@ def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None:
@patch("pytest_postgresql.janitor.psycopg.connect")
def test_cursor_connects_with_password(connect_mock: MagicMock) -> None:
"""Test that the cursor requests the postgres database."""
janitor = DatabaseJanitor("user", "host", "1234", "database_name", 10, "some_password")
janitor = DatabaseJanitor(
user="user",
host="host",
port="1234",
dbname="database_name",
version=10,
password="some_password",
)
with janitor.cursor():
connect_mock.assert_called_once_with(
dbname="postgres", user="user", password="some_password", host="host", port="1234"
Expand Down