@@ -26,8 +26,9 @@ def __init__(
2626 user : str ,
2727 host : str ,
2828 port : Union [str , int ],
29- dbname : str ,
3029 version : Union [str , float , Version ], # type: ignore[valid-type]
30+ dbname : Optional [str ] = None ,
31+ template_dbname : Optional [str ] = None ,
3132 password : Optional [str ] = None ,
3233 isolation_level : "Optional[psycopg.IsolationLevel]" = None ,
3334 connection_timeout : int = 60 ,
@@ -38,6 +39,7 @@ def __init__(
3839 :param host: postgresql host
3940 :param port: postgresql port
4041 :param dbname: database name
42+ :param dbname: template database name
4143 :param version: postgresql version number
4244 :param password: optional postgresql password
4345 :param isolation_level: optional postgresql isolation level
@@ -49,7 +51,10 @@ def __init__(
4951 self .password = password
5052 self .host = host
5153 self .port = port
54+ # At least one of the dbname or template_dbname has to be filled.
55+ assert any ([dbname , template_dbname ])
5256 self .dbname = dbname
57+ self .template_dbname = template_dbname
5358 self ._connection_timeout = connection_timeout
5459 self .isolation_level = isolation_level
5560 if not isinstance (version , Version ):
@@ -59,36 +64,33 @@ def __init__(
5964
6065 def init (self ) -> None :
6166 """Create database in postgresql."""
62- template_name = f"{ self .dbname } _tmpl"
6367 with self .cursor () as cur :
64- if self .dbname .endswith ("_tmpl" ):
65- result = False
66- else :
67- cur .execute (
68- "SELECT EXISTS "
69- "(SELECT datname FROM pg_catalog.pg_database WHERE datname= %s);" ,
70- (template_name ,),
71- )
72- row = cur .fetchone ()
73- result = (row is not None ) and row [0 ]
74- if not result :
68+ if self .is_template ():
69+ cur .execute (f'CREATE DATABASE "{ self .template_dbname } ";' )
70+ elif self .template_dbname is None :
7571 cur .execute (f'CREATE DATABASE "{ self .dbname } ";' )
7672 else :
7773 # All template database does not allow connection:
78- self ._dont_datallowconn (cur , template_name )
74+ self ._dont_datallowconn (cur , self . template_dbname )
7975 # And make sure no-one is left connected to the template database.
80- # Otherwise Creating database from template will fail
81- self ._terminate_connection (cur , template_name )
82- cur .execute (f'CREATE DATABASE "{ self .dbname } " TEMPLATE "{ template_name } ";' )
76+ # Otherwise, Creating database from template will fail
77+ self ._terminate_connection (cur , self .template_dbname )
78+ cur .execute (f'CREATE DATABASE "{ self .dbname } " TEMPLATE "{ self .template_dbname } ";' )
79+
80+ def is_template (self ) -> bool :
81+ """Determine whether the DatabaseJanitor maintains template or database."""
82+ return self .dbname is None
8383
8484 def drop (self ) -> None :
8585 """Drop database in postgresql."""
8686 # We cannot drop the database while there are connections to it, so we
8787 # terminate all connections first while not allowing new connections.
88+ db_to_drop = self .template_dbname if self .is_template () else self .dbname
89+ assert db_to_drop
8890 with self .cursor () as cur :
89- self ._dont_datallowconn (cur , self . dbname )
90- self ._terminate_connection (cur , self . dbname )
91- cur .execute (f'DROP DATABASE IF EXISTS "{ self . dbname } ";' )
91+ self ._dont_datallowconn (cur , db_to_drop )
92+ self ._terminate_connection (cur , db_to_drop )
93+ cur .execute (f'DROP DATABASE IF EXISTS "{ db_to_drop } ";' )
9294
9395 @staticmethod
9496 def _dont_datallowconn (cur : Cursor , dbname : str ) -> None :
@@ -113,12 +115,13 @@ def load(self, load: Union[Callable, str, Path]) -> None:
113115 * a callable that expects: host, port, user, dbname and password arguments.
114116
115117 """
118+ db_to_load = self .template_dbname if self .is_template () else self .dbname
116119 _loader = build_loader (load )
117120 _loader (
118121 host = self .host ,
119122 port = self .port ,
120123 user = self .user ,
121- dbname = self . dbname ,
124+ dbname = db_to_load ,
122125 password = self .password ,
123126 )
124127
0 commit comments