Skip to content

Commit e41b65b

Browse files
committed
Add new more specific error types
To avoid comparing strings when handling routing errors.
1 parent b666884 commit e41b65b

File tree

2 files changed

+52
-17
lines changed

2 files changed

+52
-17
lines changed

neo4j/exceptions.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,9 @@ def hydrate(cls, message=None, code=None, **metadata):
6565
classification = "DatabaseError"
6666
category = "General"
6767
title = "UnknownError"
68-
if classification == "ClientError":
69-
try:
70-
error_class = client_errors[code]
71-
except KeyError:
72-
error_class = ClientError
73-
elif classification == "DatabaseError":
74-
error_class = DatabaseError
75-
elif classification == "TransientError":
76-
error_class = TransientError
77-
else:
78-
error_class = cls
68+
69+
error_class = cls._extract_error_class(classification, code)
70+
7971
inst = error_class(message)
8072
inst.message = message
8173
inst.code = code
@@ -85,6 +77,26 @@ def hydrate(cls, message=None, code=None, **metadata):
8577
inst.metadata = metadata
8678
return inst
8779

80+
@classmethod
81+
def _extract_error_class(cls, classification, code):
82+
if classification == "ClientError":
83+
try:
84+
return client_errors[code]
85+
except KeyError:
86+
return ClientError
87+
88+
elif classification == "TransientError":
89+
try:
90+
return transient_errors[code]
91+
except KeyError:
92+
return TransientError
93+
94+
elif classification == "DatabaseError":
95+
return DatabaseError
96+
97+
else:
98+
return cls
99+
88100

89101
class ClientError(CypherError):
90102
""" The Client sent a bad request - changing the request might yield a successful outcome.
@@ -101,6 +113,11 @@ class TransientError(CypherError):
101113
"""
102114

103115

116+
class DatabaseUnavailableError(TransientError):
117+
"""
118+
"""
119+
120+
104121
class ConstraintError(ClientError):
105122
"""
106123
"""
@@ -116,11 +133,21 @@ class CypherTypeError(ClientError):
116133
"""
117134

118135

136+
class NotALeaderError(ClientError):
137+
"""
138+
"""
139+
140+
119141
class Forbidden(ClientError, SecurityError):
120142
"""
121143
"""
122144

123145

146+
class ForbiddenOnReadOnlyDatabaseError(Forbidden):
147+
"""
148+
"""
149+
150+
124151
class AuthError(ClientError, SecurityError):
125152
""" Raised when authentication failure occurs.
126153
"""
@@ -144,7 +171,7 @@ class AuthError(ClientError, SecurityError):
144171
"Neo.ClientError.Statement.TypeError": CypherTypeError,
145172

146173
# Forbidden
147-
"Neo.ClientError.General.ForbiddenOnReadOnlyDatabase": Forbidden,
174+
"Neo.ClientError.General.ForbiddenOnReadOnlyDatabase": ForbiddenOnReadOnlyDatabaseError,
148175
"Neo.ClientError.General.ReadOnly": Forbidden,
149176
"Neo.ClientError.Schema.ForbiddenOnConstraintIndex": Forbidden,
150177
"Neo.ClientError.Schema.IndexBelongsToConstraint": Forbidden,
@@ -155,4 +182,12 @@ class AuthError(ClientError, SecurityError):
155182
"Neo.ClientError.Security.AuthorizationFailed": AuthError,
156183
"Neo.ClientError.Security.Unauthorized": AuthError,
157184

185+
# NotALeaderError
186+
"Neo.ClientError.Cluster.NotALeader": NotALeaderError
187+
}
188+
189+
transient_errors = {
190+
191+
# DatabaseUnavailableError
192+
"Neo.TransientError.General.DatabaseUnavailable": DatabaseUnavailableError
158193
}

neo4j/v1/routing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from neo4j.addressing import SocketAddress, resolve
2626
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect
2727
from neo4j.compat.collections import MutableSet, OrderedDict
28-
from neo4j.exceptions import CypherError
28+
from neo4j.exceptions import CypherError, DatabaseUnavailableError, NotALeaderError, ForbiddenOnReadOnlyDatabaseError
2929
from neo4j.util import ServerVersion
3030
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters
3131
from neo4j.v1.exceptions import SessionExpired
@@ -251,8 +251,8 @@ class RoutingConnectionPool(ConnectionPool):
251251
""" Connection pool with routing table.
252252
"""
253253

254-
FAILURE_CODES = ("Neo.TransientError.General.DatabaseUnavailable")
255-
WRITE_FAILURE_CODES = ("Neo.ClientError.Cluster.NotALeader", "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase")
254+
CLUSTER_MEMBER_FAILURE_ERRORS = (ServiceUnavailable, SessionExpired, DatabaseUnavailableError)
255+
WRITE_FAILURE_ERRORS = (NotALeaderError, ForbiddenOnReadOnlyDatabaseError)
256256

257257
def __init__(self, connector, initial_address, routing_context, *routers, **config):
258258
super(RoutingConnectionPool, self).__init__(connector)
@@ -412,9 +412,9 @@ def acquire(self, access_mode=None):
412412
def _handle_connection_error(self, address, error):
413413
""" Handle routing connection send or receive error.
414414
"""
415-
if isinstance(error, (SessionExpired, ServiceUnavailable)) or error.code in self.FAILURE_CODES:
415+
if isinstance(error, self.CLUSTER_MEMBER_FAILURE_ERRORS):
416416
self.remove(address)
417-
elif error.code in self.WRITE_FAILURE_CODES:
417+
elif isinstance(error, self.WRITE_FAILURE_ERRORS):
418418
self._remove_writer(address)
419419

420420
def remove(self, address):

0 commit comments

Comments
 (0)