Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .github/dependabot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
version: 2
updates:
- package-ecosystem: "docker" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "monthly"
- package-ecosystem: "pip" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "monthly"
- package-ecosystem: "github-actions" # See documentation for possible values
directory: "/" # Location of package manifests
schedule:
interval: "monthly"
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v1.4.3
v2.0.1
2 changes: 1 addition & 1 deletion data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ def _parse(cls, stream: BytesIO, ctx: OrderedDict):
def _build(cls, obj, ctx: OrderedDotDict):
return StarByteArray.build(obj.encode("utf-8"), ctx)


class Byte(Struct):
@classmethod
def _parse(cls, stream: BytesIO, ctx: OrderedDict):
Expand Down Expand Up @@ -748,6 +747,7 @@ class ProtocolRequest(Struct):
class ProtocolResponse(Struct):
"""packet type 1 """
server_response = Byte
info = Variant


class ServerDisconnect(Struct):
Expand Down
28 changes: 28 additions & 0 deletions plugins/opensb_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
StarryPy OpenSB Detector Plugin

Detects zstd compression for the stream and sets server configuration accordingly
"""

import asyncio

from base_plugin import SimpleCommandPlugin
from utilities import send_message, Command


class OpenSBDetector(SimpleCommandPlugin):
name = "opensb_detector"

def __init__(self):
super().__init__()

async def activate(self):
await super().activate()

async def on_protocol_response(self, data, connection):
# self.logger.debug("Received protocol response: {} from connection {}".format(data, connection))
info = data["parsed"].get("info")
if info != None and info["compression"] == "Zstd":
self.logger.info("Detected Zstd compression. Setting server configuration.")
connection.start_zstd()
return True
22 changes: 12 additions & 10 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
aiohttp==3.8.4
aiohappyeyeballs==2.3.7
aiohttp==3.10.4
aiosignal==1.3.1
async-timeout==4.0.2
attrs==23.1.0
charset-normalizer==3.1.0
discord.py==2.3.1
async-timeout==4.0.3
attrs==24.2.0
charset-normalizer==3.3.2
discord.py==2.4.0
docopt==0.6.2
frozenlist==1.3.3
idna==3.4
frozenlist==1.4.1
idna==3.7
irc3==1.1.10
multidict==6.0.4
venusian==3.0.0
yarl==1.9.2
multidict==6.0.5
venusian==3.1.0
yarl==1.9.4
zstandard==0.23.0
51 changes: 40 additions & 11 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import logging
import sys
import signal
import traceback

from configuration_manager import ConfigurationManager
from data_parser import ChatReceived
from packets import packets
from pparser import build_packet
from plugin_manager import PluginManager
from utilities import path, read_packet, State, Direction, ChatReceiveMode
from zstd_reader import ZstdFrameReader
from zstd_writer import ZstdFrameWriter


DEBUG = True
Expand All @@ -21,18 +24,21 @@
logger = logging.getLogger('starrypy')
logger.setLevel(loglevel)

class SwitchToZstdException(Exception):
pass

class StarryPyServer:
"""
Primary server class. Handles all the things.
"""
def __init__(self, reader, writer, config, factory):
logger.debug("Initializing connection.")
self._reader = reader
self._writer = writer
self._client_reader = None
self._client_writer = None
self._reader = reader # read packets from client
self._writer = writer # writes packets to client
self._client_reader = None # read packets from server (acting as client)
self._client_writer = None # write packets to server
self.factory = factory
self._client_loop_future = None
self._client_loop_future = asyncio.create_task(self.client_loop())
self._server_loop_future = asyncio.create_task(self.server_loop())
self.state = None
self._alive = True
Expand All @@ -42,8 +48,20 @@ def __init__(self, reader, writer, config, factory):
self._client_read_future = None
self._server_write_future = None
self._client_write_future = None
self._expect_server_loop_death = False
logger.info("Received connection from {}".format(self.client_ip))

def start_zstd(self):
self._reader = ZstdFrameReader(self._reader, Direction.TO_SERVER)
self._client_reader= ZstdFrameReader(self._client_reader, Direction.TO_CLIENT)
self._writer = ZstdFrameWriter(self._writer, skip_packets=1)
self._client_writer = ZstdFrameWriter(self._client_writer)
self._expect_server_loop_death = True
self._server_loop_future.cancel()
self._server_loop_future = asyncio.create_task(self.server_loop())
logger.info("Switched to zstd")


async def server_loop(self):
"""
Main server loop. As clients connect to the proxy, pass the
Expand All @@ -52,14 +70,15 @@ async def server_loop(self):

