1+ {-# LANGUAGE LambdaCase #-}
2+ {-# LANGUAGE OverloadedStrings #-}
3+ {-# LANGUAGE ScopedTypeVariables #-}
4+
5+ module Network.Transport.QUIC.Internal (
6+ createTransport ,
7+ QUICAddr (.. ),
8+ encodeQUICAddr ,
9+ decodeQUICAddr ,
10+
11+ -- * Re-export to generate credentials
12+ Credential ,
13+ credentialLoadX509 ,
14+ ) where
15+
16+ import Control.Concurrent (ThreadId , forkFinally , killThread , myThreadId )
17+ import Control.Concurrent.STM (atomically )
18+ import Control.Concurrent.STM.TQueue (
19+ TQueue ,
20+ newTQueueIO ,
21+ readTQueue ,
22+ writeTQueue ,
23+ )
24+ import Control.Exception (bracket , try )
25+ import Control.Monad (void )
26+ import Data.Bifunctor (first )
27+ import Data.ByteString (StrictByteString )
28+ import Data.ByteString qualified as BS
29+ import Data.Foldable (traverse_ )
30+ import Data.Functor (($>) , (<&>) )
31+ import Data.IORef (IORef , newIORef , readIORef , writeIORef )
32+ import Data.List.NonEmpty (NonEmpty )
33+ import Data.Set (Set )
34+ import Data.Set qualified as Set
35+ import GHC.IORef (atomicModifyIORef'_ )
36+ import Network.QUIC (Stream )
37+ import Network.QUIC qualified as QUIC
38+ import Network.QUIC.Client qualified as QUIC.Client
39+ import Network.QUIC.Server qualified as QUIC.Server
40+ import Network.TLS (Credential )
41+ import Network.Transport (
42+ ConnectErrorCode (ConnectNotFound ),
43+ ConnectHints ,
44+ Connection (.. ),
45+ ConnectionId ,
46+ EndPoint (.. ),
47+ EndPointAddress ,
48+ Event (.. ),
49+ EventErrorCode (EventEndPointFailed ),
50+ NewEndPointErrorCode ,
51+ NewMulticastGroupErrorCode (NewMulticastGroupUnsupported ),
52+ Reliability ,
53+ ResolveMulticastGroupErrorCode (ResolveMulticastGroupUnsupported ),
54+ SendErrorCode (.. ),
55+ Transport (.. ),
56+ TransportError (.. ),
57+ )
58+ import Network.Transport.QUIC.Internal.Configuration (credentialLoadX509 , mkClientConfig , mkServerConfig )
59+ import Network.Transport.QUIC.Internal.QUICAddr (QUICAddr (.. ), decodeQUICAddr , encodeQUICAddr )
60+ import Network.Transport.QUIC.Internal.TransportState (TransportState , newTransportState , registerEndpoint , traverseTransportState )
61+
62+ {- | Create a new Transport.
63+
64+ Only a single transport should be created per Haskell process
65+ (threads can, and should, create their own endpoints though).
66+ -}
67+ createTransport ::
68+ QUICAddr ->
69+ NonEmpty Credential ->
70+ IO Transport
71+ createTransport quicAddr creds = do
72+ transportState <- newTransportState
73+ pure $
74+ Transport
75+ (newEndpoint transportState quicAddr creds)
76+ (closeQUICTransport transportState)
77+
78+ newEndpoint ::
79+ TransportState ->
80+ QUICAddr ->
81+ NonEmpty Credential ->
82+ IO (Either (TransportError NewEndPointErrorCode ) EndPoint )
83+ newEndpoint transportState quicAddr@ (QUICAddr host port) creds = do
84+ eventQueue <- newTQueueIO
85+
86+ state <- EndpointState <$> newIORef mempty
87+
88+ serverConfig <- mkServerConfig host port creds
89+ serverThread <-
90+ forkFinally
91+ ( QUIC.Server. run
92+ serverConfig
93+ ( withQUICStream $
94+ -- TODO: create a bidirectional stream
95+ -- which can be re-used for sending
96+ \ stream ->
97+ -- We register which threads are actively receiving or sending
98+ -- data such that we can cleanly stop
99+ withThreadRegistered state $ do
100+ -- TODO: how to ensure positivity of ConnectionId? QUIC StreamID should be a 62 bit integer,
101+ -- so there's room to make it a positive 64 bit integer (ConnectionId ~ Word64)
102+ let connId = fromIntegral (QUIC. streamId stream)
103+ receiveLoop connId stream eventQueue
104+ )
105+ )
106+ ( \ case
107+ Left exc -> atomically (writeTQueue eventQueue (ErrorEvent (TransportError EventEndPointFailed (show exc))))
108+ Right _ -> pure ()
109+ )
110+
111+ let endpoint =
112+ EndPoint
113+ (atomically (readTQueue eventQueue))
114+ (encodeQUICAddr quicAddr)
115+ (connectQUIC creds)
116+ (pure . Left $ TransportError NewMulticastGroupUnsupported " Multicast not supported" )
117+ (pure . Left . const (TransportError ResolveMulticastGroupUnsupported " Multicast not supported" ))
118+ (stopAllThreads state >> killThread serverThread >> atomically (writeTQueue eventQueue EndPointClosed ))
119+ void $ transportState `registerEndpoint` endpoint
120+ pure $ Right endpoint
121+ where
122+ receiveLoop ::
123+ ConnectionId ->
124+ QUIC. Stream ->
125+ TQueue Event ->
126+ IO ()
127+ receiveLoop connId stream eventQueue = do
128+ incoming <- QUIC. recvStream stream 1024 -- TODO: variable length?
129+ -- TODO: check some state whether we should stop all connections
130+ if BS. null incoming
131+ then do
132+ atomically (writeTQueue eventQueue (ConnectionClosed connId))
133+ else do
134+ atomically (writeTQueue eventQueue (Received connId [incoming]))
135+ receiveLoop connId stream eventQueue
136+
137+ withQUICStream :: (QUIC. Stream -> IO a ) -> QUIC. Connection -> IO a
138+ withQUICStream f conn =
139+ bracket
140+ (QUIC. waitEstablished conn >> QUIC. acceptStream conn)
141+ (\ stream -> QUIC. closeStream stream >> QUIC.Server. stop conn)
142+ f
143+
144+ connectQUIC ::
145+ NonEmpty Credential ->
146+ EndPointAddress ->
147+ Reliability ->
148+ ConnectHints ->
149+ IO (Either (TransportError ConnectErrorCode ) Connection )
150+ connectQUIC creds endpointAddress _reliability _connectHints =
151+ case decodeQUICAddr endpointAddress of
152+ Left errmsg -> pure $ Left $ TransportError ConnectNotFound (" Could not decode QUIC address: " <> errmsg)
153+ Right (QUICAddr hostname port) ->
154+ try $ do
155+ clientConfig <- mkClientConfig hostname port creds
156+
157+ QUIC.Client. run clientConfig $ \ conn -> do
158+ QUIC. waitEstablished conn
159+ stream <- QUIC. stream conn
160+
161+ pure $
162+ Connection
163+ (sendQUIC stream)
164+ (QUIC. closeStream stream)
165+ where
166+ sendQUIC :: Stream -> [StrictByteString ] -> IO (Either (TransportError SendErrorCode ) () )
167+ sendQUIC stream payloads =
168+ try (QUIC. sendStreamMany stream payloads)
169+ <&> first
170+ ( \ case
171+ QUIC. StreamIsClosed -> TransportError SendClosed " QUIC stream is closed"
172+ QUIC. ConnectionIsClosed reason -> TransportError SendClosed (show reason)
173+ other -> TransportError SendFailed (show other)
174+ )
175+
176+ closeQUICTransport :: TransportState -> IO ()
177+ closeQUICTransport = flip traverseTransportState (\ _ endpoint -> closeEndPoint endpoint)
178+
179+ {- | We keep track of all threads actively listening on QUIC streams
180+ so that we can cleanly stop these threads when closing the endpoint.
181+
182+ See 'withThreadRegistered' for a combinator which automatically keeps
183+ track of these threads
184+ -}
185+ newtype EndpointState = EndpointState
186+ { threads :: IORef (Set ThreadId )
187+ }
188+
189+ withThreadRegistered :: EndpointState -> IO a -> IO a
190+ withThreadRegistered state f =
191+ bracket
192+ registerThread
193+ unregisterThread
194+ (const f)
195+ where
196+ registerThread =
197+ myThreadId
198+ >>= \ tid ->
199+ atomicModifyIORef'_ (threads state) (Set. insert tid)
200+ $> tid
201+
202+ unregisterThread tid =
203+ atomicModifyIORef'_ (threads state) (Set. insert tid)
204+
205+ stopAllThreads :: EndpointState -> IO ()
206+ stopAllThreads (EndpointState tds) = do
207+ readIORef tds >>= traverse_ killThread
208+ writeIORef tds mempty -- so that we can call `closeQUICTransport` even after the endpoint has been closed
0 commit comments