@@ -305,7 +305,7 @@ class Pool:
305305    """ 
306306
307307    __slots__  =  (
308-         '_queue' , '_loop' , '_minsize' , '_maxsize' ,
308+         '_queue' , '_loop' , '_minsize' , '_maxsize' ,  '_middlewares' , 
309309        '_init' , '_connect_args' , '_connect_kwargs' ,
310310        '_working_addr' , '_working_config' , '_working_params' ,
311311        '_holders' , '_initialized' , '_initializing' , '_closing' ,
@@ -320,6 +320,7 @@ def __init__(self, *connect_args,
320320                 max_inactive_connection_lifetime ,
321321                 setup ,
322322                 init ,
323+                  middlewares ,
323324                 loop ,
324325                 connection_class ,
325326                 ** connect_kwargs ):
@@ -377,6 +378,7 @@ def __init__(self, *connect_args,
377378        self ._closed  =  False 
378379        self ._generation  =  0 
379380        self ._init  =  init 
381+         self ._middlewares  =  middlewares 
380382        self ._connect_args  =  connect_args 
381383        self ._connect_kwargs  =  connect_kwargs 
382384
@@ -469,6 +471,7 @@ async def _get_new_connection(self):
469471                * self ._connect_args ,
470472                loop = self ._loop ,
471473                connection_class = self ._connection_class ,
474+                 middlewares = self ._middlewares ,
472475                ** self ._connect_kwargs )
473476
474477            self ._working_addr  =  con ._addr 
@@ -483,6 +486,7 @@ async def _get_new_connection(self):
483486                addr = self ._working_addr ,
484487                timeout = self ._working_params .connect_timeout ,
485488                config = self ._working_config ,
489+                 middlewares = self ._middlewares ,
486490                params = self ._working_params ,
487491                connection_class = self ._connection_class )
488492
@@ -784,13 +788,37 @@ def __await__(self):
784788        return  self .pool ._acquire (self .timeout ).__await__ ()
785789
786790
791+ def  middleware (f ):
792+     """Decorator for adding a middleware 
793+ 
794+     Can be used like such 
795+ 
796+     .. code-block:: python 
797+ 
798+         @pool.middleware 
799+         async def my_middleware(query, args, limit, 
800+                                 timeout, return_status, *, handler, conn): 
801+             print('do something before') 
802+             result, stmt = await handler(query, args, limit, 
803+                                          timeout, return_status) 
804+             print('do something after') 
805+             return result, stmt 
806+ 
807+         my_pool = await pool.create_pool(middlewares=[my_middleware]) 
808+     """ 
809+     async  def  middleware_factory (connection , handler ):
810+         return  functools .partial (f , connection = connection , handler = handler )
811+     return  middleware_factory 
812+ 
813+ 
787814def  create_pool (dsn = None , * ,
788815                min_size = 10 ,
789816                max_size = 10 ,
790817                max_queries = 50000 ,
791818                max_inactive_connection_lifetime = 300.0 ,
792819                setup = None ,
793820                init = None ,
821+                 middlewares = None ,
794822                loop = None ,
795823                connection_class = connection .Connection ,
796824                ** connect_kwargs ):
@@ -866,6 +894,19 @@ def create_pool(dsn=None, *,
866894        or :meth:`Connection.set_type_codec() <\ 
867895        asyncpg.connection.Connection.set_type_codec>`. 
868896
897+     :param middlewares: 
898+         A list of middleware functions to be middleware just 
899+         before a connection excecutes a statement. 
900+         Syntax of a middleware is as follows: 
901+         async def middleware_factory(connection, handler): 
902+             async def middleware(query, args, limit, timeout, return_status): 
903+                 print('do something before') 
904+                 result, stmt = await handler(query, args, limit, 
905+                                              timeout, return_status) 
906+                 print('do something after') 
907+                 return result, stmt 
908+             return middleware 
909+ 
869910    :param loop: 
870911        An asyncio event loop instance.  If ``None``, the default 
871912        event loop will be used. 
@@ -893,6 +934,7 @@ def create_pool(dsn=None, *,
893934        dsn ,
894935        connection_class = connection_class ,
895936        min_size = min_size , max_size = max_size ,
896-         max_queries = max_queries , loop = loop , setup = setup , init = init ,
937+         max_queries = max_queries , loop = loop , setup = setup ,
938+         middlewares = middlewares , init = init ,
897939        max_inactive_connection_lifetime = max_inactive_connection_lifetime ,
898940        ** connect_kwargs )
0 commit comments