Skip to content

Commit a61e384

Browse files
authored
Merge pull request #156 from evonzee/feature/opensb-support
OpenStarbound Support
2 parents 17754a8 + eefe6f4 commit a61e384

File tree

8 files changed

+192
-23
lines changed

8 files changed

+192
-23
lines changed

.github/dependabot.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
version: 2
2+
updates:
3+
- package-ecosystem: "docker" # See documentation for possible values
4+
directory: "/" # Location of package manifests
5+
schedule:
6+
interval: "monthly"
7+
- package-ecosystem: "pip" # See documentation for possible values
8+
directory: "/" # Location of package manifests
9+
schedule:
10+
interval: "monthly"
11+
- package-ecosystem: "github-actions" # See documentation for possible values
12+
directory: "/" # Location of package manifests
13+
schedule:
14+
interval: "monthly"

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
v1.4.3
1+
v2.0.1

data_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,6 @@ def _parse(cls, stream: BytesIO, ctx: OrderedDict):
336336
def _build(cls, obj, ctx: OrderedDotDict):
337337
return StarByteArray.build(obj.encode("utf-8"), ctx)
338338

339-
340339
class Byte(Struct):
341340
@classmethod
342341
def _parse(cls, stream: BytesIO, ctx: OrderedDict):
@@ -748,6 +747,7 @@ class ProtocolRequest(Struct):
748747
class ProtocolResponse(Struct):
749748
"""packet type 1 """
750749
server_response = Byte
750+
info = Variant
751751

752752

753753
class ServerDisconnect(Struct):

