@@ -131,9 +131,10 @@ def get_status(self) -> str:
131
131
async def connect (self ,
132
132
loop : typing .Optional [asyncio .AbstractEventLoop ] = None ,
133
133
** kwargs : typing .Any ) -> 'connection.Connection' :
134
- conn_info : typing .Optional [typing .Any ] = self .get_connection_spec ()
135
- conn_info .update (kwargs ) # type: ignore[union-attr]
136
- return await asyncpg .connect (loop = loop , ** conn_info ) # type: ignore[misc] # noqa: E501
134
+ conn_info = typing .cast (typing .Dict [str , typing .Any ],
135
+ self .get_connection_spec ())
136
+ conn_info .update (kwargs )
137
+ return await asyncpg .connect (loop = loop , ** conn_info )
137
138
138
139
def init (self , ** settings : str ) -> str :
139
140
"""Initialize cluster."""
@@ -307,12 +308,17 @@ def _get_connection_spec(self) -> typing.Optional[_ConnectionSpec]:
307
308
308
309
return None
309
310
310
- def get_connection_spec (self ) -> typing . Optional [ _ConnectionSpec ] :
311
+ def get_connection_spec (self ) -> _ConnectionSpec :
311
312
status = self .get_status ()
312
313
if status != 'running' :
313
314
raise ClusterError ('cluster is not running' )
314
315
315
- return self ._get_connection_spec ()
316
+ spec = self ._get_connection_spec ()
317
+
318
+ if spec is None :
319
+ raise ClusterError ('cannot determine server connection address' )
320
+
321
+ return spec
316
322
317
323
def override_connection_spec (self , ** kwargs : str ) -> None :
318
324
self ._connection_spec_override = typing .cast (_ConnectionSpec , kwargs )
0 commit comments