Skip to content

Commit a4eb403

Browse files
committed
Have separate paratemeter for template and regular database name - closes #672
1 parent d821171 commit a4eb403

File tree

6 files changed

+54
-29
lines changed

6 files changed

+54
-29
lines changed

newsfragments/672.feature.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Have separate parameters for template database name and database name in DatabaseJanitor.
2+
It'll make it much clearer to understand the code and Janitor's behaviour.

pytest_postgresql/factories/client.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
6363
pg_user = proc_fixture.user
6464
pg_password = proc_fixture.password
6565
pg_options = proc_fixture.options
66-
pg_db = dbname or proc_fixture.dbname
66+
pg_db = dbname
67+
pg_template = None
68+
if not dbname:
69+
pg_db = proc_fixture.dbname
70+
pg_template = f"{pg_db}_tmpl"
6771
pg_load = load or []
6872
if pg_load:
6973
warnings.warn(
@@ -75,9 +79,15 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
7579
),
7680
category=DeprecationWarning,
7781
)
78-
7982
with DatabaseJanitor(
80-
pg_user, pg_host, pg_port, pg_db, proc_fixture.version, pg_password, isolation_level
83+
user=pg_user,
84+
host=pg_host,
85+
port=pg_port,
86+
dbname=pg_db,
87+
template_dbname=pg_template,
88+
version=proc_fixture.version,
89+
password=pg_password,
90+
isolation_level=isolation_level,
8191
) as janitor:
8292
db_connection: Connection = psycopg.connect(
8393
dbname=pg_db,

pytest_postgresql/factories/noprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def postgresql_noproc_fixture(request: FixtureRequest) -> Iterator[NoopExecutor]
8686
user=noop_exec.user,
8787
host=noop_exec.host,
8888
port=noop_exec.port,
89-
dbname=template_dbname,
89+
template_dbname=template_dbname,
9090
version=noop_exec.version,
9191
password=noop_exec.password,
9292
) as janitor:

pytest_postgresql/factories/process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def postgresql_proc_fixture(
141141
user=postgresql_executor.user,
142142
host=postgresql_executor.host,
143143
port=postgresql_executor.port,
144-
dbname=template_dbname,
144+
template_dbname=template_dbname,
145145
version=postgresql_executor.version,
146146
password=postgresql_executor.password,
147147
) as janitor:

