11import asyncio
2+ import collections
23import datetime
4+ import io
35import os
46import sys
57import websockets
8+ import websockets .extensions .permessage_deflate
9+ import websockets .framing
610
711
812DEBUG = 'WSDEBUG' in os .environ and os .environ ['WSDEBUG' ] == '1'
@@ -22,63 +26,121 @@ async def stdin(loop):
2226 return reader
2327
2428
25- async def stdin_to_amplifier (amplifier , loop ):
29+ async def stdin_to_amplifier (amplifier , loop , stats ):
2630 reader = await stdin (loop )
2731 while True :
28- amplifier .send ((await reader .readline ()).decode ('utf-8' ).strip ())
32+ d = await reader .readline ()
33+ stats ['stdin read' ] += len (d )
34+ amplifier .send (d .decode ('utf-8' ).strip ())
35+
36+
37+ def websocket_extensions_to_key (extensions ):
38+ # Convert a list of websockets extensions into a key, handling PerMessageDeflate objects with the relevant care for server-side compression dedupe
39+ def _inner ():
40+ for e in extensions :
41+ if isinstance (e , websockets .extensions .permessage_deflate .PerMessageDeflate ) and e .local_no_context_takeover :
42+ yield (websockets .extensions .permessage_deflate .PerMessageDeflate , e .remote_max_window_bits , e .local_max_window_bits , tuple (e .compress_settings .items ()))
43+ else :
44+ yield e
45+ return tuple (_inner ())
2946
3047
3148class MessageAmplifier :
32- def __init__ (self ):
33- self .queues = {}
49+ def __init__ (self , stats ):
50+ self .queues = {} # websocket -> queue
51+ self ._stats = stats
3452
3553 def register (self , websocket ):
36- self .queues [websocket ] = asyncio .Queue (maxsize = 1000 )
37- return self .queues [websocket ]
54+ q = asyncio .Queue (maxsize = 1000 )
55+ self .queues [websocket ] = q
56+ return q
3857
3958 def send (self , message ):
40- for queue in self .queues .values ():
59+ #FIXME This abuses internal API of websockets==7.0
60+ # Using the normal `websocket.send` reencodes and recompresses the message for every client.
61+ # So we construct the relevant Frame once instead and push that to the individual queues.
62+ frame = websockets .framing .Frame (fin = True , opcode = websockets .framing .OP_TEXT , data = message .encode ('utf-8' ))
63+ data = {} # tuple of extensions key → bytes
64+ for websocket , queue in self .queues .items ():
65+ extensionsKey = websocket_extensions_to_key (websocket .extensions )
66+ if extensionsKey not in data :
67+ output = io .BytesIO ()
68+ frame .write (output .write , mask = False , extensions = websocket .extensions )
69+ data [extensionsKey ] = output .getvalue ()
70+ self ._stats ['frame writes' ] += len (data [extensionsKey ])
4171 try :
42- queue .put_nowait (message )
72+ queue .put_nowait (data [ extensionsKey ] )
4373 except asyncio .QueueFull :
4474 # Pop one, try again; it should be impossible for this to fail, so no try/except here.
45- queue .get_nowait ()
46- queue .put_nowait (message )
75+ dropped = queue .get_nowait ()
76+ self ._stats ['dropped' ] += len (dropped )
77+ queue .put_nowait (data [extensionsKey ])
4778
4879 def unregister (self , websocket ):
4980 del self .queues [websocket ]
5081
5182
52- async def websocket_server (amplifier , websocket , path ):
83+ async def websocket_server (amplifier , websocket , path , stats ):
5384 queue = amplifier .register (websocket )
5485 try :
5586 while True :
56- await websocket .send (await queue .get ())
87+ #FIXME See above; this is write_frame essentially
88+ data = await queue .get ()
89+ await websocket .ensure_open ()
90+ websocket .writer .write (data )
91+ stats ['sent' ] += len (data )
92+ if websocket .writer .transport is not None :
93+ if websocket .writer_is_closing ():
94+ await asyncio .sleep (0 )
95+ try :
96+ async with websocket ._drain_lock :
97+ await websocket .writer .drain ()
98+ except ConnectionError :
99+ websocket .fail_connection ()
100+ await websocket .ensure_open ()
57101 except websockets .exceptions .ConnectionClosed : # Silence connection closures
58102 pass
59103 finally :
60104 amplifier .unregister (websocket )
61105
62106
63- async def print_status (amplifier ):
107+ async def print_status (amplifier , stats ):
108+ interval = 60
64109 previousUtime = None
110+ previousStats = {}
65111 while True :
66112 currentUtime = os .times ().user
67- cpu = (currentUtime - previousUtime ) / 60 * 100 if previousUtime is not None else float ('nan' )
68- print (f'{ datetime .datetime .now ():%Y-%m-%d %H:%M:%S} - { len (amplifier .queues )} clients, { sum (q .qsize () for q in amplifier .queues .values ())} total queue size, { cpu :.1f} % CPU, { get_rss ()/ 1048576 :.1f} MiB RSS' )
113+ cpu = (currentUtime - previousUtime ) / interval * 100 if previousUtime is not None else float ('nan' )
114+ print (f'{ datetime .datetime .now ():%Y-%m-%d %H:%M:%S} - ' +
115+ ', ' .join ([
116+ f'{ len (amplifier .queues )} clients' ,
117+ f'{ sum (q .qsize () for q in amplifier .queues .values ())} total queue size' ,
118+ f'{ cpu :.1f} % CPU' ,
119+ f'{ get_rss ()/ 1048576 :.1f} MiB RSS' ,
120+ 'throughput: ' + ', ' .join (f'{ (stats [k ] - previousStats .get (k , 0 ))/ interval / 1000 :.1f} kB/s { k } ' for k in stats ),
121+ ])
122+ )
69123 if DEBUG :
70124 for socket in amplifier .queues :
71125 print (f' { socket .remote_address } : { amplifier .queues [socket ].qsize ()} ' )
72126 previousUtime = currentUtime
73- await asyncio .sleep (60 )
127+ previousStats .update (stats )
128+ await asyncio .sleep (interval )
74129
75130
76131def main ():
77- amplifier = MessageAmplifier ()
78- start_server = websockets .serve (lambda websocket , path : websocket_server (amplifier , websocket , path ), None , 4568 )
132+ stats = {'stdin read' : 0 , 'frame writes' : 0 , 'sent' : 0 , 'dropped' : 0 }
133+ amplifier = MessageAmplifier (stats )
134+ # Disable context takeover (cf. RFC 7692) so the compression can be reused
135+ start_server = websockets .serve (
136+ lambda websocket , path : websocket_server (amplifier , websocket , path , stats ),
137+ None ,
138+ 4568 ,
139+ extensions = [websockets .extensions .permessage_deflate .ServerPerMessageDeflateFactory (server_no_context_takeover = True )]
140+ )
79141 loop = asyncio .get_event_loop ()
80142 loop .run_until_complete (start_server )
81- loop .run_until_complete (asyncio .gather (stdin_to_amplifier (amplifier , loop ), print_status (amplifier )))
143+ loop .run_until_complete (asyncio .gather (stdin_to_amplifier (amplifier , loop , stats ), print_status (amplifier , stats )))
82144
83145
84146if __name__ == '__main__' :
0 commit comments