66
77
88import  asyncio 
9+ from  concurrent .futures ._base  import  TimeoutError 
910import  functools 
1011import  inspect 
1112import  time 
1516from  . import  exceptions 
1617
1718
19+ BAD_CONN_EXCEPTION  =  (
20+     exceptions ._base .PostgresError ,
21+     exceptions ._base .FatalPostgresError ,
22+     exceptions ._base .UnknownPostgresError ,
23+     TimeoutError ,
24+     ConnectionRefusedError ,
25+ )
26+ 
27+ 
1828class  PoolConnectionProxyMeta (type ):
1929
2030    def  __new__ (mcls , name , bases , dct , * , wrap = False ):
@@ -96,10 +106,12 @@ class PoolConnectionHolder:
96106                 '_connect_args' , '_connect_kwargs' ,
97107                 '_max_queries' , '_setup' , '_init' ,
98108                 '_max_inactive_time' , '_in_use' ,
99-                  '_inactive_callback' , '_timeout' )
109+                  '_inactive_callback' , '_timeout' ,
110+                  '_max_consecutive_exceptions' , '_consecutive_exceptions' )
100111
101112    def  __init__ (self , pool , * , connect_args , connect_kwargs ,
102-                  max_queries , setup , init , max_inactive_time ):
113+                  max_queries , setup , init , max_inactive_time ,
114+                  max_consecutive_exceptions ):
103115
104116        self ._pool  =  pool 
105117        self ._con  =  None 
@@ -108,6 +120,8 @@ def __init__(self, pool, *, connect_args, connect_kwargs,
108120        self ._connect_kwargs  =  connect_kwargs 
109121        self ._max_queries  =  max_queries 
110122        self ._max_inactive_time  =  max_inactive_time 
123+         self ._max_consecutive_exceptions  =  max_consecutive_exceptions 
124+         self ._consecutive_exceptions  =  0 
111125        self ._setup  =  setup 
112126        self ._init  =  init 
113127        self ._inactive_callback  =  None 
@@ -259,6 +273,16 @@ def _deactivate_connection(self):
259273        self ._con .terminate ()
260274        self ._con  =  None 
261275
276+     async  def  maybe_close_bad_connection (self , exc_type ):
277+         if  self ._max_consecutive_exceptions  >  0  and  \
278+                 isinstance (exc_type , BAD_CONN_EXCEPTION ):
279+ 
280+             self ._consecutive_exceptions  +=  1 
281+ 
282+             if  self ._consecutive_exceptions  >  self ._max_consecutive_exceptions :
283+                 await  self .close ()
284+                 self ._consecutive_exceptions  =  0 
285+ 
262286
263287class  Pool :
264288    """A connection pool. 
@@ -285,6 +309,7 @@ def __init__(self, *connect_args,
285309                 init ,
286310                 loop ,
287311                 connection_class ,
312+                  max_consecutive_exceptions ,
288313                 ** connect_kwargs ):
289314
290315        if  loop  is  None :
@@ -331,6 +356,7 @@ def __init__(self, *connect_args,
331356                connect_kwargs = connect_kwargs ,
332357                max_queries = max_queries ,
333358                max_inactive_time = max_inactive_connection_lifetime ,
359+                 max_consecutive_exceptions = max_consecutive_exceptions ,
334360                setup = setup ,
335361                init = init )
336362
@@ -459,7 +485,8 @@ async def _acquire_impl():
459485            ch  =  await  self ._queue .get ()  # type: PoolConnectionHolder 
460486            try :
461487                proxy  =  await  ch .acquire ()  # type: PoolConnectionProxy 
462-             except  Exception :
488+             except  Exception  as  e :
489+                 await  ch .maybe_close_bad_connection (e )
463490                self ._queue .put_nowait (ch )
464491                raise 
465492            else :
@@ -580,6 +607,11 @@ async def __aexit__(self, *exc):
580607        self .done  =  True 
581608        con  =  self .connection 
582609        self .connection  =  None 
610+         if  not  exc [0 ]:
611+             con ._holder ._consecutive_exceptions  =  0 
612+         else :
613+             # Pass exception type to ConnectionHolder 
614+             await  con ._holder .maybe_close_bad_connection (exc [0 ])
583615        await  self .pool .release (con )
584616
585617    def  __await__ (self ):
@@ -592,6 +624,7 @@ def create_pool(dsn=None, *,
592624                max_size = 10 ,
593625                max_queries = 50000 ,
594626                max_inactive_connection_lifetime = 300.0 ,
627+                 max_consecutive_exceptions = 0 ,
595628                setup = None ,
596629                init = None ,
597630                loop = None ,
@@ -651,6 +684,12 @@ def create_pool(dsn=None, *,
651684        Number of seconds after which inactive connections in the 
652685        pool will be closed.  Pass ``0`` to disable this mechanism. 
653686
687+     :param int max_consecutive_exceptions: 
688+         the maximum number of consecutive exceptions that may be raised by a 
689+         single connection before that connection is assumed corrupt (ex. 
690+         pointing to an old DB after a failover) and will therefore be closed. 
691+         Pass ``0`` to disable. 
692+ 
654693    :param coroutine setup: 
655694        A coroutine to prepare a connection right before it is returned 
656695        from :meth:`Pool.acquire() <pool.Pool.acquire>`.  An example use 
@@ -699,4 +738,5 @@ def create_pool(dsn=None, *,
699738        min_size = min_size , max_size = max_size ,
700739        max_queries = max_queries , loop = loop , setup = setup , init = init ,
701740        max_inactive_connection_lifetime = max_inactive_connection_lifetime ,
741+         max_consecutive_exceptions = max_consecutive_exceptions ,
702742        ** connect_kwargs )
0 commit comments