diff --git a/txsockjs/protocols/websocket.py b/txsockjs/protocols/websocket.py index 0f4e36f..790b55c 100644 --- a/txsockjs/protocols/websocket.py +++ b/txsockjs/protocols/websocket.py @@ -37,18 +37,7 @@ from txsockjs.utils import normalize import json, re - -class PeerOverrideProtocol(ProtocolWrapper): - def getPeer(self): - if self.parent._options["proxy_header"] and self.request.requestHeaders.hasHeader(self.parent._options["proxy_header"]): - ip = self.request.requestHeaders.getRawHeaders(self.parent._options["proxy_header"])[0].split(",")[-1].strip() - if re.match("\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", ip): - return address.IPv4Address("TCP", ip, None) - else: - return address.IPv6Address("TCP", ip, None) - return ProtocolWrapper.getPeer(self) - -class JsonProtocol(PeerOverrideProtocol): +class JsonProtocol(ProtocolWrapper): def makeConnection(self, transport): directlyProvides(self, providedBy(transport)) Protocol.makeConnection(self, transport) @@ -76,7 +65,7 @@ def loseConnection(self): def connectionLost(self, reason=None): if self.heartbeat_timer.active(): self.heartbeat_timer.cancel() - PeerOverrideProtocol.connectionLost(self, reason) + ProtocolWrapper.connectionLost(self, reason) def dataReceived(self, data): if not data: @@ -94,9 +83,6 @@ def heartbeat(self): self.transport.write('h') self.heartbeat_timer = reactor.callLater(self.parent._options['heartbeat'], self.heartbeat) -class PeerOverrideFactory(WrappingFactory): - protocol = PeerOverrideProtocol - class JsonFactory(WrappingFactory): protocol = JsonProtocol @@ -105,15 +91,23 @@ def __init__(self): self._factory = None def _makeFactory(self): - f = PeerOverrideFactory(self.parent._factory) WebSocketsResource.__init__(self, self.parent._factory) OldWebSocketsResource.__init__(self, self.parent._factory) + def getPeer(self, request): + if self.parent._options["proxy_header"] and request.requestHeaders.hasHeader(self.parent._options["proxy_header"]): + ip = request.requestHeaders.getRawHeaders(self.parent._options["proxy_header"])[0].split(",")[-1].strip() + if re.match("\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", ip): + return address.IPv4Address("TCP", ip, None) + else: + return address.IPv6Address("TCP", ip, None) + return request.transport.getPeer() + def lookupProtocol(self, protocolNames, request, old = False): if old: - protocol = self._oldfactory.buildProtocol(request.transport.getPeer()) + protocol = self._oldfactory.buildProtocol(self.getPeer(request)) else: - protocol = self._factory.buildProtocol(request.transport.getPeer()) + protocol = self._factory.buildProtocol(self.getPeer(request)) protocol.request = request protocol.parent = self.parent return protocol, None