22import logging
33import sys
44import signal
5+ import traceback
56
67from configuration_manager import ConfigurationManager
78from data_parser import ChatReceived
89from packets import packets
910from pparser import build_packet
1011from plugin_manager import PluginManager
1112from utilities import path , read_packet , State , Direction , ChatReceiveMode
13+ from zstd_reader import ZstdFrameReader
14+ from zstd_writer import ZstdFrameWriter
1215
1316
1417DEBUG = True
2124logger = logging .getLogger ('starrypy' )
2225logger .setLevel (loglevel )
2326
27+ class SwitchToZstdException (Exception ):
28+ pass
29+
2430class StarryPyServer :
2531 """
2632 Primary server class. Handles all the things.
2733 """
2834 def __init__ (self , reader , writer , config , factory ):
2935 logger .debug ("Initializing connection." )
30- self ._reader = reader
31- self ._writer = writer
32- self ._client_reader = None
33- self ._client_writer = None
36+ self ._reader = reader # read packets from client
37+ self ._writer = writer # writes packets to client
38+ self ._client_reader = None # read packets from server (acting as client)
39+ self ._client_writer = None # write packets to server
3440 self .factory = factory
35- self ._client_loop_future = None
41+ self ._client_loop_future = asyncio . create_task ( self . client_loop ())
3642 self ._server_loop_future = asyncio .create_task (self .server_loop ())
3743 self .state = None
3844 self ._alive = True
@@ -42,8 +48,20 @@ def __init__(self, reader, writer, config, factory):
4248 self ._client_read_future = None
4349 self ._server_write_future = None
4450 self ._client_write_future = None
51+ self ._expect_server_loop_death = False
4552 logger .info ("Received connection from {}" .format (self .client_ip ))
4653
54+ def start_zstd (self ):
55+ self ._reader = ZstdFrameReader (self ._reader , Direction .TO_SERVER )
56+ self ._client_reader = ZstdFrameReader (self ._client_reader , Direction .TO_CLIENT )
57+ self ._writer = ZstdFrameWriter (self ._writer , skip_packets = 1 )
58+ self ._client_writer = ZstdFrameWriter (self ._client_writer )
59+ self ._expect_server_loop_death = True
60+ self ._server_loop_future .cancel ()
61+ self ._server_loop_future = asyncio .create_task (self .server_loop ())
62+ logger .info ("Switched to zstd" )
63+
64+
4765 async def server_loop (self ):
4866 """
4967 Main server loop. As clients connect to the proxy, pass the
@@ -52,14 +70,15 @@ async def server_loop(self):
5270
5371 :return:
5472 """
55- (self ._client_reader , self ._client_writer ) = \
56- await asyncio .open_connection (self .config ['upstream_host' ],
57- self .config ['upstream_port' ])
58- self ._client_loop_future = asyncio .create_task (self .client_loop ())
73+
74+ # wait until client is available
75+ while self ._client_writer is None :
76+ await asyncio .sleep (0.1 )
77+
5978 try :
6079 while True :
6180 packet = await read_packet (self ._reader ,
62- Direction .TO_SERVER )
81+ Direction .TO_SERVER )
6382 # Break in case of emergencies:
6483 # if packet['type'] not in [17, 40, 41, 43, 48, 51]:
6584 # logger.debug('c->s {}'.format(packet['type']))
@@ -74,8 +93,14 @@ async def server_loop(self):
7493 except Exception as err :
7594 logger .error ("Server loop exception occurred:"
7695 "{}: {}" .format (err .__class__ .__name__ , err ))
96+ logger .error ("Error details and traceback: {}" .format (traceback .format_exc ()))
7797 finally :
78- self .die ()
98+ if not self ._expect_server_loop_death :
99+ logger .info ("Server loop ended." )
100+ self .die ()
101+ else :
102+ logger .info ("Restarting server loop for switch to zstd." )
103+ self ._expect_server_loop_death = False
79104
80105 async def client_loop (self ):
81106 """
@@ -84,6 +109,10 @@ async def client_loop(self):
84109
85110 :return:
86111 """
112+ (self ._client_reader , self ._client_writer ) = \
113+ await asyncio .open_connection (self .config ['upstream_host' ],
114+ self .config ['upstream_port' ])
115+
87116 try :
88117 while True :
89118 packet = await read_packet (self ._client_reader ,
0 commit comments