2626from neo4j .bolt import ConnectionPool , ServiceUnavailable , ProtocolError , DEFAULT_PORT , connect
2727from neo4j .compat .collections import MutableSet , OrderedDict
2828from neo4j .exceptions import CypherError
29- from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS
29+ from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS , fix_statement , fix_parameters
3030from neo4j .v1 .exceptions import SessionExpired
3131from neo4j .v1 .security import SecurityPlan
3232from neo4j .v1 .session import BoltSession
@@ -150,14 +150,28 @@ def update(self, new_routing_table):
150150 self .ttl = new_routing_table .ttl
151151
152152
153- class RoutingConnectionPool (ConnectionPool ):
154- """ Connection pool with routing table.
155- """
153+ class RoutingSession (BoltSession ):
156154
157155 call_get_servers = "CALL dbms.cluster.routing.getServers"
158156 get_routing_table_param = "context"
159157 call_get_routing_table = "CALL dbms.cluster.routing.getRoutingTable({%s})" % get_routing_table_param
160158
159+ def routing_info_procedure (self , routing_context ):
160+ if ServerVersion .from_str (self ._connection .server .version ).at_least_version (3 , 2 ):
161+ return self .call_get_routing_table , {self .get_routing_table_param : routing_context }
162+ else :
163+ return self .call_get_servers , {}
164+
165+ def __run__ (self , ignored , routing_context ):
166+ # the statement is ignored as it will be get routing table procedure call.
167+ statement , parameters = self .routing_info_procedure (routing_context )
168+ return self ._run (fix_statement (statement ), fix_parameters (parameters ))
169+
170+
171+ class RoutingConnectionPool (ConnectionPool ):
172+ """ Connection pool with routing table.
173+ """
174+
161175 def __init__ (self , connector , initial_address , routing_context , * routers ):
162176 super (RoutingConnectionPool , self ).__init__ (connector )
163177 self .initial_address = initial_address
@@ -166,12 +180,6 @@ def __init__(self, connector, initial_address, routing_context, *routers):
166180 self .missing_writer = False
167181 self .refresh_lock = Lock ()
168182
169- def routing_info_procedure (self , connection ):
170- if ServerVersion .from_str (connection .server .version ).at_least_version (3 , 2 ):
171- return self .call_get_routing_table , {self .get_routing_table_param : self .routing_context }
172- else :
173- return self .call_get_servers , {}
174-
175183 def fetch_routing_info (self , address ):
176184 """ Fetch raw routing info from a given router address.
177185
@@ -182,15 +190,8 @@ def fetch_routing_info(self, address):
182190 if routing support is broken
183191 """
184192 try :
185- connections = [None ]
186-
187- def connector (_ ):
188- connection = self .acquire_direct (address )
189- connections [0 ] = connection
190- return connection
191-
192- with BoltSession (lambda _ : connector ) as session :
193- return list (session .run (* self .routing_info_procedure (connections [0 ])))
193+ with RoutingSession (lambda _ : self .acquire_direct (address )) as session :
194+ return list (session .run ("ignored" , self .routing_context ))
194195 except CypherError as error :
195196 if error .code == "Neo.ClientError.Procedure.ProcedureNotFound" :
196197 raise ServiceUnavailable ("Server {!r} does not support routing" .format (address ))
0 commit comments