Skip to content

Commit 0ed15f9

Browse files
author
Sebastien Stormacq
committed
fix parallel invocation for non streaming lambda functions
1 parent 009b5c6 commit 0ed15f9

File tree

1 file changed

+130
-58
lines changed

1 file changed

+130
-58
lines changed

Sources/AWSLambdaRuntime/Lambda+LocalServer.swift

Lines changed: 130 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -401,38 +401,22 @@ internal struct LambdaHTTPServer {
401401
self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body))
402402

403403
// wait for the lambda function to process the request
404-
// when POST /invoke is called multiple times before a response is processed,
405-
// the `for try await ... in` loop will throw an error and we will return a 400 error to the client
404+
// Handle streaming responses by collecting all chunks for this requestId
406405
do {
407-
for try await response in self.responsePool {
406+
var isComplete = false
407+
while !isComplete {
408+
let response = try await self.responsePool.next(for: requestId)
408409
logger[metadataKey: "response_requestId"] = "\(response.requestId ?? "nil")"
409-
logger.trace("Received response to return to client")
410-
if response.requestId == requestId {
411-
logger.trace("/invoke requestId is valid, sending the response")
412-
// send the response to the client
413-
// if the response is final, we can send it and return
414-
// if the response is not final, we can send it and wait for the next response
415-
try await self.sendResponse(response, outbound: outbound, logger: logger)
416-
if response.final == true {
417-
logger.trace("/invoke returning")
418-
return // if the response is final, we can return and close the connection
419-
}
420-
} else {
421-
logger.error(
422-
"Received response for a different requestId",
423-
metadata: ["response requestId": "\(response.requestId ?? "")"]
424-
)
425-
let response = LocalServerResponse(
426-
id: requestId,
427-
status: .badRequest,
428-
body: ByteBuffer(string: "The responseId is not equal to the requestId.")
429-
)
430-
try await self.sendResponse(response, outbound: outbound, logger: logger)
410+
logger.trace("Received response chunk to return to client")
411+
412+
// send the response chunk to the client
413+
try await self.sendResponse(response, outbound: outbound, logger: logger)
414+
415+
if response.final == true {
416+
logger.trace("/invoke complete, returning")
417+
isComplete = true
431418
}
432419
}
433-
// What todo when there is no more responses to process?
434-
// This should not happen as the async iterator blocks until there is a response to process
435-
fatalError("No more responses to process - the async for loop should not return")
436420
} catch is LambdaHTTPServer.Pool<LambdaHTTPServer.LocalServerResponse>.PoolError {
437421
logger.trace("PoolError catched")
438422
// detect concurrent invocations of POST and gently decline the requests while we're processing one.
@@ -587,60 +571,127 @@ internal struct LambdaHTTPServer {
587571

588572
enum State: ~Copyable {
589573
case buffer(Deque<T>)
590-
case continuation(CheckedContinuation<T, any Error>)
574+
case waitingForAny(CheckedContinuation<T, any Error>) // FIFO waiting (for invocations)
575+
case waitingForSpecific([String: CheckedContinuation<T, any Error>]) // RequestId-based waiting (for responses)
591576
}
592577

593578
private let lock = Mutex<State>(.buffer([]))
594579

595580
/// enqueue an element, or give it back immediately to the iterator if it is waiting for an element
596-
public func push(_ invocation: T) {
597-
// if the iterator is waiting for an element on `next()``, give it to it
598-
// otherwise, enqueue the element
599-
let maybeContinuation = self.lock.withLock { state -> CheckedContinuation<T, any Error>? in
581+
public func push(_ item: T) {
582+
let continuationToResume = self.lock.withLock { state -> CheckedContinuation<T, any Error>? in
600583
switch consume state {
601-
case .continuation(let continuation):
602-
state = .buffer([])
603-
return continuation
604-
605584
case .buffer(var buffer):
606-
buffer.append(invocation)
585+
buffer.append(item)
607586
state = .buffer(buffer)
608587
return nil
588+
589+
case .waitingForAny(let continuation):
590+
// Someone is waiting for any item (FIFO)
591+
state = .buffer([])
592+
return continuation
593+
594+
case .waitingForSpecific(var continuations):
595+
// Check if this item matches any waiting continuation
596+
if let response = item as? LocalServerResponse,
597+
let requestId = response.requestId,
598+
let continuation = continuations.removeValue(forKey: requestId)
599+
{
600+
// Found a matching continuation
601+
if continuations.isEmpty {
602+
state = .buffer([])
603+
} else {
604+
state = .waitingForSpecific(continuations)
605+
}
606+
return continuation
607+
} else {
608+
// No matching continuation, add to buffer
609+
var buffer = Deque<T>()
610+
buffer.append(item)
611+
state = .buffer(buffer)
612+
return nil
613+
}
609614
}
610615
}
611616

612617
// Resume continuation outside the lock to prevent potential deadlocks
613-
maybeContinuation?.resume(returning: invocation)
618+
continuationToResume?.resume(returning: item)
614619
}
615620

616-
func next() async throws -> T? {
617-
// exit the async for loop if the task is cancelled
621+
/// Unified next() method that handles both FIFO and requestId-specific waiting
622+
private func _next(for requestId: String?) async throws -> T {
623+
// exit if the task is cancelled
618624
guard !Task.isCancelled else {
619-
return nil
625+
throw CancellationError()
620626
}
621627

622628
return try await withTaskCancellationHandler {
623629
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
624630
let nextAction: Result<T, PoolError>? = self.lock.withLock { state -> Result<T, PoolError>? in
625631
switch consume state {
626632
case .buffer(var buffer):
627-
if let first = buffer.popFirst() {
628-
state = .buffer(buffer)
629-
return .success(first)
633+
if let requestId = requestId {
634+
// Look for oldest (first) item for this requestId in buffer
635+
if let index = buffer.firstIndex(where: { item in
636+
if let response = item as? LocalServerResponse {
637+
return response.requestId == requestId
638+
}
639+
return false
640+
}) {
641+
let item = buffer.remove(at: index)
642+
state = .buffer(buffer)
643+
return .success(item)
644+
} else {
645+
// No matching item, wait for it
646+
var continuations: [String: CheckedContinuation<T, any Error>] = [:]
647+
continuations[requestId] = continuation
648+
state = .waitingForSpecific(continuations)
649+
return nil
650+
}
630651
} else {
631-
state = .continuation(continuation)
632-
return nil
652+
// FIFO mode - take first item
653+
if let first = buffer.popFirst() {
654+
state = .buffer(buffer)
655+
return .success(first)
656+
} else {
657+
state = .waitingForAny(continuation)
658+
return nil
659+
}
633660
}
634661

635-
case .continuation(let previousContinuation):
636-
state = .buffer([])
637-
return .failure(PoolError(cause: .nextCalledTwice(previousContinuation)))
662+
case .waitingForAny(let previousContinuation):
663+
if requestId == nil {
664+
// Another FIFO call while already waiting
665+
state = .buffer([])
666+
return .failure(PoolError(cause: .nextCalledTwice(previousContinuation)))
667+
} else {
668+
// Can't mix FIFO and specific waiting
669+
state = .waitingForAny(previousContinuation)
670+
return .failure(PoolError(cause: .mixedWaitingModes))
671+
}
672+
673+
case .waitingForSpecific(var continuations):
674+
if let requestId = requestId {
675+
if continuations[requestId] != nil {
676+
// Already waiting for this requestId
677+
state = .waitingForSpecific(continuations)
678+
return .failure(PoolError(cause: .duplicateRequestIdWait(requestId)))
679+
} else {
680+
continuations[requestId] = continuation
681+
state = .waitingForSpecific(continuations)
682+
return nil
683+
}
684+
} else {
685+
// Can't mix FIFO and specific waiting
686+
state = .waitingForSpecific(continuations)
687+
return .failure(PoolError(cause: .mixedWaitingModes))
688+
}
638689
}
639690
}
640691

641692
switch nextAction {
642-
case .success(let action):
643-
continuation.resume(returning: action)
693+
case .success(let item):
694+
continuation.resume(returning: item)
644695
case .failure(let error):
645696
if case let .nextCalledTwice(prevContinuation) = error.cause {
646697
prevContinuation.resume(throwing: error)
@@ -653,22 +704,37 @@ internal struct LambdaHTTPServer {
653704
}
654705
} onCancel: {
655706
// Ensure we properly handle cancellation by checking if we have a stored continuation
656-
let continuationToCancel = self.lock.withLock { state -> CheckedContinuation<T, any Error>? in
707+
let continuationsToCancel = self.lock.withLock { state -> [String: CheckedContinuation<T, any Error>] in
657708
switch consume state {
658709
case .buffer(let buffer):
659710
state = .buffer(buffer)
660-
return nil
661-
case .continuation(let continuation):
711+
return [:]
712+
case .waitingForAny(let continuation):
662713
state = .buffer([])
663-
return continuation
714+
return ["": continuation] // Use empty string as key for single continuation
715+
case .waitingForSpecific(let continuations):
716+
state = .buffer([])
717+
return continuations
664718
}
665719
}
666720

667-
// Resume the continuation outside the lock to avoid potential deadlocks
668-
continuationToCancel?.resume(throwing: CancellationError())
721+
// Resume all continuations outside the lock to avoid potential deadlocks
722+
for continuation in continuationsToCancel.values {
723+
continuation.resume(throwing: CancellationError())
724+
}
669725
}
670726
}
671727

728+
/// Simple FIFO next() method - used by AsyncIteratorProtocol
729+
func next() async throws -> T? {
730+
try await _next(for: nil)
731+
}
732+
733+
/// RequestId-specific next() method for LocalServerResponse - NOT part of AsyncIteratorProtocol
734+
func next(for requestId: String) async throws -> T {
735+
try await _next(for: requestId)
736+
}
737+
672738
func makeAsyncIterator() -> Pool {
673739
self
674740
}
@@ -679,11 +745,17 @@ internal struct LambdaHTTPServer {
679745
switch self.cause {
680746
case .nextCalledTwice:
681747
return "Concurrent invocations to next(). This is not allowed."
748+
case .duplicateRequestIdWait(let requestId):
749+
return "Already waiting for requestId: \(requestId)"
750+
case .mixedWaitingModes:
751+
return "Cannot mix FIFO waiting (next()) with specific waiting (next(for:))"
682752
}
683753
}
684754

685755
enum Cause {
686756
case nextCalledTwice(CheckedContinuation<T, any Error>)
757+
case duplicateRequestIdWait(String)
758+
case mixedWaitingModes
687759
}
688760
}
689761
}

0 commit comments

Comments
 (0)