@@ -27,8 +27,8 @@ func withMockServer<Result>(
2727 _ body: ( _ port: Int ) async throws -> Result
2828) async throws -> Result {
2929 let eventLoopGroup = NIOSingletons . posixEventLoopGroup
30- let server = MockLambdaServer ( behavior: behaviour, port: port, keepAlive: keepAlive)
31- let port = try await server. start ( ) . get ( )
30+ let server = MockLambdaServer ( behavior: behaviour, port: port, keepAlive: keepAlive, eventLoopGroup : eventLoopGroup )
31+ let port = try await server. start ( )
3232
3333 let result : Swift . Result < Result , any Error >
3434 do {
@@ -37,13 +37,13 @@ func withMockServer<Result>(
3737 result = . failure( error)
3838 }
3939
40- try ? await server. stop ( ) . get ( )
40+ try ? await server. stop ( )
4141 return try result. get ( )
4242}
4343
44- final class MockLambdaServer {
44+ final class MockLambdaServer < Behavior : LambdaServerBehavior > {
4545 private let logger = Logger ( label: " MockLambdaServer " )
46- private let behavior : LambdaServerBehavior
46+ private let behavior : Behavior
4747 private let host : String
4848 private let port : Int
4949 private let keepAlive : Bool
@@ -52,7 +52,13 @@ final class MockLambdaServer {
5252 private var channel : Channel ?
5353 private var shutdown = false
5454
55- init ( behavior: LambdaServerBehavior , host: String = " 127.0.0.1 " , port: Int = 7000 , keepAlive: Bool = true ) {
55+ init (
56+ behavior: Behavior ,
57+ host: String = " 127.0.0.1 " ,
58+ port: Int = 7000 ,
59+ keepAlive: Bool = true ,
60+ eventLoopGroup: MultiThreadedEventLoopGroup
61+ ) {
5662 self . group = NIOSingletons . posixEventLoopGroup
5763 self . behavior = behavior
5864 self . host = host
@@ -64,39 +70,41 @@ final class MockLambdaServer {
6470 assert ( shutdown)
6571 }
6672
67- func start( ) -> EventLoopFuture < Int > {
68- let bootstrap = ServerBootstrap ( group: group)
73+ fileprivate func start( ) async throws -> Int {
74+ let logger = self . logger
75+ let keepAlive = self . keepAlive
76+ let behavior = self . behavior
77+
78+ let channel = try await ServerBootstrap ( group: group)
6979 . serverChannelOption ( ChannelOptions . socket ( SocketOptionLevel ( SOL_SOCKET) , SO_REUSEADDR) , value: 1 )
7080 . childChannelInitializer { channel in
7181 do {
7282 try channel. pipeline. syncOperations. configureHTTPServerPipeline ( withErrorHandling: true )
7383 try channel. pipeline. syncOperations. addHandler (
74- HTTPHandler ( logger: self . logger, keepAlive: self . keepAlive, behavior: self . behavior)
84+ HTTPHandler ( logger: logger, keepAlive: keepAlive, behavior: behavior)
7585 )
7686 return channel. eventLoop. makeSucceededVoidFuture ( )
7787 } catch {
7888 return channel. eventLoop. makeFailedFuture ( error)
7989 }
8090 }
81- return bootstrap. bind ( host: self . host, port: self . port) . flatMap { channel in
82- self . channel = channel
83- guard let localAddress = channel. localAddress else {
84- return channel. eventLoop. makeFailedFuture ( ServerError . cantBind)
85- }
86- self . logger. info ( " \( self ) started and listening on \( localAddress) " )
87- return channel. eventLoop. makeSucceededFuture ( localAddress. port!)
91+ . bind ( host: self . host, port: self . port)
92+ . get ( )
93+
94+ self . channel = channel
95+ guard let localAddress = channel. localAddress else {
96+ throw ServerError . cantBind
8897 }
98+ self . logger. info ( " \( self ) started and listening on \( localAddress) " )
99+ return localAddress. port!
89100 }
90101
91- func stop( ) -> EventLoopFuture < Void > {
102+ fileprivate func stop( ) async throws {
92103 self . logger. info ( " stopping \( self ) " )
93- guard let channel = self . channel else {
94- return self . group. next ( ) . makeFailedFuture ( ServerError . notReady)
95- }
96- return channel. close ( ) . always { _ in
97- self . shutdown = true
98- self . logger. info ( " \( self ) stopped " )
99- }
104+ let channel = self . channel!
105+ try ? await channel. close ( ) . get ( )
106+ self . shutdown = true
107+ self . logger. info ( " \( self ) stopped " )
100108 }
101109}
102110
@@ -221,32 +229,37 @@ final class HTTPHandler: ChannelInboundHandler {
221229 }
222230 let head = HTTPResponseHead ( version: HTTPVersion ( major: 1 , minor: 1 ) , status: status, headers: headers)
223231
232+ let logger = self . logger
224233 context. write ( wrapOutboundOut ( . head( head) ) ) . whenFailure { error in
225- self . logger. error ( " \( self ) write error \( error) " )
234+ logger. error ( " write error \( error) " )
226235 }
227236
228237 if let b = body {
229238 var buffer = context. channel. allocator. buffer ( capacity: b. utf8. count)
230239 buffer. writeString ( b)
231240 context. write ( wrapOutboundOut ( . body( . byteBuffer( buffer) ) ) ) . whenFailure { error in
232- self . logger. error ( " \( self ) write error \( error) " )
241+ logger. error ( " write error \( error) " )
233242 }
234243 }
235244
245+ let loopBoundContext = NIOLoopBound ( context, eventLoop: context. eventLoop)
246+
247+ let keepAlive = self . keepAlive
236248 context. writeAndFlush ( wrapOutboundOut ( . end( nil ) ) ) . whenComplete { result in
237249 if case . failure( let error) = result {
238- self . logger. error ( " \( self ) write error \( error) " )
250+ logger. error ( " write error \( error) " )
239251 }
240- if !self . keepAlive {
252+ if !keepAlive {
253+ let context = loopBoundContext. value
241254 context. close ( ) . whenFailure { error in
242- self . logger. error ( " \( self ) close error \( error) " )
255+ logger. error ( " close error \( error) " )
243256 }
244257 }
245258 }
246259 }
247260}
248261
249- protocol LambdaServerBehavior {
262+ protocol LambdaServerBehavior: Sendable {
250263 func getInvocation( ) -> GetInvocationResult
251264 func processResponse( requestId: String , response: String ? ) -> Result < Void , ProcessResponseError >
252265 func processError( requestId: String , error: ErrorResponse ) -> Result < Void , ProcessErrorError >
0 commit comments