diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 0000000..715bc4e --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,14 @@ +version: 2 +updates: + - package-ecosystem: "docker" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "monthly" + - package-ecosystem: "pip" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "monthly" + - package-ecosystem: "github-actions" # See documentation for possible values + directory: "/" # Location of package manifests + schedule: + interval: "monthly" \ No newline at end of file diff --git a/VERSION b/VERSION index 5beebea..81ef58f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -v1.4.3 \ No newline at end of file +v2.0.1 \ No newline at end of file diff --git a/data_parser.py b/data_parser.py index 4b8b8ff..aa85177 100644 --- a/data_parser.py +++ b/data_parser.py @@ -336,7 +336,6 @@ def _parse(cls, stream: BytesIO, ctx: OrderedDict): def _build(cls, obj, ctx: OrderedDotDict): return StarByteArray.build(obj.encode("utf-8"), ctx) - class Byte(Struct): @classmethod def _parse(cls, stream: BytesIO, ctx: OrderedDict): @@ -748,6 +747,7 @@ class ProtocolRequest(Struct): class ProtocolResponse(Struct): """packet type 1 """ server_response = Byte + info = Variant class ServerDisconnect(Struct): diff --git a/plugins/opensb_detector.py b/plugins/opensb_detector.py new file mode 100644 index 0000000..04ec1a8 --- /dev/null +++ b/plugins/opensb_detector.py @@ -0,0 +1,28 @@ +""" +StarryPy OpenSB Detector Plugin + +Detects zstd compression for the stream and sets server configuration accordingly +""" + +import asyncio + +from base_plugin import SimpleCommandPlugin +from utilities import send_message, Command + + +class OpenSBDetector(SimpleCommandPlugin): + name = "opensb_detector" + + def __init__(self): + super().__init__() + + async def activate(self): + await super().activate() + + async def on_protocol_response(self, data, connection): + # self.logger.debug("Received protocol response: {} from connection {}".format(data, connection)) + info = data["parsed"].get("info") + if info != None and info["compression"] == "Zstd": + self.logger.info("Detected Zstd compression. Setting server configuration.") + connection.start_zstd() + return True diff --git a/requirements.txt b/requirements.txt index 33d2cc1..5a419f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,15 @@ -aiohttp==3.8.4 +aiohappyeyeballs==2.3.7 +aiohttp==3.10.4 aiosignal==1.3.1 -async-timeout==4.0.2 -attrs==23.1.0 -charset-normalizer==3.1.0 -discord.py==2.3.1 +async-timeout==4.0.3 +attrs==24.2.0 +charset-normalizer==3.3.2 +discord.py==2.4.0 docopt==0.6.2 -frozenlist==1.3.3 -idna==3.4 +frozenlist==1.4.1 +idna==3.7 irc3==1.1.10 -multidict==6.0.4 -venusian==3.0.0 -yarl==1.9.2 +multidict==6.0.5 +venusian==3.1.0 +yarl==1.9.4 +zstandard==0.23.0 diff --git a/server.py b/server.py index b59312e..8478ef8 100644 --- a/server.py +++ b/server.py @@ -2,6 +2,7 @@ import logging import sys import signal +import traceback from configuration_manager import ConfigurationManager from data_parser import ChatReceived @@ -9,6 +10,8 @@ from pparser import build_packet from plugin_manager import PluginManager from utilities import path, read_packet, State, Direction, ChatReceiveMode +from zstd_reader import ZstdFrameReader +from zstd_writer import ZstdFrameWriter DEBUG = True @@ -21,18 +24,21 @@ logger = logging.getLogger('starrypy') logger.setLevel(loglevel) +class SwitchToZstdException(Exception): + pass + class StarryPyServer: """ Primary server class. Handles all the things. """ def __init__(self, reader, writer, config, factory): logger.debug("Initializing connection.") - self._reader = reader - self._writer = writer - self._client_reader = None - self._client_writer = None + self._reader = reader # read packets from client + self._writer = writer # writes packets to client + self._client_reader = None # read packets from server (acting as client) + self._client_writer = None # write packets to server self.factory = factory - self._client_loop_future = None + self._client_loop_future = asyncio.create_task(self.client_loop()) self._server_loop_future = asyncio.create_task(self.server_loop()) self.state = None self._alive = True @@ -42,8 +48,20 @@ def __init__(self, reader, writer, config, factory): self._client_read_future = None self._server_write_future = None self._client_write_future = None + self._expect_server_loop_death = False logger.info("Received connection from {}".format(self.client_ip)) + def start_zstd(self): + self._reader = ZstdFrameReader(self._reader, Direction.TO_SERVER) + self._client_reader= ZstdFrameReader(self._client_reader, Direction.TO_CLIENT) + self._writer = ZstdFrameWriter(self._writer, skip_packets=1) + self._client_writer = ZstdFrameWriter(self._client_writer) + self._expect_server_loop_death = True + self._server_loop_future.cancel() + self._server_loop_future = asyncio.create_task(self.server_loop()) + logger.info("Switched to zstd") + + async def server_loop(self): """ Main server loop. As clients connect to the proxy, pass the @@ -52,14 +70,15 @@ async def server_loop(self): :return: """ - (self._client_reader, self._client_writer) = \ - await asyncio.open_connection(self.config['upstream_host'], - self.config['upstream_port']) - self._client_loop_future = asyncio.create_task(self.client_loop()) + + # wait until client is available + while self._client_writer is None: + await asyncio.sleep(0.1) + try: while True: packet = await read_packet(self._reader, - Direction.TO_SERVER) + Direction.TO_SERVER) # Break in case of emergencies: # if packet['type'] not in [17, 40, 41, 43, 48, 51]: # logger.debug('c->s {}'.format(packet['type'])) @@ -74,8 +93,14 @@ async def server_loop(self): except Exception as err: logger.error("Server loop exception occurred:" "{}: {}".format(err.__class__.__name__, err)) + logger.error("Error details and traceback: {}".format(traceback.format_exc())) finally: - self.die() + if not self._expect_server_loop_death: + logger.info("Server loop ended.") + self.die() + else: + logger.info("Restarting server loop for switch to zstd.") + self._expect_server_loop_death = False async def client_loop(self): """ @@ -84,6 +109,10 @@ async def client_loop(self): :return: """ + (self._client_reader, self._client_writer) = \ + await asyncio.open_connection(self.config['upstream_host'], + self.config['upstream_port']) + try: while True: packet = await read_packet(self._client_reader, diff --git a/zstd_reader.py b/zstd_reader.py new file mode 100644 index 0000000..eb2fa71 --- /dev/null +++ b/zstd_reader.py @@ -0,0 +1,71 @@ +import asyncio +from io import BufferedReader +import io +import zstandard as zstd + +from utilities import Direction + +class ZstdFrameReader: + def __init__(self, reader: asyncio.StreamReader, direction: Direction): + self.outputbuffer = NonSeekableMemoryStream() + self.decompressor = zstd.ZstdDecompressor().stream_writer(self.outputbuffer) + self.raw_reader = reader + self.direction = direction + + async def readexactly(self, count): + # print(f"Reading exactly {count} bytes") + + while True: + # if there are enough bytes, return them + if self.outputbuffer.remaining() >= count: + # print (f"Returning {count} bytes from buffer {self.direction}") + return self.outputbuffer.read(count) + + # print(f"Reading from network since there are only {self.remaining} bytes in buffer") + await self.read_from_network(count) + + async def read_from_network(self, target_count): + while self.outputbuffer.remaining() < target_count: + + chunk = await self.raw_reader.read(32768) # Read in chunks; we'll only get what's available + # print(f"Read {len(chunk)} bytes from network") + if not chunk: + raise asyncio.CancelledError("Connection closed") + try: + self.decompressor.write(chunk) + except zstd.ZstdError: + print("Zstd error, dropping connection") + raise asyncio.CancelledError("Error in compressed data stream!") + +class NonSeekableMemoryStream(io.RawIOBase): + def __init__(self): + self.buffer = bytearray() + self.read_pos = 0 + self.write_pos = 0 + + def write(self, b): + self.buffer.extend(b) + self.write_pos += len(b) + return len(b) + + def read(self, size=-1): + if size == -1 or size > self.write_pos - self.read_pos: + size = self.write_pos - self.read_pos + if size == 0: + return b'' + data = self.buffer[self.read_pos:self.read_pos + size] + self.read_pos += size + if self.read_pos == self.write_pos: + self.buffer = bytearray() + self.read_pos = 0 + self.write_pos = 0 + return bytes(data) + + def remaining(self): + return self.write_pos - self.read_pos + + def readable(self): + return True + + def writable(self): + return True \ No newline at end of file diff --git a/zstd_writer.py b/zstd_writer.py new file mode 100644 index 0000000..e6c34b5 --- /dev/null +++ b/zstd_writer.py @@ -0,0 +1,25 @@ +import asyncio +from io import BufferedReader, BytesIO +import zstandard as zstd + +class ZstdFrameWriter: + def __init__(self, raw_writer: asyncio.StreamWriter, skip_packets=0): + self.compressor = zstd.ZstdCompressor() + self.raw_writer = raw_writer + self.skip_packets = skip_packets + + async def drain(self): + await self.raw_writer.drain() + + def close(self): + self.raw_writer.close() + self.compressor = None + + def write(self, data): + + if self.skip_packets > 0: + self.skip_packets -= 1 + self.raw_writer.write(data) + return + + self.raw_writer.write(self.compressor.compress(data)) \ No newline at end of file