Skip to content

Commit d73dce6

Browse files
author
Sebastien Stormacq
committed
remove the fatal error to make testinge easier
1 parent b341bb4 commit d73dce6

File tree

2 files changed

+94
-51
lines changed

2 files changed

+94
-51
lines changed

Sources/AWSLambdaRuntime/Lambda+LocalServer.swift

Lines changed: 66 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,42 @@ internal struct LambdaHTTPServer {
391391

392392
logger.trace("/invoke received invocation, pushing it to the pool and wait for a lambda response")
393393
// detect concurrent invocations of POST and gently decline the requests while we're processing one.
394-
if !self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body)) {
394+
self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body))
395+
396+
// wait for the lambda function to process the request
397+
// when POST /invoke is called multiple times before a response is process, the
398+
// `for try await ... in` loop will throw an error and we will return a 400 error to the client
399+
do {
400+
for try await response in self.responsePool {
401+
logger[metadataKey: "response requestId"] = "\(response.requestId ?? "nil")"
402+
logger.trace("Received response to return to client")
403+
if response.requestId == requestId {
404+
logger.trace("/invoke requestId is valid, sending the response")
405+
// send the response to the client
406+
// if the response is final, we can send it and return
407+
// if the response is not final, we can send it and wait for the next response
408+
try await self.sendResponse(response, outbound: outbound, logger: logger)
409+
if response.final == true {
410+
logger.trace("/invoke returning")
411+
return // if the response is final, we can return and close the connection
412+
}
413+
} else {
414+
logger.error(
415+
"Received response for a different requestId",
416+
metadata: ["response requestId": "\(response.requestId ?? "")"]
417+
)
418+
let response = LocalServerResponse(
419+
id: requestId,
420+
status: .badRequest,
421+
body: ByteBuffer(string: "The responseId is not equal to the requestId.")
422+
)
423+
try await self.sendResponse(response, outbound: outbound, logger: logger)
424+
}
425+
}
426+
// What todo when there is no more responses to process?
427+
// This should not happen as the async iterator blocks until there is a response to process
428+
fatalError("No more responses to process - the async for loop should not return")
429+
} catch is LambdaHTTPServer.Pool<LambdaHTTPServer.LocalServerResponse>.PoolError {
395430
let response = LocalServerResponse(
396431
id: requestId,
397432
status: .badRequest,
@@ -401,39 +436,7 @@ internal struct LambdaHTTPServer {
401436
)
402437
)
403438
try await self.sendResponse(response, outbound: outbound, logger: logger)
404-
return
405-
}
406-
407-
// wait for the lambda function to process the request
408-
for try await response in self.responsePool {
409-
logger[metadataKey: "response requestId"] = "\(response.requestId ?? "nil")"
410-
logger.trace("Received response to return to client")
411-
if response.requestId == requestId {
412-
logger.trace("/invoke requestId is valid, sending the response")
413-
// send the response to the client
414-
// if the response is final, we can send it and return
415-
// if the response is not final, we can send it and wait for the next response
416-
try await self.sendResponse(response, outbound: outbound, logger: logger)
417-
if response.final == true {
418-
logger.trace("/invoke returning")
419-
return // if the response is final, we can return and close the connection
420-
}
421-
} else {
422-
logger.error(
423-
"Received response for a different requestId",
424-
metadata: ["response requestId": "\(response.requestId ?? "")"]
425-
)
426-
let response = LocalServerResponse(
427-
id: requestId,
428-
status: .badRequest,
429-
body: ByteBuffer(string: "The responseId is not equal to the requestId.")
430-
)
431-
try await self.sendResponse(response, outbound: outbound, logger: logger)
432-
}
433439
}
434-
// What todo when there is no more responses to process?
435-
// This should not happen as the async iterator blocks until there is a response to process
436-
fatalError("No more responses to process - the async for loop should not return")
437440

