3232from select import select
3333from socket import socket , SOL_SOCKET , SO_KEEPALIVE , SHUT_RDWR , error as SocketError , timeout as SocketTimeout , AF_INET , AF_INET6
3434from struct import pack as struct_pack , unpack as struct_unpack
35- from threading import RLock
35+ from threading import RLock , Condition
3636
3737from neo4j .addressing import SocketAddress , is_ip_address
3838from neo4j .bolt .cert import KNOWN_HOSTS
3939from neo4j .bolt .response import InitResponse , AckFailureResponse , ResetResponse
4040from neo4j .compat .ssl import SSL_AVAILABLE , HAS_SNI , SSLError
41- from neo4j .exceptions import ProtocolError , SecurityError , ServiceUnavailable
41+ from neo4j .exceptions import ClientError , ProtocolError , SecurityError , ServiceUnavailable
4242from neo4j .meta import version
4343from neo4j .packstream import Packer , Unpacker
4444from neo4j .util import import_best as _import_best
45+ from time import clock
4546
4647ChunkedInputBuffer = _import_best ("neo4j.bolt._io" , "neo4j.bolt.io" ).ChunkedInputBuffer
4748ChunkedOutputBuffer = _import_best ("neo4j.bolt._io" , "neo4j.bolt.io" ).ChunkedOutputBuffer
4849
4950
51+ INFINITE = - 1
52+ DEFAULT_MAX_CONNECTION_LIFETIME = INFINITE
53+ DEFAULT_MAX_CONNECTION_POOL_SIZE = INFINITE
5054DEFAULT_CONNECTION_TIMEOUT = 5.0
55+ DEFAULT_CONNECTION_ACQUISITION_TIMEOUT = 60
5156DEFAULT_PORT = 7687
5257DEFAULT_USER_AGENT = "neo4j-python/%s" % version
5358
@@ -178,6 +183,8 @@ def __init__(self, address, sock, error_handler, **config):
178183 self .packer = Packer (self .output_buffer )
179184 self .unpacker = Unpacker ()
180185 self .responses = deque ()
186+ self ._max_connection_lifetime = config .get ("max_connection_lifetime" , DEFAULT_MAX_CONNECTION_LIFETIME )
187+ self ._creation_timestamp = clock ()
181188
182189 # Determine the user agent and ensure it is a Unicode value
183190 user_agent = config .get ("user_agent" , DEFAULT_USER_AGENT )
@@ -201,6 +208,7 @@ def __init__(self, address, sock, error_handler, **config):
201208 # Pick up the server certificate, if any
202209 self .der_encoded_server_certificate = config .get ("der_encoded_server_certificate" )
203210
211+ def Init (self ):
204212 response = InitResponse (self )
205213 self .append (INIT , (self .user_agent , self .auth_dict ), response = response )
206214 self .sync ()
@@ -360,6 +368,9 @@ def _unpack(self):
360368 more = False
361369 return details , summary_signature , summary_metadata
362370
371+ def timedout (self ):
372+ return 0 <= self ._max_connection_lifetime <= clock () - self ._creation_timestamp
373+
363374 def sync (self ):
364375 """ Send and fetch all outstanding messages.
365376
@@ -396,11 +407,14 @@ class ConnectionPool(object):
396407
397408 _closed = False
398409
399- def __init__ (self , connector , connection_error_handler ):
410+ def __init__ (self , connector , connection_error_handler , ** config ):
400411 self .connector = connector
401412 self .connection_error_handler = connection_error_handler
402413 self .connections = {}
403414 self .lock = RLock ()
415+ self .cond = Condition (self .lock )
416+ self ._max_connection_pool_size = config .get ("max_connection_pool_size" , DEFAULT_MAX_CONNECTION_POOL_SIZE )
417+ self ._connection_acquisition_timeout = config .get ("connection_acquisition_timeout" , DEFAULT_CONNECTION_ACQUISITION_TIMEOUT )
404418
405419 def __enter__ (self ):
406420 return self
@@ -424,23 +438,42 @@ def acquire_direct(self, address):
424438 connections = self .connections [address ]
425439 except KeyError :
426440 connections = self .connections [address ] = deque ()
427- for connection in list (connections ):
428- if connection .closed () or connection .defunct ():
429- connections .remove (connection )
430- continue
431- if not connection .in_use :
432- connection .in_use = True
433- return connection
434- try :
435- connection = self .connector (address , self .connection_error_handler )
436- except ServiceUnavailable :
437- self .remove (address )
438- raise
439- else :
440- connection .pool = self
441- connection .in_use = True
442- connections .append (connection )
443- return connection
441+
442+ connection_acquisition_start_timestamp = clock ()
443+ while True :
444+ # try to find a free connection in pool
445+ for connection in list (connections ):
446+ if connection .closed () or connection .defunct () or connection .timedout ():
447+ connections .remove (connection )
448+ continue
449+ if not connection .in_use :
450+ connection .in_use = True
451+ return connection
452+ # all connections in pool are in-use
453+ can_create_new_connection = self ._max_connection_pool_size == INFINITE or len (connections ) < self ._max_connection_pool_size
454+ if can_create_new_connection :
455+ try :
456+ connection = self .connector (address , self .connection_error_handler )
457+ except ServiceUnavailable :
458+ self .remove (address )
459+ raise
460+ else :
461+ connection .pool = self
462+ connection .in_use = True
463+ connections .append (connection )
464+ return connection
465+
466+ # failed to obtain a connection from pool because the pool is full and no free connection in the pool
467+ span_timeout = self ._connection_acquisition_timeout - (clock () - connection_acquisition_start_timestamp )
468+ if span_timeout > 0 :
469+ self .cond .wait (span_timeout )
470+ # if timed out, then we throw error. This time computation is needed, as with python 2.7, we cannot
471+ # tell if the condition is notified or timed out when we come to this line
472+ if self ._connection_acquisition_timeout <= (clock () - connection_acquisition_start_timestamp ):
473+ raise ClientError ("Failed to obtain a connection from pool within {!r}s" .format (
474+ self ._connection_acquisition_timeout ))
475+ else :
476+ raise ClientError ("Failed to obtain a connection from pool within {!r}s" .format (self ._connection_acquisition_timeout ))
444477
445478 def acquire (self , access_mode = None ):
446479 """ Acquire a connection to a server that can satisfy a set of parameters.
@@ -454,6 +487,7 @@ def release(self, connection):
454487 """
455488 with self .lock :
456489 connection .in_use = False
490+ self .cond .notify_all ()
457491
458492 def in_use_connection_count (self , address ):
459493 """ Count the number of connections currently in use to a given
@@ -600,8 +634,10 @@ def connect(address, ssl_context=None, error_handler=None, **config):
600634 s .shutdown (SHUT_RDWR )
601635 s .close ()
602636 elif agreed_version == 1 :
603- return Connection (address , s , der_encoded_server_certificate = der_encoded_server_certificate ,
637+ connection = Connection (address , s , der_encoded_server_certificate = der_encoded_server_certificate ,
604638 error_handler = error_handler , ** config )
639+ connection .Init ()
640+ return connection
605641 elif agreed_version == 0x48545450 :
606642 log_error ("S: [CLOSE]" )
607643 s .close ()
0 commit comments