Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,13 @@ extension ValkeyChannelHandler {
}

@usableFromInline
enum GracefulShutdownAction {
case waitForPendingCommands(Context)
enum TriggerGracefulShutdownAction {
case closeConnection(Context)
case doNothing
}
/// Want to gracefully shutdown the handler
@usableFromInline
mutating func gracefulShutdown() -> GracefulShutdownAction {
mutating func triggerGracefulShutdown() -> TriggerGracefulShutdownAction {
switch consume self.state {
case .initialized:
self = .closed(nil)
Expand All @@ -346,11 +345,11 @@ extension ValkeyChannelHandler {
var pendingCommands = state.pendingCommands
pendingCommands.prepend(state.pendingHelloCommand)
self = .closing(.init(context: state.context, pendingCommands: pendingCommands))
return .waitForPendingCommands(state.context)
return .doNothing
case .active(let state):
if state.pendingCommands.count > 0 {
self = .closing(.init(context: state.context, pendingCommands: state.pendingCommands))
return .waitForPendingCommands(state.context)
return .doNothing
} else {
self = .closed(nil)
return .closeConnection(state.context)
Expand Down
9 changes: 9 additions & 0 deletions Sources/Valkey/Connection/ValkeyChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,15 @@ final class ValkeyChannelHandler: ChannelInboundHandler {
break
}
}

func triggerGracefulShutdown() {
switch self.stateMachine.triggerGracefulShutdown() {
case .closeConnection(let context):
context.close(mode: .all, promise: nil)
case .doNothing:
break
}
}
}

@available(valkeySwift 1.0, *)
Expand Down
8 changes: 8 additions & 0 deletions Sources/Valkey/Connection/ValkeyConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ public final actor ValkeyConnection: ValkeyClientProtocol, Sendable {
try await self.channelHandler.waitOnActive().get()
}

/// Trigger graceful shutdown of connection
///
/// The connection will wait until all pending commands have been processed before
/// closing the connection.
func triggerGracefulShutdown() {
self.channelHandler.triggerGracefulShutdown()
}

/// Send RESP command to Valkey connection
/// - Parameter command: ValkeyCommand structure
/// - Returns: The command response as defined in the ValkeyCommand
Expand Down
20 changes: 10 additions & 10 deletions Tests/ValkeyTests/ValkeyChannelHandlerStateMachineTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ struct ValkeyChannelHandlerStateMachineTests {
var stateMachine = ValkeyChannelHandler.StateMachine<String>()
stateMachine.setConnected(context: "testGracefulShutdown")
stateMachine.receiveHelloResponse()
switch stateMachine.gracefulShutdown() {
switch stateMachine.triggerGracefulShutdown() {
case .closeConnection(let context):
#expect(context == "testGracefulShutdown")
default:
Expand All @@ -168,10 +168,10 @@ struct ValkeyChannelHandlerStateMachineTests {
case .throwError:
Issue.record("Invalid sendCommand action")
}
switch stateMachine.gracefulShutdown() {
case .waitForPendingCommands(let context):
#expect(context == "testGracefulShutdown")
default:
switch stateMachine.triggerGracefulShutdown() {
case .doNothing:
break
case .closeConnection:
Issue.record("Invalid waitForPendingCommands action")
}
expect(
Expand Down Expand Up @@ -207,10 +207,10 @@ struct ValkeyChannelHandlerStateMachineTests {
case .throwError:
Issue.record("Invalid sendCommand action")
}
switch stateMachine.gracefulShutdown() {
case .waitForPendingCommands(let context):
#expect(context == "testClosedClosingState")
default:
switch stateMachine.triggerGracefulShutdown() {
case .doNothing:
break
case .closeConnection:
Issue.record("Invalid waitForPendingCommands action")
}
expect(
Expand Down Expand Up @@ -333,7 +333,7 @@ struct ValkeyChannelHandlerStateMachineTests {
case .throwError:
Issue.record("Invalid sendCommand action")
}
_ = stateMachine.gracefulShutdown()
_ = stateMachine.triggerGracefulShutdown()
switch stateMachine.cancel(requestID: 23) {
case .failPendingCommandsAndClose(let context, let cancel, let closeConnectionDueToCancel):
#expect(context == "testCancelGracefulShutdown")
Expand Down
22 changes: 22 additions & 0 deletions Tests/ValkeyTests/ValkeyConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,28 @@ struct ConnectionTests {
try await channel.close()
}

@Test
@available(valkeySwift 1.0, *)
func testTriggerGracefulShutdown() async throws {
let channel = NIOAsyncTestingChannel()
let logger = Logger(label: "test")
let connection = try await ValkeyConnection.setupChannelAndConnect(channel, configuration: .init(), logger: logger)
try await channel.processHello()

async let fooResult = connection.get("foo").map { String(buffer: $0) }

let outbound = try await channel.waitForOutboundWrite(as: ByteBuffer.self)
#expect(outbound == RESPToken(.command(["GET", "foo"])).base)

await connection.triggerGracefulShutdown()
#expect(channel.isActive)

try await channel.writeInbound(RESPToken(.bulkString("Bar")).base)
#expect(try await fooResult == "Bar")

try await channel.closeFuture.get()
}

#if DistributedTracingSupport
@Suite
struct DistributedTracingTests {
Expand Down