Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from websocket import WebSocketHandler, WebSocketFrameDecoder
from websocket import WebSocketSite, WebSocketTransport
from websocket import DecodingError

from twisted.web.resource import Resource
from twisted.web.server import Request, Site
Expand Down Expand Up @@ -372,6 +373,17 @@ def setUp(self):
transport._attachHandler(handler)
self.decoder = WebSocketFrameDecoder(request, handler)
self.decoder.MAX_LENGTH = 100
self.decoder.MAX_BINARY_LENGTH = 1000


def assertOneDecodingError(self):
"""
Assert that exactly one L{DecodingError} has been logged and return
that error.
"""
errors = self.flushLoggedErrors(DecodingError)
self.assertEquals(len(errors), 1)
return errors[0]


def test_oneFrame(self):
Expand Down Expand Up @@ -406,6 +418,7 @@ def test_missingNull(self):
dropped.
"""
self.decoder.dataReceived("frame\xff")
self.assertOneDecodingError()
self.assertTrue(self.channel.transport.disconnected)


Expand All @@ -415,6 +428,7 @@ def test_missingNullAfterGoodFrame(self):
frame, the connection is dropped.
"""
self.decoder.dataReceived("\x00frame\xfffoo")
self.assertOneDecodingError()
self.assertTrue(self.channel.transport.disconnected)
self.assertEquals(self.decoder.handler.frames, ["frame"])

Expand Down Expand Up @@ -456,6 +470,97 @@ def test_frameLengthReset(self):
self.assertFalse(self.channel.transport.disconnected)


def test_oneBinaryFrame(self):
"""
A binary frame is parsed and ignored, the following text frame is
delivered.
"""
self.decoder.dataReceived("\xff\x0abinarydata\x00text frame\xff")
self.assertEquals(self.decoder.handler.frames, ["text frame"])


def test_multipleBinaryFrames(self):
"""
Text frames intermingled with binary frames are parsed correctly.
"""
tf1, tf2, tf3 = "\x00frame1\xff", "\x00frame2\xff", "\x00frame3\xff"
bf1, bf2, bf3 = "\xff\x01X", "\xff\x1a" + "X" * 0x1a, "\xff\x02AB"

self.decoder.dataReceived(tf1 + bf1 + bf2 + tf2 + tf3 + bf3)
self.assertEquals(self.decoder.handler.frames,
["frame1", "frame2", "frame3"])


def test_binaryFrameMultipleLengthBytes(self):
"""
A binary frame can have its length field spread across multiple bytes.
"""
bf = "\xff\x81\x48" + "X" * 200
tf = "\x00frame\xff"
self.decoder.dataReceived(bf + tf + bf)
self.assertEquals(self.decoder.handler.frames, ["frame"])


def test_binaryAndTextSplitted(self):
"""
Intermingled binary and text frames can be split across several
C{dataReceived} calls.
"""
tf1, tf2 = "\x00text\xff", "\x00other text\xff"
bf1, bf2, bf3 = ("\xff\x01X", "\xff\x81\x48" + "X" * 200,
"\xff\x20" + "X" * 32)

chunks = [bf1[0], bf1[1:], tf1[:2], tf1[2:] + bf2[:2], bf2[2:-2],
bf2[-2:-1], bf2[1] + tf2[:-1], tf2[-1], bf3]
for c in chunks:
self.decoder.dataReceived(c)

self.assertEquals(self.decoder.handler.frames, ["text", "other text"])
self.assertFalse(self.channel.transport.disconnected)


def test_maxBinaryLength(self):
"""
If a binary frame's declared length exceeds MAX_BINARY_LENGTH, the
connection is dropped.
"""
self.decoder.dataReceived("\xff\xff\xff\xff\xff\x01")
self.assertTrue(self.channel.transport.disconnected)


def test_closingHandshake(self):
"""
After receiving the closing handshake, the server sends its own closing
handshake and ignores all future data.
"""
self.decoder.dataReceived("\x00frame\xff\xff\x00random crap")
self.decoder.dataReceived("more random crap, that's discarded")
self.assertEquals(self.decoder.handler.frames, ["frame"])
self.assertTrue(self.decoder.closing)


def test_invalidFrameType(self):
"""
Frame types other than 0x00 and 0xff cause the connection to be
dropped.
"""
ok = "\x00ok\xff"
wrong = "\x05foo\xff"

self.decoder.dataReceived(ok + wrong + ok)
self.assertEquals(self.decoder.handler.frames, ["ok"])
error = self.assertOneDecodingError()
self.assertTrue(self.channel.transport.disconnected)


def test_emptyFrame(self):
"""
An empty text frame is correctly parsed.
"""
self.decoder.dataReceived("\x00\xff")
self.assertEquals(self.decoder.handler.frames, [""])
self.assertFalse(self.channel.transport.disconnected)


class WebSocketHandlerTestCase(TestCase):
"""
Expand Down
160 changes: 131 additions & 29 deletions websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import struct

from twisted.internet import interfaces
from twisted.python import log
from twisted.web._newclient import makeStatefulDispatcher
from twisted.web.http import datetimeToString
from twisted.web.http import _IdentityTransferDecoder
from twisted.web.server import Request, Site, version, unquote
Expand Down Expand Up @@ -417,15 +419,30 @@ def connectionLost(self, reason):
"""


