3939
4040
4141if  t .TYPE_CHECKING :
42+     from  ssl  import  SSLContext 
43+ 
44+     import  typing_extensions  as  te 
45+ 
4246    from  ..._deadline  import  Deadline 
47+     from  ...addressing  import  (
48+         Address ,
49+         ResolvedAddress ,
50+     )
4351
4452
4553log  =  logging .getLogger ("neo4j.io" )
@@ -63,7 +71,11 @@ def __str__(self):
6371
6472
6573class  AsyncBoltSocket (AsyncBoltSocketBase ):
66-     async  def  _parse_handshake_response_v1 (self , ctx , response ):
74+     async  def  _parse_handshake_response_v1 (
75+         self ,
76+         ctx : HandshakeCtx ,
77+         response : bytes ,
78+     ) ->  tuple [int , int ]:
6779        agreed_version  =  response [- 1 ], response [- 2 ]
6880        log .debug (
6981            "[#%04X]  S: <HANDSHAKE> 0x%06X%02X" ,
@@ -73,7 +85,11 @@ async def _parse_handshake_response_v1(self, ctx, response):
7385        )
7486        return  agreed_version 
7587
76-     async  def  _parse_handshake_response_v2 (self , ctx , response ):
88+     async  def  _parse_handshake_response_v2 (
89+         self ,
90+         ctx : HandshakeCtx ,
91+         response : bytes ,
92+     ) ->  tuple [int , int ]:
7793        ctx .ctx  =  "handshake v2 offerings count" 
7894        num_offerings  =  await  self ._read_varint (ctx )
7995        offerings  =  []
@@ -85,7 +101,7 @@ async def _parse_handshake_response_v2(self, ctx, response):
85101        ctx .ctx  =  "handshake v2 capabilities" 
86102        _capabilities_offer  =  await  self ._read_varint (ctx )
87103
88-         if  log .getEffectiveLevel () > =  logging .DEBUG :
104+         if  log .getEffectiveLevel () < =  logging .DEBUG :
89105            log .debug (
90106                "[#%04X]  S: <HANDSHAKE> %s [%i] %s %s" ,
91107                ctx .local_port ,
@@ -125,7 +141,7 @@ async def _parse_handshake_response_v2(self, ctx, response):
125141
126142        return  chosen_version 
127143
128-     async  def  _read_varint (self , ctx ) :
144+     async  def  _read_varint (self , ctx :  HandshakeCtx )  ->   int :
129145        next_byte  =  (await  self ._handshake_read (ctx , 1 ))[0 ]
130146        res  =  next_byte  &  0x7F 
131147        i  =  0 
@@ -136,15 +152,15 @@ async def _read_varint(self, ctx):
136152        return  res 
137153
138154    @staticmethod  
139-     def  _encode_varint (n ) :
155+     def  _encode_varint (n :  int )  ->   bytearray :
140156        res  =  bytearray ()
141157        while  n  >=  0x80 :
142158            res .append (n  &  0x7F  |  0x80 )
143159            n  >>=  7 
144160        res .append (n )
145161        return  res 
146162
147-     async  def  _handshake_read (self , ctx , n ) :
163+     async  def  _handshake_read (self , ctx :  HandshakeCtx , n :  int )  ->   bytes :
148164        original_timeout  =  self .gettimeout ()
149165        self .settimeout (ctx .deadline .to_timeout ())
150166        try :
@@ -193,7 +209,11 @@ async def _handshake_send(self, ctx, data):
193209        finally :
194210            self .settimeout (original_timeout )
195211
196-     async  def  _handshake (self , resolved_address , deadline ):
212+     async  def  _handshake (
213+         self ,
214+         resolved_address : ResolvedAddress ,
215+         deadline : Deadline ,
216+     ) ->  tuple [tuple [int , int ], bytes , bytes ]:
197217        """ 
198218        Perform BOLT handshake. 
199219
@@ -204,16 +224,16 @@ async def _handshake(self, resolved_address, deadline):
204224        """ 
205225        local_port  =  self .getsockname ()[1 ]
206226
207-         if   log . getEffectiveLevel ()  >=   logging . DEBUG : 
208-              handshake   =   self . Bolt . get_handshake () 
209-             handshake  =  struct .unpack (">16B" , handshake )
210-             handshake  =  [
211-                 handshake [i  : i  +  4 ] for  i  in  range (0 , len (handshake ), 4 )
227+         handshake   =   self . Bolt . get_handshake () 
228+         if   log . getEffectiveLevel ()  <=   logging . DEBUG : 
229+             handshake_bytes :  t . Sequence  =  struct .unpack (">16B" , handshake )
230+             handshake_bytes  =  [
231+                 handshake [i  : i  +  4 ] for  i  in  range (0 , len (handshake_bytes ), 4 )
212232            ]
213233
214234            supported_versions  =  [
215235                f"0x{ vx [0 ]:02X} { vx [1 ]:02X} { vx [2 ]:02X} { vx [3 ]:02X}  " 
216-                 for  vx  in  handshake 
236+                 for  vx  in  handshake_bytes 
217237            ]
218238
219239            log .debug (
@@ -227,7 +247,7 @@ async def _handshake(self, resolved_address, deadline):
227247                * supported_versions ,
228248            )
229249
230-         request  =  self .Bolt .MAGIC_PREAMBLE  +  self . Bolt . get_handshake () 
250+         request  =  self .Bolt .MAGIC_PREAMBLE  +  handshake 
231251
232252        ctx  =  HandshakeCtx (
233253            ctx = "handshake opening" ,
@@ -273,14 +293,14 @@ async def _handshake(self, resolved_address, deadline):
273293    @classmethod  
274294    async  def  connect (
275295        cls ,
276-         address ,
296+         address :  Address ,
277297        * ,
278-         tcp_timeout ,
279-         deadline ,
280-         custom_resolver ,
281-         ssl_context ,
282-         keep_alive ,
283-     ):
298+         tcp_timeout :  float   |   None ,
299+         deadline :  Deadline ,
300+         custom_resolver :  t . Callable   |   None ,
301+         ssl_context :  SSLContext   |   None ,
302+         keep_alive :  bool ,
303+     )  ->   tuple [ te . Self ,  tuple [ int ,  int ],  bytes ,  bytes ] :
284304        """ 
285305        Connect and perform a handshake. 
286306
@@ -313,10 +333,10 @@ async def connect(
313333                )
314334                return  s , agreed_version , handshake , response 
315335            except  (BoltError , DriverError , OSError ) as  error :
316-                 try : 
317-                      local_port   =   s . getsockname ()[ 1 ] 
318-                 except   (OSError , AttributeError , TypeError ):
319-                     local_port  =  0 
336+                 local_port   =   0 
337+                 if   isinstance ( s ,  cls ): 
338+                      with   suppress (OSError , AttributeError , TypeError ):
339+                          local_port  =  s . getsockname ()[ 1 ] 
320340                err_str  =  error .__class__ .__name__ 
321341                if  str (error ):
322342                    err_str  +=  ": "  +  str (error )
@@ -331,10 +351,10 @@ async def connect(
331351                errors .append (error )
332352                failed_addresses .append (resolved_address )
333353            except  asyncio .CancelledError :
334-                 try : 
335-                      local_port   =   s . getsockname ()[ 1 ] 
336-                 except   (OSError , AttributeError , TypeError ):
337-                     local_port  =  0 
354+                 local_port   =   0 
355+                 if   isinstance ( s ,  cls ): 
356+                      with   suppress (OSError , AttributeError , TypeError ):
357+                          local_port  =  s . getsockname ()[ 1 ] 
338358                log .debug (
339359                    "[#%04X]  C: <CANCELED> %s" , local_port , resolved_address 
340360                )
0 commit comments