11from __future__ import annotations
22
33from functools import partial
4- from typing import Awaitable , Callable , Dict , Optional , Tuple
4+ from typing import Awaitable , Callable , Dict , Optional , Set , Tuple
55
66from aioquic .buffer import Buffer
77from aioquic .h3 .connection import H3_ALPN
2222from .h3 import H3Protocol
2323from ..config import Config
2424from ..events import Closed , Event , RawData
25- from ..typing import AppWrapper , TaskGroup , WorkerContext
25+ from ..typing import AppWrapper , TaskGroup , WorkerContext , Timer
26+
27+
28+ class ConnectionState :
29+ def __init__ (self , connection : QuicConnection ):
30+ self .connection = connection
31+ self .timer : Optional [Timer ] = None
32+ self .cids : Set [bytes ] = set ()
33+ self .h3_protocol : Optional [H3Protocol ] = None
34+
35+ def add_cid (self , cid : bytes ) -> None :
36+ self .cids .add (cid )
37+
38+ def remove_cid (self , cid : bytes ) -> None :
39+ self .cids .remove (cid )
2640
2741
2842class QuicProtocol :
@@ -38,8 +52,7 @@ def __init__(
3852 self .app = app
3953 self .config = config
4054 self .context = context
41- self .connections : Dict [bytes , QuicConnection ] = {}
42- self .http_connections : Dict [QuicConnection , H3Protocol ] = {}
55+ self .connections : Dict [bytes , ConnectionState ] = {}
4356 self .send = send
4457 self .server = server
4558 self .task_group = task_group
@@ -49,7 +62,7 @@ def __init__(
4962
5063 @property
5164 def idle (self ) -> bool :
52- return len (self .connections ) == 0 and len ( self . http_connections ) == 0
65+ return len (self .connections ) == 0
5366
5467 async def handle (self , event : Event ) -> None :
5568 if isinstance (event , RawData ):
@@ -69,9 +82,13 @@ async def handle(self, event: Event) -> None:
6982 await self .send (RawData (data = data , address = event .address ))
7083 return
7184
72- connection = self .connections .get (header .destination_cid )
85+ state = self .connections .get (header .destination_cid )
86+ if state is not None :
87+ connection = state .connection
88+ else :
89+ connection = None
7390 if (
74- connection is None
91+ state is None
7592 and len (event .data ) >= 1200
7693 and header .packet_type == PACKET_TYPE_INITIAL
7794 and not self .context .terminated .is_set ()
@@ -80,12 +97,18 @@ async def handle(self, event: Event) -> None:
8097 configuration = self .quic_config ,
8198 original_destination_connection_id = header .destination_cid ,
8299 )
83- self .connections [header .destination_cid ] = connection
84- self .connections [connection .host_cid ] = connection
100+ # This partial() needs python >= 3.8
101+ state = ConnectionState (connection )
102+ timer = self .task_group .create_timer (partial (self ._timeout , state ))
103+ state .timer = timer
104+ state .add_cid (header .destination_cid )
105+ self .connections [header .destination_cid ] = state
106+ state .add_cid (connection .host_cid )
107+ self .connections [connection .host_cid ] = state
85108
86109 if connection is not None :
87110 connection .receive_datagram (event .data , event .address , now = self .context .time ())
88- await self ._handle_events ( connection , event . address )
111+ await self ._wake_up_timer ( state )
89112 elif isinstance (event , Closed ):
90113 pass
91114
@@ -94,42 +117,50 @@ async def send_all(self, connection: QuicConnection) -> None:
94117 await self .send (RawData (data = data , address = address ))
95118
96119 async def _handle_events (
97- self , connection : QuicConnection , client : Optional [Tuple [str , int ]] = None
120+ self , state : ConnectionState , client : Optional [Tuple [str , int ]] = None
98121 ) -> None :
122+ connection = state .connection
99123 event = connection .next_event ()
100124 while event is not None :
101125 if isinstance (event , ConnectionTerminated ):
102- pass
126+ await state .timer .stop ()
127+ for cid in state .cids :
128+ del self .connections [cid ]
129+ state .cids = set ()
103130 elif isinstance (event , ProtocolNegotiated ):
104- self . http_connections [ connection ] = H3Protocol (
131+ state . h3_protocol = H3Protocol (
105132 self .app ,
106133 self .config ,
107134 self .context ,
108135 self .task_group ,
109136 client ,
110137 self .server ,
111138 connection ,
112- partial (self .send_all , connection ),
139+ partial (self ._wake_up_timer , state ),
113140 )
114141 elif isinstance (event , ConnectionIdIssued ):
115- self .connections [event .connection_id ] = connection
142+ state .add_cid (event .connection_id )
143+ self .connections [event .connection_id ] = state
116144 elif isinstance (event , ConnectionIdRetired ):
145+ state .remove_cid (event .connection_id )
117146 del self .connections [event .connection_id ]
118147
119- if connection in self . http_connections :
120- await self . http_connections [ connection ] .handle (event )
148+ elif state . h3_protocol is not None :
149+ await state . h3_protocol .handle (event )
121150
122151 event = connection .next_event ()
123152
153+ async def _wake_up_timer (self , state : ConnectionState ) -> None :
154+ # When new output is send, or new input is received, we
155+ # fire the timer right away so we update our state.
156+ await state .timer .schedule (0.0 )
157+
158+ async def _timeout (self , state : ConnectionState ) -> None :
159+ connection = state .connection
160+ now = self .context .time ()
161+ when = connection .get_timer ()
162+ if when is not None and now > when :
163+ connection .handle_timer (now )
164+ await self ._handle_events (state , None )
124165 await self .send_all (connection )
125-
126- timer = connection .get_timer ()
127- if timer is not None :
128- self .task_group .spawn (self ._handle_timer , timer , connection )
129-
130- async def _handle_timer (self , timer : float , connection : QuicConnection ) -> None :
131- wait = max (0 , timer - self .context .time ())
132- await self .context .sleep (wait )
133- if connection ._close_at is not None :
134- connection .handle_timer (now = self .context .time ())
135- await self ._handle_events (connection , None )
166+ await state .timer .schedule (connection .get_timer ())
0 commit comments