6
6
7
7
8
8
import asyncio
9
+ from concurrent .futures ._base import TimeoutError
9
10
import functools
10
11
import inspect
11
12
import time
15
16
from . import exceptions
16
17
17
18
19
+ BAD_CONN_EXCEPTION = (
20
+ exceptions ._base .PostgresError ,
21
+ exceptions ._base .FatalPostgresError ,
22
+ exceptions ._base .UnknownPostgresError ,
23
+ TimeoutError ,
24
+ ConnectionRefusedError ,
25
+ )
26
+
27
+
18
28
class PoolConnectionProxyMeta (type ):
19
29
20
30
def __new__ (mcls , name , bases , dct , * , wrap = False ):
@@ -96,10 +106,12 @@ class PoolConnectionHolder:
96
106
'_connect_args' , '_connect_kwargs' ,
97
107
'_max_queries' , '_setup' , '_init' ,
98
108
'_max_inactive_time' , '_in_use' ,
99
- '_inactive_callback' , '_timeout' )
109
+ '_inactive_callback' , '_timeout' ,
110
+ '_max_consecutive_exceptions' , '_consecutive_exceptions' )
100
111
101
112
def __init__ (self , pool , * , connect_args , connect_kwargs ,
102
- max_queries , setup , init , max_inactive_time ):
113
+ max_queries , setup , init , max_inactive_time ,
114
+ max_consecutive_exceptions ):
103
115
104
116
self ._pool = pool
105
117
self ._con = None
@@ -108,6 +120,8 @@ def __init__(self, pool, *, connect_args, connect_kwargs,
108
120
self ._connect_kwargs = connect_kwargs
109
121
self ._max_queries = max_queries
110
122
self ._max_inactive_time = max_inactive_time
123
+ self ._max_consecutive_exceptions = max_consecutive_exceptions
124
+ self ._consecutive_exceptions = 0
111
125
self ._setup = setup
112
126
self ._init = init
113
127
self ._inactive_callback = None
@@ -259,6 +273,16 @@ def _deactivate_connection(self):
259
273
self ._con .terminate ()
260
274
self ._con = None
261
275
276
+ async def maybe_close_bad_connection (self , exc_type ):
277
+ if self ._max_consecutive_exceptions > 0 and \
278
+ isinstance (exc_type , BAD_CONN_EXCEPTION ):
279
+
280
+ self ._consecutive_exceptions += 1
281
+
282
+ if self ._consecutive_exceptions > self ._max_consecutive_exceptions :
283
+ await self .close ()
284
+ self ._consecutive_exceptions = 0
285
+
262
286
263
287
class Pool :
264
288
"""A connection pool.
@@ -285,6 +309,7 @@ def __init__(self, *connect_args,
285
309
init ,
286
310
loop ,
287
311
connection_class ,
312
+ max_consecutive_exceptions ,
288
313
** connect_kwargs ):
289
314
290
315
if loop is None :
@@ -331,6 +356,7 @@ def __init__(self, *connect_args,
331
356
connect_kwargs = connect_kwargs ,
332
357
max_queries = max_queries ,
333
358
max_inactive_time = max_inactive_connection_lifetime ,
359
+ max_consecutive_exceptions = max_consecutive_exceptions ,
334
360
setup = setup ,
335
361
init = init )
336
362
@@ -380,6 +406,7 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
380
406
async with self .acquire () as con :
381
407
return await con .execute (query , * args , timeout = timeout )
382
408
409
+
383
410
async def executemany (self , command : str , args , * , timeout : float = None ):
384
411
"""Execute an SQL *command* for each sequence of arguments in *args*.
385
412
@@ -459,7 +486,8 @@ async def _acquire_impl():
459
486
ch = await self ._queue .get () # type: PoolConnectionHolder
460
487
try :
461
488
proxy = await ch .acquire () # type: PoolConnectionProxy
462
- except Exception :
489
+ except Exception as e :
490
+ await ch .maybe_close_bad_connection (e )
463
491
self ._queue .put_nowait (ch )
464
492
raise
465
493
else :
@@ -580,6 +608,11 @@ async def __aexit__(self, *exc):
580
608
self .done = True
581
609
con = self .connection
582
610
self .connection = None
611
+ if not exc [0 ]:
612
+ con ._holder ._consecutive_exceptions = 0
613
+ else :
614
+ # Pass exception type to ConnectionHolder
615
+ await con ._holder .maybe_close_bad_connection (exc [0 ])
583
616
await self .pool .release (con )
584
617
585
618
def __await__ (self ):
@@ -592,6 +625,7 @@ def create_pool(dsn=None, *,
592
625
max_size = 10 ,
593
626
max_queries = 50000 ,
594
627
max_inactive_connection_lifetime = 300.0 ,
628
+ max_consecutive_exceptions = 0 ,
595
629
setup = None ,
596
630
init = None ,
597
631
loop = None ,
@@ -651,6 +685,12 @@ def create_pool(dsn=None, *,
651
685
Number of seconds after which inactive connections in the
652
686
pool will be closed. Pass ``0`` to disable this mechanism.
653
687
688
+ :param int max_consecutive_exceptions:
689
+ the maximum number of consecutive exceptions that may be raised by a
690
+ single connection before that connection is assumed corrupt (ex.
691
+ pointing to an old DB after a failover) and will therefore be closed.
692
+ Pass ``0`` to disable.
693
+
654
694
:param coroutine setup:
655
695
A coroutine to prepare a connection right before it is returned
656
696
from :meth:`Pool.acquire() <pool.Pool.acquire>`. An example use
@@ -699,4 +739,5 @@ def create_pool(dsn=None, *,
699
739
min_size = min_size , max_size = max_size ,
700
740
max_queries = max_queries , loop = loop , setup = setup , init = init ,
701
741
max_inactive_connection_lifetime = max_inactive_connection_lifetime ,
742
+ max_consecutive_exceptions = max_consecutive_exceptions ,
702
743
** connect_kwargs )
0 commit comments