2929from __future__ import division
3030
3131from base64 import b64encode
32- from collections import deque
32+ from collections import deque , namedtuple
3333from io import BytesIO
3434import logging
3535from os import makedirs , open as os_open , write as os_write , close as os_close , O_CREAT , O_APPEND , O_WRONLY
8181log_error = log .error
8282
8383
84+ Address = namedtuple ("Address" , ["host" , "port" ])
85+ ServerInfo = namedtuple ("ServerInfo" , ["address" , "version" ])
86+
87+
8488class BufferingSocket (object ):
8589
8690 def __init__ (self , connection ):
8791 self .connection = connection
8892 self .socket = connection .socket
89- self .address = self .socket .getpeername ()
93+ self .address = Address ( * self .socket .getpeername () )
9094 self .buffer = bytearray ()
9195
9296 def fill (self ):
@@ -132,7 +136,7 @@ class ChunkChannel(object):
132136
133137 def __init__ (self , sock ):
134138 self .socket = sock
135- self .address = sock .getpeername ()
139+ self .address = Address ( * sock .getpeername () )
136140 self .raw = BytesIO ()
137141 self .output_buffer = []
138142 self .output_size = 0
@@ -206,6 +210,22 @@ def on_ignored(self, metadata=None):
206210 pass
207211
208212
213+ class InitResponse (Response ):
214+
215+ def on_success (self , metadata ):
216+ super (InitResponse , self ).on_success (metadata )
217+ connection = self .connection
218+ address = Address (* connection .socket .getpeername ())
219+ version = metadata .get ("server" )
220+ connection .server = ServerInfo (address , version )
221+
222+ def on_failure (self , metadata ):
223+ code = metadata .get ("code" )
224+ error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else
225+ ServiceUnavailable )
226+ raise error (metadata .get ("message" , "INIT failed" ))
227+
228+
209229class Connection (object ):
210230 """ Server connection for Bolt protocol v1.
211231
@@ -227,10 +247,12 @@ class Connection(object):
227247 #: The pool of which this connection is a member
228248 pool = None
229249
250+ #: Server version details
251+ server = None
252+
230253 def __init__ (self , sock , ** config ):
231254 self .socket = sock
232255 self .buffering_socket = BufferingSocket (self )
233- self .address = sock .getpeername ()
234256 self .channel = ChunkChannel (sock )
235257 self .packer = Packer (self .channel )
236258 self .unpacker = Unpacker ()
@@ -251,19 +273,7 @@ def __init__(self, sock, **config):
251273 # Pick up the server certificate, if any
252274 self .der_encoded_server_certificate = config .get ("der_encoded_server_certificate" )
253275
254- def on_success (metadata ):
255- self .server_version = metadata .get ("server" )
256-
257- def on_failure (metadata ):
258- code = metadata .get ("code" )
259- error = (Unauthorized if code == "Neo.ClientError.Security.Unauthorized" else
260- ServiceUnavailable )
261- raise error (metadata .get ("message" , "INIT failed" ))
262-
263- response = Response (self )
264- response .on_success = on_success
265- response .on_failure = on_failure
266-
276+ response = InitResponse (self )
267277 self .append (INIT , (self .user_agent , self .auth_dict ), response = response )
268278 self .sync ()
269279
@@ -323,18 +333,18 @@ def send(self):
323333 """ Send all queued messages to the server.
324334 """
325335 if self .closed :
326- raise ServiceUnavailable ("Failed to write to closed connection %r" % (self .address ,))
336+ raise ServiceUnavailable ("Failed to write to closed connection %r" % (self .server . address ,))
327337 if self .defunct :
328- raise ServiceUnavailable ("Failed to write to defunct connection %r" % (self .address ,))
338+ raise ServiceUnavailable ("Failed to write to defunct connection %r" % (self .server . address ,))
329339 self .channel .send ()
330340
331341 def fetch (self ):
332342 """ Receive exactly one message from the server.
333343 """
334344 if self .closed :
335- raise ServiceUnavailable ("Failed to read from closed connection %r" % (self .address ,))
345+ raise ServiceUnavailable ("Failed to read from closed connection %r" % (self .server . address ,))
336346 if self .defunct :
337- raise ServiceUnavailable ("Failed to read from defunct connection %r" % (self .address ,))
347+ raise ServiceUnavailable ("Failed to read from defunct connection %r" % (self .server . address ,))
338348 try :
339349 message_data = self .buffering_socket .read_message ()
340350 except ProtocolError :
0 commit comments