diff --git a/Sources/NIOCore/IO.swift b/Sources/NIOCore/IO.swift index 59d7cced5e..61e25f677d 100644 --- a/Sources/NIOCore/IO.swift +++ b/Sources/NIOCore/IO.swift @@ -62,7 +62,7 @@ public struct IOError: Swift.Error { .reason(self.failureDescription) } - private enum Error { + package enum Error { #if os(Windows) case windows(DWORD) case winsock(CInt) @@ -70,7 +70,7 @@ public struct IOError: Swift.Error { case errno(CInt) } - private let error: Error + package let error: Error /// The `errno` that was set for the operation. public var errnoCode: CInt { diff --git a/Sources/NIOCore/SystemCallHelpers.swift b/Sources/NIOCore/SystemCallHelpers.swift index ee44c9b9ca..2fdf24ca8f 100644 --- a/Sources/NIOCore/SystemCallHelpers.swift +++ b/Sources/NIOCore/SystemCallHelpers.swift @@ -114,6 +114,10 @@ internal func syscall( case (EWOULDBLOCK, true): return .wouldBlock(0) #endif + #if os(Windows) + case (WSAEWOULDBLOCK, true): + return .wouldBlock(0) + #endif default: preconditionIsNotUnacceptableErrno(err: err, where: function) throw IOError(errnoCode: err, reason: function) diff --git a/Sources/NIOPosix/BSDSocketAPIWindows.swift b/Sources/NIOPosix/BSDSocketAPIWindows.swift index b51aaca02b..9915dc6757 100644 --- a/Sources/NIOPosix/BSDSocketAPIWindows.swift +++ b/Sources/NIOPosix/BSDSocketAPIWindows.swift @@ -195,7 +195,11 @@ extension NIOBSDSocket { ) throws -> NIOBSDSocket.Handle? { let socket: NIOBSDSocket.Handle = WinSDK.accept(s, addr, addrlen) if socket == WinSDK.INVALID_SOCKET { - throw IOError(winsock: WSAGetLastError(), reason: "accept") + let lastError = WSAGetLastError() + if lastError == WSAEWOULDBLOCK { + return nil + } + throw IOError(winsock: lastError, reason: "accept") } return socket } diff --git a/Sources/NIOPosix/BaseSocketChannel.swift b/Sources/NIOPosix/BaseSocketChannel.swift index d928416451..c772c0fd13 100644 --- a/Sources/NIOPosix/BaseSocketChannel.swift +++ b/Sources/NIOPosix/BaseSocketChannel.swift @@ -15,6 +15,9 @@ import Atomics import NIOConcurrencyHelpers import NIOCore +#if os(Windows) +import WinSDK +#endif private struct SocketChannelLifecycleManager { // MARK: Types @@ -1215,7 +1218,7 @@ class BaseSocketChannel: SelectableChannel, Chan /// - err: The `Error` which was thrown by `readFromSocket`. /// - Returns: `true` if the `Channel` should be closed, `false` otherwise. func shouldCloseOnReadError(_ err: Error) -> Bool { - true + return true } /// Handles an error reported by the selector. diff --git a/Sources/NIOPosix/SelectorGeneric.swift b/Sources/NIOPosix/SelectorGeneric.swift index b538e3f4f5..ee7eb761b4 100644 --- a/Sources/NIOPosix/SelectorGeneric.swift +++ b/Sources/NIOPosix/SelectorGeneric.swift @@ -216,7 +216,9 @@ internal class Selector { @usableFromInline typealias EventType = WinSDK.pollfd @usableFromInline - var pollFDs = [WinSDK.pollfd]() + var pollFDs = [pollfd]() + @usableFromInline + var deregisteredFDs = [Bool]() #else #error("Unsupported platform, no suitable selector backend (we need kqueue or epoll support)") #endif diff --git a/Sources/NIOPosix/SelectorWSAPoll.swift b/Sources/NIOPosix/SelectorWSAPoll.swift index c02cf1d3be..8db3c68d10 100644 --- a/Sources/NIOPosix/SelectorWSAPoll.swift +++ b/Sources/NIOPosix/SelectorWSAPoll.swift @@ -67,6 +67,8 @@ extension Selector: _SelectorBackendProtocol { func initialiseState0() throws { self.pollFDs.reserveCapacity(16) + self.deregisteredFDs.reserveCapacity(16) + self.lifecycleState = .open } func deinitAssertions0() { @@ -102,7 +104,7 @@ extension Selector: _SelectorBackendProtocol { } } else { let result = self.pollFDs.withUnsafeMutableBufferPointer { ptr in - WSAPoll(ptr.baseAddress!, UInt32(ptr.count), time) + WSAPoll(ptr.baseAddress!, UInt32(ptr.count), 1) } if result > 0 { @@ -131,6 +133,19 @@ extension Selector: _SelectorBackendProtocol { try body((SelectorEvent(io: selectorEvent, registration: registration))) } + + // now clean up any deregistered fds + // In reverse order so we don't have to copy elements out of the array + // If we do in in normal order, we'll have to shift all elements after the removed one + for i in self.deregisteredFDs.indices.reversed() { + if self.deregisteredFDs[i] { + // remove this one + let fd = self.pollFDs[i].fd + self.pollFDs.remove(at: i) + self.deregisteredFDs.remove(at: i) + self.registrations.removeValue(forKey: Int(fd)) + } + } } else if result == 0 { // nothing has happened } else if result == WinSDK.SOCKET_ERROR { @@ -149,6 +164,7 @@ extension Selector: _SelectorBackendProtocol { // that will allow O(1) access here. let poll = pollfd(fd: UInt64(fileDescriptor), events: interested.wsaPollEvent, revents: 0) self.pollFDs.append(poll) + self.deregisteredFDs.append(false) } func reregister0( @@ -158,7 +174,9 @@ extension Selector: _SelectorBackendProtocol { newInterested: SelectorEventSet, registrationID: SelectorRegistrationID ) throws { - fatalError("TODO: Unimplemented") + if let index = self.pollFDs.firstIndex(where: { $0.fd == UInt64(fileDescriptor) }) { + self.pollFDs[index].events = newInterested.wsaPollEvent + } } func deregister0( @@ -167,13 +185,15 @@ extension Selector: _SelectorBackendProtocol { oldInterested: SelectorEventSet, registrationID: SelectorRegistrationID ) throws { - fatalError("TODO: Unimplemented") + if let index = self.pollFDs.firstIndex(where: { $0.fd == UInt64(fileDescriptor) }) { + self.deregisteredFDs[index] = true + } } func wakeup0() throws { // will be called from a different thread - let result = try self.myThread.withHandleUnderLock { handle in - QueueUserAPC(wakeupTarget, handle, 0) + let result = try self.myThread.withHandleUnderLock { threadHandle in + return QueueUserAPC(wakeupTarget, threadHandle.handle, 0) } if result == 0 { let errorCode = GetLastError() @@ -185,6 +205,7 @@ extension Selector: _SelectorBackendProtocol { func close0() throws { self.pollFDs.removeAll() + self.deregisteredFDs.removeAll() } } diff --git a/Sources/NIOPosix/SocketChannel.swift b/Sources/NIOPosix/SocketChannel.swift index 504c2ad834..10346bfe38 100644 --- a/Sources/NIOPosix/SocketChannel.swift +++ b/Sources/NIOPosix/SocketChannel.swift @@ -383,12 +383,12 @@ final class ServerSocketChannel: BaseSocketChannel, @unchecked Sen } guard let err = err as? IOError else { return true } - switch err.errnoCode { - case ECONNABORTED, - EMFILE, - ENFILE, - ENOBUFS, - ENOMEM: + switch err.error { + case .errno(ECONNABORTED), + .errno(EMFILE), + .errno(ENFILE), + .errno(ENOBUFS), + .errno(ENOMEM): // These are errors we may be able to recover from. The user may just want to stop accepting connections for example // or provide some other means of back-pressure. This could be achieved by a custom ChannelDuplexHandler. return false @@ -856,14 +856,22 @@ final class DatagramChannel: BaseSocketChannel, @unchecked Sendable { private func shouldCloseOnErrnoCode(_ errnoCode: CInt) -> Bool { switch errnoCode { + case ECONNREFUSED, ENOMEM: + // These are errors we may be able to recover from. + return false + default: + return true + } + } + + private func shouldCloseOnError(_ error: IOError.Error) -> Bool { + switch error { // ECONNREFUSED can happen on linux if the previous sendto(...) failed. // See also: // - https://bugzilla.redhat.com/show_bug.cgi?id=1375 // - https://lists.gt.net/linux/kernel/39575 - case ECONNREFUSED, - ENOMEM: - // These are errors we may be able to recover from. - return false + case .errno(let code): + return self.shouldCloseOnErrnoCode(code) default: return true } @@ -871,7 +879,7 @@ final class DatagramChannel: BaseSocketChannel, @unchecked Sendable { override func shouldCloseOnReadError(_ err: Error) -> Bool { guard let err = err as? IOError else { return true } - return self.shouldCloseOnErrnoCode(err.errnoCode) + return self.shouldCloseOnError(err.error) } override func error() -> ErrorResult { diff --git a/Sources/NIOPosix/Thread.swift b/Sources/NIOPosix/Thread.swift index b730aba1b2..77640fca9c 100644 --- a/Sources/NIOPosix/Thread.swift +++ b/Sources/NIOPosix/Thread.swift @@ -92,18 +92,14 @@ final class NIOThread: Sendable { static var currentThreadName: String? { #if os(Windows) - ThreadOpsSystem.threadName(.init(GetCurrentThread())) + ThreadOpsSystem.threadName(.init(handle: GetCurrentThread())) #else ThreadOpsSystem.threadName(.init(handle: pthread_self())) #endif } static var currentThreadID: UInt { - #if os(Windows) - UInt(bitPattern: .init(bitPattern: ThreadOpsSystem.currentThread)) - #else UInt(bitPattern: .init(bitPattern: ThreadOpsSystem.currentThread.handle)) - #endif } @discardableResult diff --git a/Sources/NIOPosix/ThreadWindows.swift b/Sources/NIOPosix/ThreadWindows.swift index 4da8f48b01..fcb2d4bf4b 100644 --- a/Sources/NIOPosix/ThreadWindows.swift +++ b/Sources/NIOPosix/ThreadWindows.swift @@ -18,13 +18,15 @@ import WinSDK typealias ThreadOpsSystem = ThreadOpsWindows enum ThreadOpsWindows: ThreadOps { - typealias ThreadHandle = HANDLE + struct ThreadHandle: @unchecked Sendable { + let handle: HANDLE + } typealias ThreadSpecificKey = DWORD typealias ThreadSpecificKeyDestructor = @convention(c) (UnsafeMutableRawPointer?) -> Void static func threadName(_ thread: ThreadOpsSystem.ThreadHandle) -> String? { var pszBuffer: PWSTR? - GetThreadDescription(thread, &pszBuffer) + GetThreadDescription(thread.handle, &pszBuffer) guard let buffer = pszBuffer else { return nil } let string: String = String(decodingCString: buffer, as: UTF16.self) LocalFree(buffer) @@ -41,11 +43,27 @@ enum ThreadOpsWindows: ThreadOps { let routine: @convention(c) (UnsafeMutableRawPointer?) -> CUnsignedInt = { let boxed = Unmanaged.fromOpaque($0!).takeRetainedValue() let (body, name) = (boxed.value.body, boxed.value.name) - let hThread: ThreadOpsSystem.ThreadHandle = GetCurrentThread() + + // Get a real thread handle instead of pseudo-handle + var realHandle: HANDLE? = nil + let success = DuplicateHandle( + GetCurrentProcess(), // Source process + GetCurrentThread(), // Source handle (pseudo-handle) + GetCurrentProcess(), // Target process + &realHandle, // Target handle (real handle) + 0, // Desired access (0 = same as source) + false, // Inherit handle + DWORD(DUPLICATE_SAME_ACCESS) // Options + ) + + guard success, let realHandle else { + fatalError("DuplicateHandle failed: \(GetLastError())") + } + let hThread = ThreadOpsSystem.ThreadHandle(handle: realHandle) if let name = name { _ = name.withCString(encodedAs: UTF16.self) { - SetThreadDescription(hThread, $0) + SetThreadDescription(hThread.handle, $0) } } @@ -58,15 +76,28 @@ enum ThreadOpsWindows: ThreadOps { } static func isCurrentThread(_ thread: ThreadOpsSystem.ThreadHandle) -> Bool { - CompareObjectHandles(thread, GetCurrentThread()) + CompareObjectHandles(thread.handle, GetCurrentThread()) } static var currentThread: ThreadOpsSystem.ThreadHandle { - GetCurrentThread() + var realHandle: HANDLE? = nil + let success = DuplicateHandle( + GetCurrentProcess(), + GetCurrentThread(), + GetCurrentProcess(), + &realHandle, + 0, + false, + DWORD(DUPLICATE_SAME_ACCESS) + ) + guard success, let realHandle else { + fatalError("DuplicateHandle failed: \(GetLastError())") + } + return ThreadHandle(handle: realHandle) } static func joinThread(_ thread: ThreadOpsSystem.ThreadHandle) { - let dwResult: DWORD = WaitForSingleObject(thread, INFINITE) + let dwResult: DWORD = WaitForSingleObject(thread.handle, INFINITE) assert(dwResult == WAIT_OBJECT_0, "WaitForSingleObject: \(GetLastError())") } @@ -88,7 +119,7 @@ enum ThreadOpsWindows: ThreadOps { } static func compareThreads(_ lhs: ThreadOpsSystem.ThreadHandle, _ rhs: ThreadOpsSystem.ThreadHandle) -> Bool { - CompareObjectHandles(lhs, rhs) + CompareObjectHandles(lhs.handle, rhs.handle) } }