Skip to content

Commit b25a5bd

Browse files
authored
Merge pull request #874 from ClearcodeHQ/path-type
Address the issue where pytest-postgresql incorrectly determined the …
2 parents c3be976 + 44524cc commit b25a5bd

File tree

15 files changed

+123
-43
lines changed

15 files changed

+123
-43
lines changed

README.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,10 @@ Client specific loads the database each test
104104

105105
.. code-block:: python
106106
107+
from pathlib import Path
107108
postgresql_my_with_schema = factories.postgresql(
108109
'postgresql_my_proc',
109-
load=["schemafile.sql", "otherschema.sql", "import.path.to.function", "import.path.to:otherfunction", load_this]
110+
load=[Path("schemafile.sql"), Path("otherschema.sql"), "import.path.to.function", "import.path.to:otherfunction", load_this]
110111
)
111112
112113
.. warning::
@@ -115,12 +116,13 @@ Client specific loads the database each test
115116

116117

117118
The process fixture performs the load once per test session, and loads the data into the template database.
118-
Client fixture then creates test database out of the template database each test, which significantly speeds up the tests.
119+
Client fixture then creates test database out of the template database each test, which significantly **speeds up the tests**.
119120

120121
.. code-block:: python
121122
123+
from pathlib import Path
122124
postgresql_my_proc = factories.postgresql_proc(
123-
load=["schemafile.sql", "otherschema.sql", "import.path.to.function", "import.path.to:otherfunction", load_this]
125+
load=[Path("schemafile.sql"), Path("otherschema.sql"), "import.path.to.function", "import.path.to:otherfunction", load_this]
124126
)
125127
126128

newsfragments/638.feature.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Now all sql files used to initialise database for tests, has to be passed as pathlib.Path instance.
2+
3+
This helps the DatabaseJanitor choose correct behaviour based on parameter.

pytest_postgresql/config.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Plugin's configuration."""
22

3-
from typing import Any, List, Optional, TypedDict
3+
from pathlib import Path
4+
from typing import Any, List, Optional, TypedDict, Union
45

56
from pytest import FixtureRequest
67

@@ -17,7 +18,7 @@ class PostgresqlConfigDict(TypedDict):
1718
startparams: str
1819
unixsocketdir: str
1920
dbname: str
20-
load: List[str]
21+
load: List[Union[Path, str]]
2122
postgres_options: str
2223

2324

@@ -28,6 +29,8 @@ def get_postgresql_option(option: str) -> Any:
2829
name = "postgresql_" + option
2930
return request.config.getoption(name) or request.config.getini(name)
3031

32+
load_paths = detect_paths(get_postgresql_option("load"))
33+
3134
return PostgresqlConfigDict(
3235
exec=get_postgresql_option("exec"),
3336
host=get_postgresql_option("host"),
@@ -38,6 +41,17 @@ def get_postgresql_option(option: str) -> Any:
3841
startparams=get_postgresql_option("startparams"),
3942
unixsocketdir=get_postgresql_option("unixsocketdir"),
4043
dbname=get_postgresql_option("dbname"),
41-
load=get_postgresql_option("load"),
44+
load=load_paths,
4245
postgres_options=get_postgresql_option("postgres_options"),
4346
)
47+
48+
49+
def detect_paths(load_paths: List[str]) -> List[Union[Path, str]]:
50+
"""Covnerts path to sql files to Path instances."""
51+
converted_load_paths: List[Union[Path, str]] = []
52+
for path in load_paths:
53+
if path.endswith(".sql"):
54+
converted_load_paths.append(Path(path))
55+
else:
56+
converted_load_paths.append(path)
57+
return converted_load_paths

pytest_postgresql/factories/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# You should have received a copy of the GNU Lesser General Public License
1717
# along with pytest-dbfixtures. If not, see <http://www.gnu.org/licenses/>.
1818
"""Fixture factory for postgresql client."""
19+
from pathlib import Path
1920
from typing import Callable, Iterator, List, Optional, Union
2021

2122
import psycopg
@@ -31,7 +32,7 @@
3132
def postgresql(
3233
process_fixture_name: str,
3334
dbname: Optional[str] = None,
34-
load: Optional[List[Union[Callable, str]]] = None,
35+
load: Optional[List[Union[Callable, str, Path]]] = None,
3536
isolation_level: "Optional[psycopg.IsolationLevel]" = None,
3637
) -> Callable[[FixtureRequest], Iterator[Connection]]:
3738
"""Return connection fixture factory for PostgreSQL.

pytest_postgresql/factories/noprocess.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# along with pytest-dbfixtures. If not, see <http://www.gnu.org/licenses/>.
1818
"""Fixture factory for existing postgresql server."""
1919
import os
20+
from pathlib import Path
2021
from typing import Callable, Iterator, List, Optional, Union
2122

2223
import pytest
@@ -42,7 +43,7 @@ def postgresql_noproc(
4243
password: Optional[str] = None,
4344
dbname: Optional[str] = None,
4445
options: str = "",
45-
load: Optional[List[Union[Callable, str]]] = None,
46+
load: Optional[List[Union[Callable, str, Path]]] = None,
4647
) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]:
4748
"""Postgresql noprocess factory.
4849