class IncompleteFrame(Exception):
"""
Not enough data to complete a WebSocket frame.
"""


class DecodingError(Exception):
"""
The incoming data is not valid WebSocket protocol data.
"""


class WebSocketFrameDecoder(object):
"""
Decode WebSocket frames and pass them to the attached C{WebSocketHandler}
instance.

@ivar MAX_LENGTH: maximum len of the frame allowed, before calling
@ivar MAX_LENGTH: maximum len of a text frame allowed, before calling
C{frameLengthExceeded} on the handler.
@type MAX_LENGTH: C{int}
@ivar MAX_BINARY_LENGTH: like C{MAX_LENGTH}, but for 0xff type frames
@type MAX_BINARY_LENGTH: C{int}
@ivar closing: a flag set when the closing handshake has been received
@type closing: C{bool}
@ivar request: C{Request} instance.
@type request: L{twisted.web.server.Request}
@ivar handler: L{WebSocketHandler} instance handling the request.
Expand All @@ -438,13 +455,16 @@ class WebSocketFrameDecoder(object):
"""

MAX_LENGTH = 16384

MAX_BINARY_LENGTH = 2147483648
closing = False

def __init__(self, request, handler):
self.request = request
self.handler = handler
self.closing = False
self._data = []
self._currentFrameLength = 0
self._state = "FRAME_START"

def dataReceived(self, data):
"""
Expand All @@ -453,37 +473,119 @@ def dataReceived(self, data):
@param data: data received over the WebSocket connection.
@type data: C{str}
"""
if not data:
if not data or self.closing:
return
while True:
endIndex = data.find("\xff")
if endIndex != -1:
self._currentFrameLength += endIndex
if self._currentFrameLength > self.MAX_LENGTH:
self.handler.frameLengthExceeded()
break
self._currentFrameLength = 0
frame = "".join(self._data) + data[:endIndex]
self._data[:] = []
if frame[0] != "\x00":
self.request.transport.loseConnection()
break
self.handler.frameReceived(frame[1:])
data = data[endIndex + 1:]
if not data:
break
if data[0] != "\x00":
self.request.transport.loseConnection()
break
else:
self._currentFrameLength += len(data)
if self._currentFrameLength > self.MAX_LENGTH + 1:
self._data.append(data)

while self._data and not self.closing:
try:
self.consumeData(self._data[-1])
except IncompleteFrame:
break
except DecodingError:
log.err()
self.request.transport.loseConnection()
break

def consumeData(self, data):
"""
Process the last data chunk received.

After processing is done, L{IncompleteFrame} should be raised or
L{_addRemainingData} should be called.

@param data: last chunk of data received.
@type data: C{str}
"""
consumeData = makeStatefulDispatcher("consumeData", consumeData)

def _consumeData_FRAME_START(self, data):
self._currentFrameLength = 0

if data[0] == "\x00":
self._state = "PARSING_TEXT_FRAME"
elif data[0] == "\xff":
self._state = "PARSING_LENGTH"
else:
raise DecodingError("Invalid frame type 0x%s" %
data[0].encode("hex"))

self._addRemainingData(data[1:])

def _consumeData_PARSING_TEXT_FRAME(self, data):
endIndex = data.find("\xff")
if endIndex == -1:
self._currentFrameLength += len(data)
else:
self._currentFrameLength += endIndex

self._currentFrameLength += endIndex
# check length + 1 to account for the initial frame type byte
if self._currentFrameLength + 1 > self.MAX_LENGTH:
self.handler.frameLengthExceeded()

if endIndex == -1:
raise IncompleteFrame()

frame = "".join(self._data[:-1]) + data[:endIndex]
self.handler.frameReceived(frame)

remainingData = data[endIndex + 1:]
self._addRemainingData(remainingData)

self._state = "FRAME_START"

def _consumeData_PARSING_LENGTH(self, data):
current = 0
available = len(data)

while current < available:
byte = ord(data[current])
length, more = byte & 0x7F, bool(byte & 0x80)

if not length:
self._closingHandshake()
raise IncompleteFrame()

self._currentFrameLength *= 128
self._currentFrameLength += length

current += 1

if not more:
if self._currentFrameLength > self.MAX_BINARY_LENGTH:
self.handler.frameLengthExceeded()
else:
self._data.append(data)

remainingData = data[current:]
self._addRemainingData(remainingData)
self._state = "PARSING_BINARY_FRAME"
break
else:
raise IncompleteFrame()

def _consumeData_PARSING_BINARY_FRAME(self, data):
available = len(data)

if self._currentFrameLength <= available:
remainingData = data[self._currentFrameLength:]
self._addRemainingData(remainingData)
self._state = "FRAME_START"
else:
self._currentFrameLength -= available
self._data[:] = []

__all__ = ["WebSocketHandler", "WebSocketSite"]
def _addRemainingData(self, remainingData):
if remainingData:
self._data[:] = [remainingData]
else:
self._data[:] = []

def _closingHandshake(self):
self.closing = True
# send the closing handshake
self.request.transport.write("\xff\x00")
# discard all buffered data
self._data[:] = []


__all__ = ["WebSocketHandler", "WebSocketSite"]