diff --git a/test_websocket.py b/test_websocket.py index ee1855d..9877d16 100644 --- a/test_websocket.py +++ b/test_websocket.py @@ -11,6 +11,7 @@ from websocket import WebSocketHandler, WebSocketFrameDecoder from websocket import WebSocketSite, WebSocketTransport +from websocket import DecodingError from twisted.web.resource import Resource from twisted.web.server import Request, Site @@ -372,6 +373,17 @@ def setUp(self): transport._attachHandler(handler) self.decoder = WebSocketFrameDecoder(request, handler) self.decoder.MAX_LENGTH = 100 + self.decoder.MAX_BINARY_LENGTH = 1000 + + + def assertOneDecodingError(self): + """ + Assert that exactly one L{DecodingError} has been logged and return + that error. + """ + errors = self.flushLoggedErrors(DecodingError) + self.assertEquals(len(errors), 1) + return errors[0] def test_oneFrame(self): @@ -406,6 +418,7 @@ def test_missingNull(self): dropped. """ self.decoder.dataReceived("frame\xff") + self.assertOneDecodingError() self.assertTrue(self.channel.transport.disconnected) @@ -415,6 +428,7 @@ def test_missingNullAfterGoodFrame(self): frame, the connection is dropped. """ self.decoder.dataReceived("\x00frame\xfffoo") + self.assertOneDecodingError() self.assertTrue(self.channel.transport.disconnected) self.assertEquals(self.decoder.handler.frames, ["frame"]) @@ -456,6 +470,97 @@ def test_frameLengthReset(self): self.assertFalse(self.channel.transport.disconnected) + def test_oneBinaryFrame(self): + """ + A binary frame is parsed and ignored, the following text frame is + delivered. + """ + self.decoder.dataReceived("\xff\x0abinarydata\x00text frame\xff") + self.assertEquals(self.decoder.handler.frames, ["text frame"]) + + + def test_multipleBinaryFrames(self): + """ + Text frames intermingled with binary frames are parsed correctly. + """ + tf1, tf2, tf3 = "\x00frame1\xff", "\x00frame2\xff", "\x00frame3\xff" + bf1, bf2, bf3 = "\xff\x01X", "\xff\x1a" + "X" * 0x1a, "\xff\x02AB" + + self.decoder.dataReceived(tf1 + bf1 + bf2 + tf2 + tf3 + bf3) + self.assertEquals(self.decoder.handler.frames, + ["frame1", "frame2", "frame3"]) + + + def test_binaryFrameMultipleLengthBytes(self): + """ + A binary frame can have its length field spread across multiple bytes. + """ + bf = "\xff\x81\x48" + "X" * 200 + tf = "\x00frame\xff" + self.decoder.dataReceived(bf + tf + bf) + self.assertEquals(self.decoder.handler.frames, ["frame"]) + + + def test_binaryAndTextSplitted(self): + """ + Intermingled binary and text frames can be split across several + C{dataReceived} calls. + """ + tf1, tf2 = "\x00text\xff", "\x00other text\xff" + bf1, bf2, bf3 = ("\xff\x01X", "\xff\x81\x48" + "X" * 200, + "\xff\x20" + "X" * 32) + + chunks = [bf1[0], bf1[1:], tf1[:2], tf1[2:] + bf2[:2], bf2[2:-2], + bf2[-2:-1], bf2[1] + tf2[:-1], tf2[-1], bf3] + for c in chunks: + self.decoder.dataReceived(c) + + self.assertEquals(self.decoder.handler.frames, ["text", "other text"]) + self.assertFalse(self.channel.transport.disconnected) + + + def test_maxBinaryLength(self): + """ + If a binary frame's declared length exceeds MAX_BINARY_LENGTH, the + connection is dropped. + """ + self.decoder.dataReceived("\xff\xff\xff\xff\xff\x01") + self.assertTrue(self.channel.transport.disconnected) + + + def test_closingHandshake(self): + """ + After receiving the closing handshake, the server sends its own closing + handshake and ignores all future data. + """ + self.decoder.dataReceived("\x00frame\xff\xff\x00random crap") + self.decoder.dataReceived("more random crap, that's discarded") + self.assertEquals(self.decoder.handler.frames, ["frame"]) + self.assertTrue(self.decoder.closing) + + + def test_invalidFrameType(self): + """ + Frame types other than 0x00 and 0xff cause the connection to be + dropped. + """ + ok = "\x00ok\xff" + wrong = "\x05foo\xff" + + self.decoder.dataReceived(ok + wrong + ok) + self.assertEquals(self.decoder.handler.frames, ["ok"]) + error = self.assertOneDecodingError() + self.assertTrue(self.channel.transport.disconnected) + + + def test_emptyFrame(self): + """ + An empty text frame is correctly parsed. + """ + self.decoder.dataReceived("\x00\xff") + self.assertEquals(self.decoder.handler.frames, [""]) + self.assertFalse(self.channel.transport.disconnected) + class WebSocketHandlerTestCase(TestCase): """ diff --git a/websocket.py b/websocket.py index 5cae1d3..8908c37 100644 --- a/websocket.py +++ b/websocket.py @@ -18,6 +18,8 @@ import struct from twisted.internet import interfaces +from twisted.python import log +from twisted.web._newclient import makeStatefulDispatcher from twisted.web.http import datetimeToString from twisted.web.http import _IdentityTransferDecoder from twisted.web.server import Request, Site, version, unquote @@ -417,15 +419,30 @@ def connectionLost(self, reason): """ +class IncompleteFrame(Exception): + """ + Not enough data to complete a WebSocket frame. + """ + + +class DecodingError(Exception): + """ + The incoming data is not valid WebSocket protocol data. + """ + class WebSocketFrameDecoder(object): """ Decode WebSocket frames and pass them to the attached C{WebSocketHandler} instance. - @ivar MAX_LENGTH: maximum len of the frame allowed, before calling + @ivar MAX_LENGTH: maximum len of a text frame allowed, before calling C{frameLengthExceeded} on the handler. @type MAX_LENGTH: C{int} + @ivar MAX_BINARY_LENGTH: like C{MAX_LENGTH}, but for 0xff type frames + @type MAX_BINARY_LENGTH: C{int} + @ivar closing: a flag set when the closing handshake has been received + @type closing: C{bool} @ivar request: C{Request} instance. @type request: L{twisted.web.server.Request} @ivar handler: L{WebSocketHandler} instance handling the request. @@ -438,13 +455,16 @@ class WebSocketFrameDecoder(object): """ MAX_LENGTH = 16384 - + MAX_BINARY_LENGTH = 2147483648 + closing = False def __init__(self, request, handler): self.request = request self.handler = handler + self.closing = False self._data = [] self._currentFrameLength = 0 + self._state = "FRAME_START" def dataReceived(self, data): """ @@ -453,37 +473,119 @@ def dataReceived(self, data): @param data: data received over the WebSocket connection. @type data: C{str} """ - if not data: + if not data or self.closing: return - while True: - endIndex = data.find("\xff") - if endIndex != -1: - self._currentFrameLength += endIndex - if self._currentFrameLength > self.MAX_LENGTH: - self.handler.frameLengthExceeded() - break - self._currentFrameLength = 0 - frame = "".join(self._data) + data[:endIndex] - self._data[:] = [] - if frame[0] != "\x00": - self.request.transport.loseConnection() - break - self.handler.frameReceived(frame[1:]) - data = data[endIndex + 1:] - if not data: - break - if data[0] != "\x00": - self.request.transport.loseConnection() - break - else: - self._currentFrameLength += len(data) - if self._currentFrameLength > self.MAX_LENGTH + 1: + self._data.append(data) + + while self._data and not self.closing: + try: + self.consumeData(self._data[-1]) + except IncompleteFrame: + break + except DecodingError: + log.err() + self.request.transport.loseConnection() + break + + def consumeData(self, data): + """ + Process the last data chunk received. + + After processing is done, L{IncompleteFrame} should be raised or + L{_addRemainingData} should be called. + + @param data: last chunk of data received. + @type data: C{str} + """ + consumeData = makeStatefulDispatcher("consumeData", consumeData) + + def _consumeData_FRAME_START(self, data): + self._currentFrameLength = 0 + + if data[0] == "\x00": + self._state = "PARSING_TEXT_FRAME" + elif data[0] == "\xff": + self._state = "PARSING_LENGTH" + else: + raise DecodingError("Invalid frame type 0x%s" % + data[0].encode("hex")) + + self._addRemainingData(data[1:]) + + def _consumeData_PARSING_TEXT_FRAME(self, data): + endIndex = data.find("\xff") + if endIndex == -1: + self._currentFrameLength += len(data) + else: + self._currentFrameLength += endIndex + + self._currentFrameLength += endIndex + # check length + 1 to account for the initial frame type byte + if self._currentFrameLength + 1 > self.MAX_LENGTH: + self.handler.frameLengthExceeded() + + if endIndex == -1: + raise IncompleteFrame() + + frame = "".join(self._data[:-1]) + data[:endIndex] + self.handler.frameReceived(frame) + + remainingData = data[endIndex + 1:] + self._addRemainingData(remainingData) + + self._state = "FRAME_START" + + def _consumeData_PARSING_LENGTH(self, data): + current = 0 + available = len(data) + + while current < available: + byte = ord(data[current]) + length, more = byte & 0x7F, bool(byte & 0x80) + + if not length: + self._closingHandshake() + raise IncompleteFrame() + + self._currentFrameLength *= 128 + self._currentFrameLength += length + + current += 1 + + if not more: + if self._currentFrameLength > self.MAX_BINARY_LENGTH: self.handler.frameLengthExceeded() - else: - self._data.append(data) + + remainingData = data[current:] + self._addRemainingData(remainingData) + self._state = "PARSING_BINARY_FRAME" break + else: + raise IncompleteFrame() + def _consumeData_PARSING_BINARY_FRAME(self, data): + available = len(data) + if self._currentFrameLength <= available: + remainingData = data[self._currentFrameLength:] + self._addRemainingData(remainingData) + self._state = "FRAME_START" + else: + self._currentFrameLength -= available + self._data[:] = [] -__all__ = ["WebSocketHandler", "WebSocketSite"] + def _addRemainingData(self, remainingData): + if remainingData: + self._data[:] = [remainingData] + else: + self._data[:] = [] + + def _closingHandshake(self): + self.closing = True + # send the closing handshake + self.request.transport.write("\xff\x00") + # discard all buffered data + self._data[:] = [] + +__all__ = ["WebSocketHandler", "WebSocketSite"]