pytest_postgresql/factories/process.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os.path
2020
import platform
2121
import subprocess
22+
from pathlib import Path
2223
from typing import Callable, Iterator, List, Optional, Set, Tuple, Union
2324

2425
import pytest
@@ -54,7 +55,7 @@ def postgresql_proc(
5455
startparams: Optional[str] = None,
5556
unixsocketdir: Optional[str] = None,
5657
postgres_options: Optional[str] = None,
57-
load: Optional[List[Union[Callable, str]]] = None,
58+
load: Optional[List[Union[Callable, str, Path]]] = None,
5859
) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]:
5960
"""Postgresql process factory.
6061

pytest_postgresql/janitor.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
"""Database Janitor."""
22

3-
import re
43
from contextlib import contextmanager
5-
from functools import partial
4+
from pathlib import Path
65
from types import TracebackType
76
from typing import Callable, Iterator, Optional, Type, TypeVar, Union
87

98
import psycopg
109
from packaging.version import parse
1110
from psycopg import Connection, Cursor
1211

12+
from pytest_postgresql.loader import build_loader
1313
from pytest_postgresql.retry import retry
14-
from pytest_postgresql.sql import loader
1514

1615
Version = type(parse("1"))
1716

@@ -104,23 +103,17 @@ def _terminate_connection(cur: Cursor, dbname: str) -> None:
104103
(dbname,),
105104
)
106105

107-
def load(self, load: Union[Callable, str]) -> None:
106+
def load(self, load: Union[Callable, str, Path]) -> None:
108107
"""Load data into a database.
109108
110-
Either runs a passed loader if it's callback,
111-
or runs predefined loader if it's sql file.
109+
Expects:
110+
111+
* a Path to sql file, that'll be loaded
112+
* an import path to import callable
113+
* a callable that expects: host, port, user, dbname and password arguments.
114+
112115
"""
113-
if isinstance(load, str):
114-
if "/" in load:
115-
_loader: Callable = partial(loader, load)
116-
else:
117-
loader_parts = re.split("[.:]", load, 2)
118-
import_path = ".".join(loader_parts[:-1])
119-
loader_name = loader_parts[-1]
120-
_temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name])
121-
_loader = getattr(_temp_import, loader_name)
122-
else:
123-
_loader = load
116+
_loader = build_loader(load)
124117
_loader(
125118
host=self.host,
126119
port=self.port,

pytest_postgresql/loader.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Loader helper functions."""
2+
3+
import re
4+
from functools import partial
5+
from pathlib import Path
6+
from typing import Any, Callable, Union
7+
8+
import psycopg
9+
10+
11+
def build_loader(load: Union[Callable, str, Path]) -> Callable:
12+
"""Build a loader callable."""
13+
if isinstance(load, Path):
14+
return partial(sql, load)
15+
elif isinstance(load, str):
16+
loader_parts = re.split("[.:]", load, 2)
17+
import_path = ".".join(loader_parts[:-1])
18+
loader_name = loader_parts[-1]
19+
_temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name])
20+
_loader: Callable = getattr(_temp_import, loader_name)
21+
return _loader
22+
else:
23+
return load
24+
25+
26+
def sql(sql_filename: Path, **kwargs: Any) -> None:
27+
"""Database loader for sql files."""
28+
db_connection = psycopg.connect(**kwargs)
29+
with open(sql_filename, "r") as _fd:
30+
with db_connection.cursor() as cur:
31+
cur.execute(_fd.read())
32+
db_connection.commit()

pytest_postgresql/sql.py

Lines changed: 0 additions & 14 deletions
This file was deleted.

tests/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Tests main conftest file."""
22

33
import os
4+
from pathlib import Path
45

56
from pytest_postgresql import factories
67
from pytest_postgresql.plugin import * # noqa: F403,F401
@@ -10,18 +11,20 @@
1011

1112

1213
TEST_SQL_DIR = os.path.dirname(os.path.abspath(__file__)) + "/test_sql/"
14+
TEST_SQL_FILE = Path(TEST_SQL_DIR + "test.sql")
15+
TEST_SQL_FILE2 = Path(TEST_SQL_DIR + "test2.sql")
1316

1417
postgresql_proc2 = factories.postgresql_proc(port=None)
1518
postgresql2 = factories.postgresql("postgresql_proc2", dbname="test-db")
1619
postgresql_load_1 = factories.postgresql(
1720
"postgresql_proc2",
1821
dbname="test-load-db",
1922
load=[
20-
TEST_SQL_DIR + "test.sql",
23+
TEST_SQL_FILE,
2124
],
2225
)
2326
postgresql_load_2 = factories.postgresql(
2427
"postgresql_proc2",
2528
dbname="test-load-moredb",
26-
load=[TEST_SQL_DIR + "test.sql", TEST_SQL_DIR + "test2.sql"],
29+
load=[TEST_SQL_FILE, TEST_SQL_FILE2],
2730
)

0 commit comments

Comments
 (0)