pytest_postgresql/janitor.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ def __init__(
2626
user: str,
2727
host: str,
2828
port: Union[str, int],
29-
dbname: str,
3029
version: Union[str, float, Version], # type: ignore[valid-type]
30+
dbname: Optional[str] = None,
31+
template_dbname: Optional[str] = None,
3132
password: Optional[str] = None,
3233
isolation_level: "Optional[psycopg.IsolationLevel]" = None,
3334
connection_timeout: int = 60,
@@ -49,7 +50,10 @@ def __init__(
4950
self.password = password
5051
self.host = host
5152
self.port = port
53+
# At least one of the dbname or template_dbname has to be filled.
54+
assert any([dbname, template_dbname])
5255
self.dbname = dbname
56+
self.template_dbname = template_dbname
5357
self._connection_timeout = connection_timeout
5458
self.isolation_level = isolation_level
5559
if not isinstance(version, Version):
@@ -59,36 +63,33 @@ def __init__(
5963

6064
def init(self) -> None:
6165
"""Create database in postgresql."""
62-
template_name = f"{self.dbname}_tmpl"
6366
with self.cursor() as cur:
64-
if self.dbname.endswith("_tmpl"):
65-
result = False
66-
else:
67-
cur.execute(
68-
"SELECT EXISTS "
69-
"(SELECT datname FROM pg_catalog.pg_database WHERE datname= %s);",
70-
(template_name,),
71-
)
72-
row = cur.fetchone()
73-
result = (row is not None) and row[0]
74-
if not result:
67+
if self.is_template():
68+
cur.execute(f'CREATE DATABASE "{self.template_dbname}";')
69+
elif self.template_dbname is None:
7570
cur.execute(f'CREATE DATABASE "{self.dbname}";')
7671
else:
7772
# All template database does not allow connection:
78-
self._dont_datallowconn(cur, template_name)
73+
self._dont_datallowconn(cur, self.template_dbname)
7974
# And make sure no-one is left connected to the template database.
80-
# Otherwise Creating database from template will fail
81-
self._terminate_connection(cur, template_name)
82-
cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{template_name}";')
75+
# Otherwise, Creating database from template will fail
76+
self._terminate_connection(cur, self.template_dbname)
77+
cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{self.template_dbname}";')
78+
79+
def is_template(self) -> bool:
80+
"""Determine whether the DatabaseJanitor maintains template or database."""
81+
return self.dbname is None
8382

8483
def drop(self) -> None:
8584
"""Drop database in postgresql."""
8685
# We cannot drop the database while there are connections to it, so we
8786
# terminate all connections first while not allowing new connections.
87+
db_to_drop = self.template_dbname if self.is_template() else self.dbname
88+
assert db_to_drop
8889
with self.cursor() as cur:
89-
self._dont_datallowconn(cur, self.dbname)
90-
self._terminate_connection(cur, self.dbname)
91-
cur.execute(f'DROP DATABASE IF EXISTS "{self.dbname}";')
90+
self._dont_datallowconn(cur, db_to_drop)
91+
self._terminate_connection(cur, db_to_drop)
92+
cur.execute(f'DROP DATABASE IF EXISTS "{db_to_drop}";')
9293

9394
@staticmethod
9495
def _dont_datallowconn(cur: Cursor, dbname: str) -> None:
@@ -113,12 +114,13 @@ def load(self, load: Union[Callable, str, Path]) -> None:
113114
* a callable that expects: host, port, user, dbname and password arguments.
114115
115116
"""
117+
db_to_load = self.template_dbname if self.is_template() else self.dbname
116118
_loader = build_loader(load)
117119
_loader(
118120
host=self.host,
119121
port=self.port,
120122
user=self.user,
121-
dbname=self.dbname,
123+
dbname=db_to_load,
122124
password=self.password,
123125
)
124126

tests/test_janitor.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@
1515
@pytest.mark.parametrize("version", (VERSION, 10, "10"))
1616
def test_version_cast(version: Any) -> None:
1717
"""Test that version is cast to Version object."""
18-
janitor = DatabaseJanitor("user", "host", "1234", "database_name", version)
18+
janitor = DatabaseJanitor(
19+
user="user", host="host", port="1234", dbname="database_name", version=version
20+
)
1921
assert janitor.version == VERSION
2022

2123

2224
@patch("pytest_postgresql.janitor.psycopg.connect")
2325
def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None:
2426
"""Test that the cursor requests the postgres database."""
25-
janitor = DatabaseJanitor("user", "host", "1234", "database_name", 10)
27+
janitor = DatabaseJanitor(
28+
user="user", host="host", port="1234", dbname="database_name", version=10
29+
)
2630
with janitor.cursor():
2731
connect_mock.assert_called_once_with(
2832
dbname="postgres", user="user", password=None, host="host", port="1234"
@@ -32,7 +36,14 @@ def test_cursor_selects_postgres_database(connect_mock: MagicMock) -> None:
3236
@patch("pytest_postgresql.janitor.psycopg.connect")
3337
def test_cursor_connects_with_password(connect_mock: MagicMock) -> None:
3438
"""Test that the cursor requests the postgres database."""
35-
janitor = DatabaseJanitor("user", "host", "1234", "database_name", 10, "some_password")
39+
janitor = DatabaseJanitor(
40+
user="user",
41+
host="host",
42+
port="1234",
43+
dbname="database_name",
44+
version=10,
45+
password="some_password",
46+
)
3647
with janitor.cursor():
3748
connect_mock.assert_called_once_with(
3849
dbname="postgres", user="user", password="some_password", host="host", port="1234"

0 commit comments

Comments
 (0)