plugins/opensb_detector.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
StarryPy OpenSB Detector Plugin
3+
4+
Detects zstd compression for the stream and sets server configuration accordingly
5+
"""
6+
7+
import asyncio
8+
9+
from base_plugin import SimpleCommandPlugin
10+
from utilities import send_message, Command
11+
12+
13+
class OpenSBDetector(SimpleCommandPlugin):
14+
name = "opensb_detector"
15+
16+
def __init__(self):
17+
super().__init__()
18+
19+
async def activate(self):
20+
await super().activate()
21+
22+
async def on_protocol_response(self, data, connection):
23+
# self.logger.debug("Received protocol response: {} from connection {}".format(data, connection))
24+
info = data["parsed"].get("info")
25+
if info != None and info["compression"] == "Zstd":
26+
self.logger.info("Detected Zstd compression. Setting server configuration.")
27+
connection.start_zstd()
28+
return True

requirements.txt

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
aiohttp==3.8.4
1+
aiohappyeyeballs==2.3.7
2+
aiohttp==3.10.4
23
aiosignal==1.3.1
3-
async-timeout==4.0.2
4-
attrs==23.1.0
5-
charset-normalizer==3.1.0
6-
discord.py==2.3.1
4+
async-timeout==4.0.3
5+
attrs==24.2.0
6+
charset-normalizer==3.3.2
7+
discord.py==2.4.0
78
docopt==0.6.2
8-
frozenlist==1.3.3
9-
idna==3.4
9+
frozenlist==1.4.1
10+
idna==3.7
1011
irc3==1.1.10
11-
multidict==6.0.4
12-
venusian==3.0.0
13-
yarl==1.9.2
12+
multidict==6.0.5
13+
venusian==3.1.0
14+
yarl==1.9.4
15+
zstandard==0.23.0

server.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import logging
33
import sys
44
import signal
5+
import traceback
56

67
from configuration_manager import ConfigurationManager
78
from data_parser import ChatReceived
89
from packets import packets
910
from pparser import build_packet
1011
from plugin_manager import PluginManager
1112
from utilities import path, read_packet, State, Direction, ChatReceiveMode
13+
from zstd_reader import ZstdFrameReader
14+
from zstd_writer import ZstdFrameWriter
1215

1316

1417
DEBUG = True
@@ -21,18 +24,21 @@
2124
logger = logging.getLogger('starrypy')
2225
logger.setLevel(loglevel)
2326

27+
class SwitchToZstdException(Exception):
28+
pass
29+
2430
class StarryPyServer:
2531
"""
2632
Primary server class. Handles all the things.
2733
"""
2834
def __init__(self, reader, writer, config, factory):
2935
logger.debug("Initializing connection.")
30-
self._reader = reader
31-
self._writer = writer
32-
self._client_reader = None
33-
self._client_writer = None
36+
self._reader = reader # read packets from client
37+
self._writer = writer # writes packets to client
38+
self._client_reader = None # read packets from server (acting as client)
39+
self._client_writer = None # write packets to server
3440
self.factory = factory
35-
self._client_loop_future = None
41+
self._client_loop_future = asyncio.create_task(self.client_loop())
3642
self._server_loop_future = asyncio.create_task(self.server_loop())
3743
self.state = None
3844
self._alive = True
@@ -42,8 +48,20 @@ def __init__(self, reader, writer, config, factory):
4248
self._client_read_future = None
4349
self._server_write_future = None
4450
self._client_write_future = None
51+
self._expect_server_loop_death = False
4552
logger.info("Received connection from {}".format(self.client_ip))
4653

54+
def start_zstd(self):
55+
self._reader = ZstdFrameReader(self._reader, Direction.TO_SERVER)
56+
self._client_reader= ZstdFrameReader(self._client_reader, Direction.TO_CLIENT)
57+
self._writer = ZstdFrameWriter(self._writer, skip_packets=1)
58+
self._client_writer = ZstdFrameWriter(self._client_writer)
59+
self._expect_server_loop_death = True
60+
self._server_loop_future.cancel()
61+
self._server_loop_future = asyncio.create_task(self.server_loop())
62+
logger.info("Switched to zstd")
63+
64+
4765
async def server_loop(self):
4866
"""
4967
Main server loop. As clients connect to the proxy, pass the
@@ -52,14 +70,15 @@ async def server_loop(self):
5270
5371
:return:
5472
"""
55-
(self._client_reader, self._client_writer) = \
56-
await asyncio.open_connection(self.config['upstream_host'],
57-
self.config['upstream_port'])
58-
self._client_loop_future = asyncio.create_task(self.client_loop())
73+
74+
# wait until client is available
75+
while self._client_writer is None:
76+
await asyncio.sleep(0.1)
77+
5978
try:
6079
while True:
6180
packet = await read_packet(self._reader,
62-
Direction.TO_SERVER)
81+
Direction.TO_SERVER)
6382
# Break in case of emergencies:
6483
# if packet['type'] not in [17, 40, 41, 43, 48, 51]:
6584
# logger.debug('c->s {}'.format(packet['type']))
@@ -74,8 +93,14 @@ async def server_loop(self):
7493
except Exception as err:
7594
logger.error("Server loop exception occurred:"
7695
"{}: {}".format(err.__class__.__name__, err))
96+
logger.error("Error details and traceback: {}".format(traceback.format_exc()))
7797
finally:
78-
self.die()
98+
if not self._expect_server_loop_death:
99+
logger.info("Server loop ended.")
100+
self.die()
101+
else:
102+
logger.info("Restarting server loop for switch to zstd.")
103+
self._expect_server_loop_death = False
79104

80105
async def client_loop(self):
81106
"""
@@ -84,6 +109,10 @@ async def client_loop(self):
84109
85110
:return:
86111
"""
112+
(self._client_reader, self._client_writer) = \
113+
await asyncio.open_connection(self.config['upstream_host'],
114+
self.config['upstream_port'])
115+
87116
try:
88117
while True:
89118
packet = await read_packet(self._client_reader,

zstd_reader.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import asyncio
2+
from io import BufferedReader
3+
import io
4+
import zstandard as zstd
5+
6+
from utilities import Direction
7+
8+
class ZstdFrameReader:
9+
def __init__(self, reader: asyncio.StreamReader, direction: Direction):
10+
self.outputbuffer = NonSeekableMemoryStream()
11+
self.decompressor = zstd.ZstdDecompressor().stream_writer(self.outputbuffer)
12+
self.raw_reader = reader
13+
self.direction = direction
14+
15+
async def readexactly(self, count):
16+
# print(f"Reading exactly {count} bytes")
17+
18+
while True:
19+
# if there are enough bytes, return them
20+
if self.outputbuffer.remaining() >= count:
21+
# print (f"Returning {count} bytes from buffer {self.direction}")
22+
return self.outputbuffer.read(count)
23+
24+
# print(f"Reading from network since there are only {self.remaining} bytes in buffer")
25+
await self.read_from_network(count)
26+
27+
async def read_from_network(self, target_count):
28+
while self.outputbuffer.remaining() < target_count:
29+
30+
chunk = await self.raw_reader.read(32768) # Read in chunks; we'll only get what's available
31+
# print(f"Read {len(chunk)} bytes from network")
32+
if not chunk:
33+
raise asyncio.CancelledError("Connection closed")
34+
try:
35+
self.decompressor.write(chunk)
36+
except zstd.ZstdError:
37+
print("Zstd error, dropping connection")
38+
raise asyncio.CancelledError("Error in compressed data stream!")
39+
40+
class NonSeekableMemoryStream(io.RawIOBase):
41+
def __init__(self):
42+
self.buffer = bytearray()
43+
self.read_pos = 0
44+
self.write_pos = 0
45+
46+
def write(self, b):
47+
self.buffer.extend(b)
48+
self.write_pos += len(b)
49+
return len(b)
50+
51+
def read(self, size=-1):
52+
if size == -1 or size > self.write_pos - self.read_pos:
53+
size = self.write_pos - self.read_pos
54+
if size == 0:
55+
return b''
56+
data = self.buffer[self.read_pos:self.read_pos + size]
57+
self.read_pos += size
58+
if self.read_pos == self.write_pos:
59+
self.buffer = bytearray()
60+
self.read_pos = 0
61+
self.write_pos = 0
62+
return bytes(data)
63+
64+
def remaining(self):
65+
return self.write_pos - self.read_pos
66+
67+
def readable(self):
68+
return True
69+
70+
def writable(self):
71+
return True

zstd_writer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import asyncio
2+
from io import BufferedReader, BytesIO
3+
import zstandard as zstd
4+
5+
class ZstdFrameWriter:
6+
def __init__(self, raw_writer: asyncio.StreamWriter, skip_packets=0):
7+
self.compressor = zstd.ZstdCompressor()
8+
self.raw_writer = raw_writer
9+
self.skip_packets = skip_packets
10+
11+
async def drain(self):
12+
await self.raw_writer.drain()
13+
14+
def close(self):
15+
self.raw_writer.close()
16+
self.compressor = None
17+
18+
def write(self, data):
19+
20+
if self.skip_packets > 0:
21+
self.skip_packets -= 1
22+
self.raw_writer.write(data)
23+
return
24+
25+
self.raw_writer.write(self.compressor.compress(data))

0 commit comments

Comments
 (0)