:return:
"""
(self._client_reader, self._client_writer) = \
await asyncio.open_connection(self.config['upstream_host'],
self.config['upstream_port'])
self._client_loop_future = asyncio.create_task(self.client_loop())

# wait until client is available
while self._client_writer is None:
await asyncio.sleep(0.1)

try:
while True:
packet = await read_packet(self._reader,
Direction.TO_SERVER)
Direction.TO_SERVER)
# Break in case of emergencies:
# if packet['type'] not in [17, 40, 41, 43, 48, 51]:
# logger.debug('c->s {}'.format(packet['type']))
Expand All @@ -74,8 +93,14 @@ async def server_loop(self):
except Exception as err:
logger.error("Server loop exception occurred:"
"{}: {}".format(err.__class__.__name__, err))
logger.error("Error details and traceback: {}".format(traceback.format_exc()))
finally:
self.die()
if not self._expect_server_loop_death:
logger.info("Server loop ended.")
self.die()
else:
logger.info("Restarting server loop for switch to zstd.")
self._expect_server_loop_death = False

async def client_loop(self):
"""
Expand All @@ -84,6 +109,10 @@ async def client_loop(self):

:return:
"""
(self._client_reader, self._client_writer) = \
await asyncio.open_connection(self.config['upstream_host'],
self.config['upstream_port'])

try:
while True:
packet = await read_packet(self._client_reader,
Expand Down
71 changes: 71 additions & 0 deletions zstd_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import asyncio
from io import BufferedReader
import io
import zstandard as zstd

from utilities import Direction

class ZstdFrameReader:
def __init__(self, reader: asyncio.StreamReader, direction: Direction):
self.outputbuffer = NonSeekableMemoryStream()
self.decompressor = zstd.ZstdDecompressor().stream_writer(self.outputbuffer)
self.raw_reader = reader
self.direction = direction

async def readexactly(self, count):
# print(f"Reading exactly {count} bytes")

while True:
# if there are enough bytes, return them
if self.outputbuffer.remaining() >= count:
# print (f"Returning {count} bytes from buffer {self.direction}")
return self.outputbuffer.read(count)

# print(f"Reading from network since there are only {self.remaining} bytes in buffer")
await self.read_from_network(count)

async def read_from_network(self, target_count):
while self.outputbuffer.remaining() < target_count:

chunk = await self.raw_reader.read(32768) # Read in chunks; we'll only get what's available
# print(f"Read {len(chunk)} bytes from network")
if not chunk:
raise asyncio.CancelledError("Connection closed")
try:
self.decompressor.write(chunk)
except zstd.ZstdError:
print("Zstd error, dropping connection")
raise asyncio.CancelledError("Error in compressed data stream!")

class NonSeekableMemoryStream(io.RawIOBase):
def __init__(self):
self.buffer = bytearray()
self.read_pos = 0
self.write_pos = 0

def write(self, b):
self.buffer.extend(b)
self.write_pos += len(b)
return len(b)

def read(self, size=-1):
if size == -1 or size > self.write_pos - self.read_pos:
size = self.write_pos - self.read_pos
if size == 0:
return b''
data = self.buffer[self.read_pos:self.read_pos + size]
self.read_pos += size
if self.read_pos == self.write_pos:
self.buffer = bytearray()
self.read_pos = 0
self.write_pos = 0
return bytes(data)

def remaining(self):
return self.write_pos - self.read_pos

def readable(self):
return True

def writable(self):
return True
25 changes: 25 additions & 0 deletions zstd_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import asyncio
from io import BufferedReader, BytesIO
import zstandard as zstd

class ZstdFrameWriter:
def __init__(self, raw_writer: asyncio.StreamWriter, skip_packets=0):
self.compressor = zstd.ZstdCompressor()
self.raw_writer = raw_writer
self.skip_packets = skip_packets

async def drain(self):
await self.raw_writer.drain()

def close(self):
self.raw_writer.close()
self.compressor = None

def write(self, data):

if self.skip_packets > 0:
self.skip_packets -= 1
self.raw_writer.write(data)
return

self.raw_writer.write(self.compressor.compress(data))