2626from neo4j .bolt import ConnectionPool , ServiceUnavailable , ProtocolError , DEFAULT_PORT , connect
2727from neo4j .compat .collections import MutableSet , OrderedDict
2828from neo4j .exceptions import CypherError
29- from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS
29+ from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS , fix_statement , fix_parameters
3030from neo4j .v1 .exceptions import SessionExpired
3131from neo4j .v1 .security import SecurityPlan
3232from neo4j .v1 .session import BoltSession
33+ from neo4j .util import ServerVersion
3334
3435
3536class RoundRobinSet (MutableSet ):
@@ -131,11 +132,12 @@ def __init__(self, routers=(), readers=(), writers=(), ttl=0):
131132 self .last_updated_time = self .timer ()
132133 self .ttl = ttl
133134
134- def is_fresh (self ):
135+ def is_fresh (self , access_mode ):
135136 """ Indicator for whether routing information is still usable.
136137 """
137138 expired = self .last_updated_time + self .ttl <= self .timer ()
138- return not expired and len (self .routers ) > 1 and self .readers and self .writers
139+ has_server_for_mode = (access_mode == READ_ACCESS and self .readers ) or (access_mode == WRITE_ACCESS and self .writers )
140+ return not expired and self .routers and has_server_for_mode
139141
140142 def update (self , new_routing_table ):
141143 """ Update the current routing table with new routing information
@@ -148,16 +150,34 @@ def update(self, new_routing_table):
148150 self .ttl = new_routing_table .ttl
149151
150152
153+ class RoutingSession (BoltSession ):
154+
155+ call_get_servers = "CALL dbms.cluster.routing.getServers"
156+ get_routing_table_param = "context"
157+ call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({%s})" % get_routing_table_param
158+
159+ def routing_info_procedure (self , routing_context ):
160+ if ServerVersion .from_str (self ._connection .server .version ).at_least_version (3 , 2 ):
161+ return self .call_get_routing_table , {self .get_routing_table_param : routing_context }
162+ else :
163+ return self .call_get_servers , {}
164+
165+ def __run__ (self , ignored , routing_context ):
166+ # the statement is ignored as it will be get routing table procedure call.
167+ statement , parameters = self .routing_info_procedure (routing_context )
168+ return self ._run (fix_statement (statement ), fix_parameters (parameters ))
169+
170+
151171class RoutingConnectionPool (ConnectionPool ):
152172 """ Connection pool with routing table.
153173 """
154174
155- routing_info_procedure = "dbms.cluster.routing.getServers"
156-
157- def __init__ (self , connector , initial_address , * routers ):
175+ def __init__ (self , connector , initial_address , routing_context , * routers ):
158176 super (RoutingConnectionPool , self ).__init__ (connector )
159177 self .initial_address = initial_address
178+ self .routing_context = routing_context
160179 self .routing_table = RoutingTable (routers )
180+ self .missing_writer = False
161181 self .refresh_lock = Lock ()
162182
163183 def fetch_routing_info (self , address ):
@@ -170,8 +190,8 @@ def fetch_routing_info(self, address):
170190 if routing support is broken
171191 """
172192 try :
173- with BoltSession (lambda _ : self .acquire_direct (address )) as session :
174- return list (session .run ("CALL %s" % self .routing_info_procedure ))
193+ with RoutingSession (lambda _ : self .acquire_direct (address )) as session :
194+ return list (session .run ("ignored" , self .routing_context ))
175195 except CypherError as error :
176196 if error .code == "Neo.ClientError.Procedure.ProcedureNotFound" :
177197 raise ServiceUnavailable ("Server {!r} does not support routing" .format (address ))
@@ -200,6 +220,11 @@ def fetch_routing_table(self, address):
200220 num_readers = len (new_routing_table .readers )
201221 num_writers = len (new_routing_table .writers )
202222
223+ # No writers are available. This likely indicates a temporary state,
224+ # such as leader switching, so we should not signal an error.
225+ # When no writers available, then we flag we are reading in absence of writer
226+ self .missing_writer = (num_writers == 0 )
227+
203228 # No routers
204229 if num_routers == 0 :
205230 raise ProtocolError ("No routing servers returned from server %r" % (address ,))
@@ -208,12 +233,6 @@ def fetch_routing_table(self, address):
208233 if num_readers == 0 :
209234 raise ProtocolError ("No read servers returned from server %r" % (address ,))
210235
211- # No writers
212- if num_writers == 0 :
213- # No writers are available. This likely indicates a temporary state,
214- # such as leader switching, so we should not signal an error.
215- return None
216-
217236 # At least one of each is fine, so return this table
218237 return new_routing_table
219238
@@ -234,21 +253,30 @@ def update_routing_table(self):
234253 """
235254 # copied because it can be modified
236255 copy_of_routers = list (self .routing_table .routers )
256+
257+ has_tried_initial_routers = False
258+ if self .missing_writer :
259+ has_tried_initial_routers = True
260+ if self .update_routing_table_with_routers (resolve (self .initial_address )):
261+ return
262+
237263 if self .update_routing_table_with_routers (copy_of_routers ):
238264 return
239265
240- initial_routers = resolve (self .initial_address )
241- for router in copy_of_routers :
242- if router in initial_routers :
243- initial_routers .remove (router )
244- if initial_routers :
245- if self .update_routing_table_with_routers (initial_routers ):
246- return
266+ if not has_tried_initial_routers :
267+ initial_routers = resolve (self .initial_address )
268+ for router in copy_of_routers :
269+ if router in initial_routers :
270+ initial_routers .remove (router )
271+ if initial_routers :
272+ if self .update_routing_table_with_routers (initial_routers ):
273+ return
274+
247275
248276 # None of the routers have been successful, so just fail
249277 raise ServiceUnavailable ("Unable to retrieve routing information" )
250278
251- def refresh_routing_table (self ):
279+ def ensure_routing_table_is_fresh (self , access_mode ):
252280 """ Update the routing table if stale.
253281
254282 This method performs two freshness checks, before and after acquiring
@@ -261,10 +289,13 @@ def refresh_routing_table(self):
261289
262290 :return: `True` if an update was required, `False` otherwise.
263291 """
264- if self .routing_table .is_fresh ():
292+ if self .routing_table .is_fresh (access_mode ):
265293 return False
266294 with self .refresh_lock :
267- if self .routing_table .is_fresh ():
295+ if self .routing_table .is_fresh (access_mode ):
296+ if access_mode == READ_ACCESS :
297+ # if reader is fresh but writers is not fresh, then we are reading in absence of writer
298+ self .missing_writer = not self .routing_table .is_fresh (WRITE_ACCESS )
268299 return False
269300 self .update_routing_table ()
270301 return True
@@ -278,18 +309,20 @@ def acquire(self, access_mode=None):
278309 server_list = self .routing_table .writers
279310 else :
280311 raise ValueError ("Unsupported access mode {}" .format (access_mode ))
312+
313+ self .ensure_routing_table_is_fresh (access_mode )
281314 while True :
282- address = None
283- while address is None :
284- self .refresh_routing_table ()
285- address = next (server_list )
315+ address = next (server_list )
316+ if address is None :
317+ break
286318 try :
287319 connection = self .acquire_direct (address ) # should always be a resolved address
288320 connection .Error = SessionExpired
289321 except ServiceUnavailable :
290322 self .remove (address )
291323 else :
292324 return connection
325+ raise SessionExpired ("Failed to obtain connection towards '%s' server." % access_mode )
293326
294327 def remove (self , address ):
295328 """ Remove an address from the connection pool, if present, closing
@@ -313,6 +346,7 @@ def __init__(self, uri, **config):
313346 self .initial_address = initial_address = SocketAddress .from_uri (uri , DEFAULT_PORT )
314347 self .security_plan = security_plan = SecurityPlan .build (** config )
315348 self .encrypted = security_plan .encrypted
349+ routing_context = SocketAddress .parse_routing_context (uri )
316350 if not security_plan .routing_compatible :
317351 # this error message is case-specific as there is only one incompatible
318352 # scenario right now
@@ -321,7 +355,7 @@ def __init__(self, uri, **config):
321355 def connector (a ):
322356 return connect (a , security_plan .ssl_context , ** config )
323357
324- pool = RoutingConnectionPool (connector , initial_address , * resolve (initial_address ))
358+ pool = RoutingConnectionPool (connector , initial_address , routing_context , * resolve (initial_address ))
325359 try :
326360 pool .update_routing_table ()
327361 except :
0 commit comments