diff --git a/.swift-version b/.swift-version index 9f55b2c..b502146 100644 --- a/.swift-version +++ b/.swift-version @@ -1 +1 @@ -3.0 +3.0.2 diff --git a/Package.swift b/Package.swift index d64feb4..7d91fb2 100644 --- a/Package.swift +++ b/Package.swift @@ -4,6 +4,7 @@ let package = Package( name: "PostgreSQL", dependencies: [ .Package(url: "https://github.com/Zewo/CLibpq.git", majorVersion: 0, minor: 13), - .Package(url: "https://github.com/Zewo/SQL.git", majorVersion: 0, minor: 14) + .Package(url: "https://github.com/Zewo/SQL.git", majorVersion: 0, minor: 14), + .Package(url: "https://github.com/Zewo/CLibvenice.git", majorVersion: 0, minor: 14), ] ) diff --git a/Sources/PostgreSQL/Connection.swift b/Sources/PostgreSQL/Connection.swift index 13e9d65..1c76b3d 100644 --- a/Sources/PostgreSQL/Connection.swift +++ b/Sources/PostgreSQL/Connection.swift @@ -1,5 +1,7 @@ @_exported import SQL +import Foundation import CLibpq +import CLibvenice import Axis public struct ConnectionError: Error, CustomStringConvertible { @@ -16,7 +18,6 @@ public final class Connection: ConnectionProtocol { public var username: String? public var password: String? public var options: String? - public var tty: String? public init?(uri: URL) { do { @@ -40,14 +41,13 @@ public final class Connection: ConnectionProtocol { self.password = uri.password } - public init(host: String, port: Int = 5432, databaseName: String, username: String? = nil, password: String? = nil, options: String? = nil, tty: String? = nil) { + public init(host: String, port: Int = 5432, databaseName: String, username: String? = nil, password: String? = nil, options: String? = nil) { self.host = host self.port = port self.databaseName = databaseName self.username = username self.password = password self.options = options - self.tty = tty } } @@ -102,6 +102,7 @@ public final class Connection: ConnectionProtocol { public var logger: Logger? private var connection: OpaquePointer? = nil + private var fd: Int32 = -1 public let connectionInfo: ConnectionInfo @@ -118,18 +119,63 @@ public final class Connection: ConnectionProtocol { } public func open() throws { - connection = PQsetdbLogin( - connectionInfo.host, - String(connectionInfo.port), - connectionInfo.options ?? "", - connectionInfo.tty ?? "", - connectionInfo.databaseName, - connectionInfo.username ?? "", - connectionInfo.password ?? "" - ) - - if let error = mostRecentError { - throw error + guard connection == nil else { + throw ConnectionError(description: "Connection already opened.") + } + + var components = URLComponents() + components.scheme = "postgres" + components.host = connectionInfo.host + components.port = connectionInfo.port + components.user = connectionInfo.username + components.password = connectionInfo.password + components.path = "/\(connectionInfo.databaseName)" + if let options = connectionInfo.options { + components.queryItems = [URLQueryItem(name: "options", value: options)] + } + let url = components.url!.absoluteString + + connection = PQconnectStart(url) + + guard connection != nil else { + throw ConnectionError(description: "Could not allocate connection.") + } + + guard PQstatus(connection) != CONNECTION_BAD else { + throw ConnectionError(description: "Could not start connection.") + } + + fd = PQsocket(connection) + guard fd >= 0 else { + throw mostRecentError ?? ConnectionError(description: "Could not get file descriptor.") + } + + loop: while true { + let status = PQconnectPoll(connection) + switch status { + case PGRES_POLLING_OK: + break loop + case PGRES_POLLING_READING: + mill_fdwait(fd, FDW_IN, 15.seconds.fromNow().int64milliseconds, nil) + fdclean(fd) + case PGRES_POLLING_WRITING: + mill_fdwait(fd, FDW_OUT, 15.seconds.fromNow().int64milliseconds, nil) + fdclean(fd) + case PGRES_POLLING_ACTIVE: + break + case PGRES_POLLING_FAILED: + throw mostRecentError ?? ConnectionError(description: "Could not connect to Postgres Server.") + default: + break + } + } + + guard PQsetnonblocking(connection, 1) == 0 else { + throw mostRecentError ?? ConnectionError(description: "Could not set to non-blocking mode.") + } + + guard PQstatus(connection) == CONNECTION_OK else { + throw mostRecentError ?? ConnectionError(description: "Could not connect to Postgres Server.") } } @@ -142,87 +188,132 @@ public final class Connection: ConnectionProtocol { } public func close() { - PQfinish(connection) - connection = nil + if connection != nil { + PQfinish(connection!) + connection = nil + } } public func createSavePointNamed(_ name: String) throws { - try execute("SAVEPOINT \(name)", parameters: nil) + try execute("SAVEPOINT ?", parameters: [.string(name)]) } public func rollbackToSavePointNamed(_ name: String) throws { - try execute("ROLLBACK TO SAVEPOINT \(name)", parameters: nil) + try execute("ROLLBACK TO SAVEPOINT ?", parameters: [.string(name)]) } public func releaseSavePointNamed(_ name: String) throws { - try execute("RELEASE SAVEPOINT \(name)", parameters: nil) + try execute("RELEASE SAVEPOINT ?", parameters: [.string(name)]) } @discardableResult public func execute(_ statement: String, parameters: [Value?]?) throws -> Result { - var statement = statement.sqlStringWithEscapedPlaceholdersUsingPrefix("$") { return String($0 + 1) } defer { logger?.debug(statement) } - guard let parameters = parameters else { - guard let resultPointer = PQexec(connection, statement) else { - throw mostRecentError ?? ConnectionError(description: "Empty result") - } - - return try Result(resultPointer) - } - var parameterData = [UnsafePointer?]() var deallocators = [() -> ()]() defer { deallocators.forEach { $0() } } - for parameter in parameters { + if let parameters = parameters { + for parameter in parameters { + + guard let value = parameter else { + parameterData.append(nil) + continue + } - guard let value = parameter else { - parameterData.append(nil) + let data: AnyCollection + switch value { + case .buffer(let value): + data = AnyCollection(value.map { Int8($0) }) + + case .string(let string): + data = AnyCollection(string.utf8CString) + } + + let pointer = UnsafeMutablePointer.allocate(capacity: Int(data.count)) + deallocators.append { + pointer.deallocate(capacity: Int(data.count)) + } + + for (index, byte) in data.enumerated() { + pointer[index] = byte + } + + parameterData.append(pointer) + } + } + + let sendResult: Int32 = parameterData.withUnsafeBufferPointer { buffer in + if buffer.isEmpty { + return PQsendQuery(self.connection, statement) + } else { + return PQsendQueryParams(self.connection, + statement, + Int32(parameterData.count), + nil, + buffer.baseAddress!, + nil, + nil, + 0) + } + } + + guard sendResult == 1 else { + throw mostRecentError ?? ConnectionError(description: "Could not send query.") + } + + // write query + while true { + mill_fdwait(fd, FDW_OUT, -1, nil) + fdclean(fd) + let status = PQflush(connection) + guard status >= 0 else { + throw mostRecentError ?? ConnectionError(description: "Could not send query.") + } + guard status == 0 else { continue } + break + } - let data: AnyCollection - switch value { - case .buffer(let value): - data = AnyCollection(value.map { Int8($0) }) + // read response + var lastResult: OpaquePointer? = nil + while true { + guard PQconsumeInput(connection) == 1 else { + throw mostRecentError ?? ConnectionError(description: "Could not send query.") + } - case .string(let string): - data = AnyCollection(string.utf8CString) + guard PQisBusy(connection) == 0 else { + mill_fdwait(fd, FDW_IN, -1, nil) + fdclean(fd) + continue } - let pointer = UnsafeMutablePointer.allocate(capacity: Int(data.count)) - deallocators.append { - pointer.deallocate(capacity: Int(data.count)) + guard let result = PQgetResult(connection) else { + break } - for (index, byte) in data.enumerated() { - pointer[index] = byte + if lastResult != nil { + PQclear(lastResult!) + lastResult = nil } - parameterData.append(pointer) - } - - let result: OpaquePointer = try parameterData.withUnsafeBufferPointer { buffer in - guard let result = PQexecParams( - self.connection, - statement, - Int32(parameters.count), - nil, - buffer.isEmpty ? nil : buffer.baseAddress, - nil, - nil, - 0 - ) else { - throw mostRecentError ?? ConnectionError(description: "Empty result") + let status = PQresultStatus(result) + guard status == PGRES_COMMAND_OK || status == PGRES_TUPLES_OK else { + throw mostRecentError ?? ConnectionError(description: "Query failed.") } - return result + + lastResult = result } - return try Result(result) + guard lastResult != nil else { + throw mostRecentError ?? ConnectionError(description: "Query failed.") + } + return try Result(lastResult!) } } diff --git a/Tests/PostgreSQLTests/PostgreSQLTests.swift b/Tests/PostgreSQLTests/PostgreSQLTests.swift index 4226124..000c4e6 100644 --- a/Tests/PostgreSQLTests/PostgreSQLTests.swift +++ b/Tests/PostgreSQLTests/PostgreSQLTests.swift @@ -103,14 +103,15 @@ extension Album: ModelProtocol { // MARK: - Tests public class PostgreSQLTests: XCTestCase { - let connection = try! PostgreSQL.Connection(info: .init(URL(string: "postgres://localhost:5432/swift_test")!)) + var connection: Connection! let logger = Logger(name: "SQL Logger", appenders: [StandardOutputAppender()]) override public func setUp() { super.setUp() - + do { + connection = try! PostgreSQL.Connection(info: .init(URL(string: "postgres://localhost:5432/swift_test")!)) try connection.open() try connection.execute("DROP TABLE IF EXISTS albums") try connection.execute("DROP TABLE IF EXISTS artists")