@@ -119,6 +119,26 @@ def supports_bytes(self):
119119 return self .version_info () >= (3 , 2 )
120120
121121
122+ class ConnectionErrorHandler (object ):
123+ """ A handler for send and receive errors.
124+ """
125+
126+ def __init__ (self , handlers_by_error_class = None ):
127+ if handlers_by_error_class is None :
128+ handlers_by_error_class = {}
129+
130+ self .handlers_by_error_class = handlers_by_error_class
131+ self .known_errors = tuple (handlers_by_error_class .keys ())
132+
133+ def handle (self , error , address ):
134+ try :
135+ error_class = error .__class__
136+ handler = self .handlers_by_error_class [error_class ]
137+ handler (address )
138+ except KeyError :
139+ pass
140+
141+
122142class Connection (object ):
123143 """ Server connection for Bolt protocol v1.
124144
@@ -148,8 +168,10 @@ class Connection(object):
148168
149169 _last_run_statement = None
150170
151- def __init__ (self , sock , ** config ):
171+ def __init__ (self , address , sock , error_handler , ** config ):
172+ self .address = address
152173 self .socket = sock
174+ self .error_handler = error_handler
153175 self .server = ServerInfo (SocketAddress .from_socket (sock ))
154176 self .input_buffer = ChunkedInputBuffer ()
155177 self .output_buffer = ChunkedOutputBuffer ()
@@ -237,6 +259,13 @@ def reset(self):
237259 self .sync ()
238260
239261 def send (self ):
262+ try :
263+ self ._send ()
264+ except self .error_handler .known_errors as error :
265+ self .error_handler .handle (error , self .address )
266+ raise error
267+
268+ def _send (self ):
240269 """ Send all queued messages to the server.
241270 """
242271 data = self .output_buffer .view ()
@@ -250,6 +279,13 @@ def send(self):
250279 self .output_buffer .clear ()
251280
252281 def fetch (self ):
282+ try :
283+ return self ._fetch ()
284+ except self .error_handler .known_errors as error :
285+ self .error_handler .handle (error , self .address )
286+ raise error
287+
288+ def _fetch (self ):
253289 """ Receive at least one message from the server, if available.
254290
255291 :return: 2-tuple of number of detail messages and number of summary messages fetched
@@ -360,8 +396,9 @@ class ConnectionPool(object):
360396
361397 _closed = False
362398
363- def __init__ (self , connector ):
399+ def __init__ (self , connector , connection_error_handler ):
364400 self .connector = connector
401+ self .connection_error_handler = connection_error_handler
365402 self .connections = {}
366403 self .lock = RLock ()
367404
@@ -395,7 +432,7 @@ def acquire_direct(self, address):
395432 connection .in_use = True
396433 return connection
397434 try :
398- connection = self .connector (address )
435+ connection = self .connector (address , self . connection_error_handler )
399436 except ServiceUnavailable :
400437 self .remove (address )
401438 raise
@@ -457,7 +494,7 @@ def closed(self):
457494 return self ._closed
458495
459496
460- def connect (address , ssl_context = None , ** config ):
497+ def connect (address , ssl_context = None , error_handler = None , ** config ):
461498 """ Connect and perform a handshake and return a valid Connection object, assuming
462499 a protocol version can be agreed.
463500 """
@@ -563,7 +600,8 @@ def connect(address, ssl_context=None, **config):
563600 s .shutdown (SHUT_RDWR )
564601 s .close ()
565602 elif agreed_version == 1 :
566- return Connection (s , der_encoded_server_certificate = der_encoded_server_certificate , ** config )
603+ return Connection (address , s , der_encoded_server_certificate = der_encoded_server_certificate ,
604+ error_handler = error_handler , ** config )
567605 elif agreed_version == 0x48545450 :
568606 log_error ("S: [CLOSE]" )
569607 s .close ()
0 commit comments