438441
// client uses incorrect HTTP method
439442
case (_, let url) where url.hasSuffix(self.invocationEndpoint):
@@ -579,9 +582,7 @@ internal struct LambdaHTTPServer {
579582
private let lock = Mutex<State>(.buffer([]))
580583

581584
/// enqueue an element, or give it back immediately to the iterator if it is waiting for an element
582-
/// Returns true when we receive a element and the pool was in "waiting for continuation" state, false otherwise
583-
@discardableResult
584-
public func push(_ invocation: T) -> Bool {
585+
public func push(_ invocation: T) {
585586

586587
// if the iterator is waiting for an element on `next()``, give it to it
587588
// otherwise, enqueue the element
@@ -598,12 +599,7 @@ internal struct LambdaHTTPServer {
598599
}
599600
}
600601

601-
if let maybeContinuation {
602-
maybeContinuation.resume(returning: invocation)
603-
return true
604-
} else {
605-
return false
606-
}
602+
maybeContinuation?.resume(returning: invocation)
607603
}
608604

609605
func next() async throws -> T? {
@@ -614,34 +610,39 @@ internal struct LambdaHTTPServer {
614610

615611
return try await withTaskCancellationHandler {
616612
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
617-
let nextAction = self.lock.withLock { state -> T? in
613+
let (nextAction, nextError) = self.lock.withLock { state -> (T?, PoolError?) in
618614
switch consume state {
619615
case .buffer(var buffer):
620616
if let first = buffer.popFirst() {
621617
state = .buffer(buffer)
622-
return first
618+
return (first, nil)
623619
} else {
624620
state = .continuation(continuation)
625-
return nil
621+
return (nil, nil)
626622
}
627623

628-
case .continuation(_):
629-
fatalError("\(self.poolName) : Concurrent invocations to next(). This is not allowed.")
624+
case .continuation(let previousContinuation):
625+
state = .buffer([])
626+
return (nil, PoolError(cause: .nextCalledTwice([previousContinuation, continuation])))
630627
}
631628
}
632629

633-
guard let nextAction else { return }
634-
635-
continuation.resume(returning: nextAction)
630+
if let nextError,
631+
case let .nextCalledTwice(continuations) = nextError.cause
632+
{
633+
for continuation in continuations { continuation?.resume(throwing: nextError) }
634+
} else if let nextAction {
635+
continuation.resume(returning: nextAction)
636+
}
636637
}
637638
} onCancel: {
638639
self.lock.withLock { state in
639640
switch consume state {
640641
case .buffer(let buffer):
641642
state = .buffer(buffer)
642643
case .continuation(let continuation):
643-
continuation?.resume(throwing: CancellationError())
644644
state = .buffer([])
645+
continuation?.resume(throwing: CancellationError())
645646
}
646647
}
647648
}
@@ -650,6 +651,20 @@ internal struct LambdaHTTPServer {
650651
func makeAsyncIterator() -> Pool {
651652
self
652653
}
654+
655+
struct PoolError: Error {
656+
let cause: Cause
657+
var message: String {
658+
switch self.cause {
659+
case .nextCalledTwice:
660+
return "Concurrent invocations to next(). This is not allowed."
661+
}
662+
}
663+
664+
enum Cause {
665+
case nextCalledTwice([CheckedContinuation<T, any Error>?])
666+
}
667+
}
653668
}
654669

655670
private struct LocalServerResponse: Sendable {

Tests/AWSLambdaRuntimeTests/PoolTests.swift

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,32 @@ struct PoolTests {
158158
#expect(receivedValues.count == producerCount * messagesPerProducer)
159159
#expect(Set(receivedValues).count == producerCount * messagesPerProducer)
160160
}
161+
162+
@Test
163+
@available(LambdaSwift 2.0, *)
164+
func testConcurrentNext() async throws {
165+
let pool = LambdaHTTPServer.Pool<String>()
166+
167+
// Create two tasks that will both wait for elements to be available
168+
await #expect(throws: LambdaHTTPServer.Pool<Swift.String>.PoolError.self) {
169+
try await withThrowingTaskGroup(of: Void.self) { group in
170+
171+
// one of the two task will throw a PoolError
172+
173+
group.addTask {
174+
for try await _ in pool {
175+
}
176+
Issue.record("Loop 1 should not complete")
177+
}
178+
179+
group.addTask {
180+
for try await _ in pool {
181+
}
182+
Issue.record("Loop 2 should not complete")
183+
}
184+
try await group.waitForAll()
185+
}
186+
}
187+
}
188+
161189
}

0 commit comments

Comments
 (0)